1
use anyhow::Result;
2
use std::sync::{Arc, Mutex};
3
use wasmtime::{Caller, Func, Memory, Store, ToWasmtimeResult};
4

            
5
pub struct DataStore<Data> {
6
    pub data: Mutex<Data>,
7
    pub memory: Mutex<Option<Memory>>,
8
}
9

            
10
// Helper function for builder operations, allowing any type
11
fn with_builder<F, Builder>(caller: &Caller<'_, Arc<DataStore<Builder>>>, action: F) -> Result<()>
12
where
13
    F: FnOnce(&mut Builder) -> Result<()>,
14
{
15
    let data = caller.data();
16

            
17
    // Lock the builder
18
    let mut builder = match data.data.lock() {
19
        Ok(builder) => builder,
20
        Err(_) => anyhow::bail!(t!("Can't lock the builder")),
21
    };
22

            
23
    // Execute the action on the builder
24
    action(&mut builder)
25
}
26

            
27
// Helper function for setting a str from memory with two numbers
28
fn with_data_set_str<F, Data>(
29
    caller: &Caller<'_, Arc<DataStore<Data>>>,
30
    loc: i32,
31
    len: i32,
32
    set_field: F,
33
) -> Result<()>
34
where
35
    F: FnOnce(&mut Data, String),
36
{
37
    with_builder(caller, |builder| {
38
        let data = caller.data();
39

            
40
        // Lock the memory
41
        let m = match data.memory.lock() {
42
            Ok(m) => m,
43
            Err(_) => anyhow::bail!(t!("Can't lock the memory")),
44
        };
45

            
46
        // Access memory and set the field if data is found
47
        if let Some(memory) = &*m
48
            && let Some(bytes) = memory
49
                .data(caller)
50
                .get(loc as usize..loc as usize + len as usize)
51
        {
52
            let string_value = String::from_utf8_lossy(bytes).to_string();
53
            set_field(builder, string_value);
54
        }
55

            
56
        Ok(())
57
    })
58
}
59

            
60
// Wrapper for WASM numbers
61
pub fn wrap_wasm_i64<Data, Setter>(store: &mut Store<Arc<DataStore<Data>>>, setter: Setter) -> Func
62
where
63
    Data: Send + 'static,
64
    Setter: Fn(&mut Data, i64) + Send + Sync + 'static,
65
{
66
    Func::wrap(
67
        store,
68
        move |caller: Caller<'_, Arc<DataStore<Data>>>, value: i64| {
69
            with_builder(&caller, |builder| {
70
                setter(builder, value);
71
                Ok(())
72
            })
73
            .to_wasmtime_result()
74
        },
75
    )
76
}
77

            
78
// Generic wrapper for creating `Func` to set string fields in the builder
79
pub fn wrap_wasm_str<Data, Setter>(store: &mut Store<Arc<DataStore<Data>>>, setter: Setter) -> Func
80
where
81
    Data: Send + 'static,
82
    Setter: Fn(&mut Data, String) + Send + Sync + 'static,
83
{
84
    Func::wrap(
85
        store,
86
        move |caller: Caller<'_, Arc<DataStore<Data>>>, loc: i32, len: i32| {
87
            with_data_set_str(&caller, loc, len, |builder, value| {
88
                setter(builder, value);
89
            })
90
            .to_wasmtime_result()
91
        },
92
    )
93
}
94

            
95
// Wrapper for WASM i64 getter
96
pub fn wrap_wasm_i64_get<Data, Getter>(
97
    store: &mut Store<Arc<DataStore<Data>>>,
98
    getter: Getter,
99
) -> Func
100
where
101
    Data: Send + 'static,
102
    Getter: Fn(&Data) -> i64 + Send + Sync + 'static,
103
{
104
    Func::wrap(
105
        store,
106
        move |caller: Caller<'_, Arc<DataStore<Data>>>| -> i64 {
107
            let data = caller.data();
108
            let builder = data.data.lock().expect("Failed to lock data");
109

            
110
            // Call the getter to retrieve the i64 value
111
            getter(&builder)
112
        },
113
    )
114
}
115

            
116
// Wrapper for WASM string getter that writes the string to memory and returns its actual length
117
pub fn wrap_wasm_str_get<Data, Getter>(
118
    store: &mut Store<Arc<DataStore<Data>>>,
119
    getter: Getter,
120
) -> Func
121
where
122
    Data: Send + 'static,
123
    Getter: Fn(&Data) -> String + Send + Sync + 'static,
124
{
125
    Func::wrap(
126
        store,
127
        move |mut caller: Caller<'_, Arc<DataStore<Data>>>, loc: i32, max_len: i32| -> i32 {
128
            // Clone the Arc to avoid holding an immutable borrow on `caller`
129
            let data = Arc::clone(caller.data());
130

            
131
            // Extract the value from the builder in a separate scope
132
            let value = {
133
                let builder = data.data.lock().expect("Failed to lock data");
134
                getter(&builder)
135
            }; // `builder` is dropped here, ending the immutable borrow
136

            
137
            let bytes = value.as_bytes();
138

            
139
            // Lock the memory
140
            let m = data.memory.lock().expect("Failed to lock memory");
141

            
142
            if let Some(memory) = &*m {
143
                // Obtain a mutable context from the caller
144
                let mem = memory.data_mut(&mut caller);
145

            
146
                // Calculate the length to copy, limited by `max_len`
147
                let len = bytes.len().min(max_len as usize);
148

            
149
                // Ensure the memory bounds are valid
150
                if (loc as usize + len) <= mem.len() {
151
                    // Copy bytes to the specified location in WASM memory
152
                    mem[loc as usize..loc as usize + len].copy_from_slice(&bytes[..len]);
153

            
154
                    // Return the actual length of the copied string
155
                    return len as i32;
156
                }
157
                // Memory bounds are invalid; indicate failure
158
                return -1;
159
            }
160

            
161
            // Return -1 to indicate failure if memory is unavailable
162
            -1
163
        },
164
    )
165
}
166

            
167
56
pub fn wrap_wasm_str_transform<Data, Transform>(
168
56
    store: &mut Store<Arc<DataStore<Data>>>,
169
56
    transform: Transform,
170
56
) -> Func
171
56
where
172
56
    Data: Send + 'static,
173
56
    Transform: Fn(&Data, String) -> Result<String, anyhow::Error> + Send + Sync + 'static,
174
{
175
56
    Func::wrap(
176
56
        store,
177
        move |mut caller: Caller<'_, Arc<DataStore<Data>>>,
178
              tag_ptr: i32,
179
              tag_len: i32,
180
              buf_ptr: i32,
181
              buf_len: i32|
182
28
              -> i32 {
183
28
            let data = Arc::clone(caller.data());
184

            
185
            // Define the inner logic
186
28
            let result = (|| -> Result<i32, anyhow::Error> {
187
                // Lock the WASM memory
188
28
                let memory = data
189
28
                    .memory
190
28
                    .lock()
191
28
                    .map_err(|_| anyhow::anyhow!("Failed to lock memory"))?;
192
28
                let memory = memory
193
28
                    .as_ref()
194
28
                    .ok_or_else(|| anyhow::anyhow!("Memory not initialized"))?;
195
28
                let mem = memory.data(&caller);
196

            
197
                // Validate memory bounds for the input key
198
28
                if (tag_ptr as usize + tag_len as usize) > mem.len() {
199
                    return Err(anyhow::anyhow!("Input memory bounds error"));
200
28
                }
201

            
202
                // Extract the key string from memory
203
28
                let key = String::from_utf8(
204
28
                    mem[tag_ptr as usize..tag_ptr as usize + tag_len as usize].to_vec(),
205
                )
206
28
                .map_err(|_| anyhow::anyhow!("Failed to decode input string"))?;
207

            
208
                // Lock the data and call the transform function
209
28
                let data_lock = data
210
28
                    .data
211
28
                    .lock()
212
28
                    .map_err(|_| anyhow::anyhow!("Failed to lock data"))?;
213
28
                let output = transform(&*data_lock, key)?;
214

            
215
                // Get output bytes
216
28
                let output_bytes = output.as_bytes();
217
28
                let output_len = output_bytes.len();
218

            
219
                // Validate that the output fits within the provided buffer length
220
28
                if output_len > buf_len as usize {
221
                    return Err(anyhow::anyhow!("Output exceeds buffer size"));
222
28
                }
223

            
224
                // Validate memory bounds for the output buffer
225
28
                if (buf_ptr as usize + output_len) > mem.len() {
226
                    return Err(anyhow::anyhow!("Output memory bounds error"));
227
28
                }
228

            
229
                // Write the result string into the output buffer
230
28
                memory.data_mut(&mut caller)[buf_ptr as usize..buf_ptr as usize + output_len]
231
28
                    .copy_from_slice(output_bytes);
232

            
233
28
                Ok(output_len as i32)
234
            })();
235

            
236
            // Handle the result or return an error code
237
28
            match result {
238
28
                Ok(len) => len,
239
                Err(e) => {
240
                    tracing::error!("Error in wrap_wasm_str_transform: {e:?}");
241
                    -1 // Indicate failure
242
                }
243
            }
244
28
        },
245
    )
246
56
}
247

            
248
56
pub fn wrap_wasm_str_list_get<Data, Getter>(
249
56
    store: &mut Store<Arc<DataStore<Data>>>,
250
56
    getter: Getter,
251
56
) -> Func
252
56
where
253
56
    Data: Send + 'static,
254
56
    Getter: Fn(&Data) -> Result<Vec<String>, anyhow::Error> + Send + Sync + 'static,
255
{
256
56
    Func::wrap(
257
56
        store,
258
56
        move |mut caller: Caller<'_, Arc<DataStore<Data>>>, loc: i32, max_len: i32| -> i32 {
259
56
            let data = Arc::clone(caller.data());
260

            
261
            // Try to execute the getter and handle potential errors
262
56
            let result = (|| -> Result<i32> {
263
56
                let builder = data
264
56
                    .data
265
56
                    .lock()
266
56
                    .map_err(|_| anyhow::anyhow!("Failed to lock data"))?;
267
56
                let list = getter(&builder)?; // Call getter, which may return an error
268
56
                let mut bytes = Vec::new();
269

            
270
84
                for s in list {
271
84
                    bytes.extend_from_slice(s.as_bytes());
272
84
                    bytes.push(0); // Null terminator
273
84
                }
274

            
275
56
                let len = bytes.len().min(max_len as usize);
276

            
277
56
                let memory = data
278
56
                    .memory
279
56
                    .lock()
280
56
                    .map_err(|_| anyhow::anyhow!("Failed to lock memory"))?;
281
56
                if let Some(memory) = &*memory {
282
56
                    let mem = memory.data_mut(&mut caller);
283

            
284
56
                    if (loc as usize + len) <= mem.len() {
285
56
                        mem[loc as usize..loc as usize + len].copy_from_slice(&bytes[..len]);
286
56
                        return Ok(len as i32);
287
                    }
288
                    return Err(anyhow::anyhow!("Memory bounds error"));
289
                }
290

            
291
                Err(anyhow::anyhow!("Memory not available"))
292
            })();
293

            
294
            // Return the result or indicate failure (-1)
295
56
            match result {
296
56
                Ok(len) => len,
297
                Err(e) => {
298
                    tracing::error!("Error in wrap_wasm_str_list_get: {e:?}");
299
                    -1
300
                }
301
            }
302
56
        },
303
    )
304
56
}
305

            
306
56
pub fn wrap_wasm_str_map<Data>(
307
56
    store: &mut Store<Arc<DataStore<Data>>>,
308
56
    updater: impl Fn(&mut Data, String, String) -> Result<(), anyhow::Error> + Send + Sync + 'static,
309
56
) -> Func
310
56
where
311
56
    Data: Send + 'static,
312
{
313
56
    Func::wrap(
314
56
        store,
315
        move |caller: Caller<'_, Arc<DataStore<Data>>>,
316
              key_loc: i32,
317
              key_len: i32,
318
              value_loc: i32,
319
              value_len: i32|
320
84
              -> i32 {
321
84
            let data = caller.data();
322

            
323
84
            let result = (|| -> Result<()> {
324
84
                let memory = data
325
84
                    .memory
326
84
                    .lock()
327
84
                    .map_err(|_| anyhow::anyhow!("Failed to lock memory"))?;
328
84
                if let Some(memory) = &*memory {
329
84
                    let mem = memory.data(&caller);
330

            
331
84
                    if (key_loc as usize + key_len as usize) > mem.len()
332
84
                        || (value_loc as usize + value_len as usize) > mem.len()
333
                    {
334
                        return Err(anyhow::anyhow!("Memory bounds error"));
335
84
                    }
336

            
337
84
                    let key = String::from_utf8(
338
84
                        mem[key_loc as usize..key_loc as usize + key_len as usize].to_vec(),
339
                    )?;
340
84
                    let value = String::from_utf8(
341
84
                        mem[value_loc as usize..value_loc as usize + value_len as usize].to_vec(),
342
                    )?;
343

            
344
84
                    let mut builder = data
345
84
                        .data
346
84
                        .lock()
347
84
                        .map_err(|_| anyhow::anyhow!("Failed to lock data"))?;
348
84
                    updater(&mut builder, key, value)?;
349

            
350
84
                    Ok(())
351
                } else {
352
                    Err(anyhow::anyhow!("Memory not available"))
353
                }
354
            })();
355

            
356
84
            match result {
357
84
                Ok(()) => 0,
358
                Err(e) => {
359
                    tracing::error!("Error in wrap_wasm_str_map: {e:?}");
360
                    -1
361
                }
362
            }
363
84
        },
364
    )
365
56
}