1
use anyhow::Result;
2
use std::sync::{Arc, Mutex};
3
use wasmtime::{Caller, Func, Memory, Store, ToWasmtimeResult};
4

            
5
pub struct DataStore<Data> {
6
    pub data: Mutex<Data>,
7
    pub memory: Mutex<Option<Memory>>,
8
}
9

            
10
// Helper function for builder operations, allowing any type
11
fn with_builder<F, Builder>(caller: &Caller<'_, Arc<DataStore<Builder>>>, action: F) -> Result<()>
12
where
13
    F: FnOnce(&mut Builder) -> Result<()>,
14
{
15
    let data = caller.data();
16

            
17
    // Lock the builder
18
    let mut builder = match data.data.lock() {
19
        Ok(builder) => builder,
20
        Err(_) => anyhow::bail!(t!("Can't lock the builder")),
21
    };
22

            
23
    // Execute the action on the builder
24
    action(&mut builder)
25
}
26

            
27
// Helper function for setting a str from memory with two numbers
28
fn with_data_set_str<F, Data>(
29
    caller: &Caller<'_, Arc<DataStore<Data>>>,
30
    loc: i32,
31
    len: i32,
32
    set_field: F,
33
) -> Result<()>
34
where
35
    F: FnOnce(&mut Data, String),
36
{
37
    with_builder(caller, |builder| {
38
        let data = caller.data();
39

            
40
        // Lock the memory
41
        let m = match data.memory.lock() {
42
            Ok(m) => m,
43
            Err(_) => anyhow::bail!(t!("Can't lock the memory")),
44
        };
45

            
46
        // Access memory and set the field if data is found
47
        if let Some(memory) = &*m
48
            && let Some(bytes) = memory
49
                .data(caller)
50
                .get(loc as usize..loc as usize + len as usize)
51
        {
52
            let string_value = String::from_utf8_lossy(bytes).to_string();
53
            set_field(builder, string_value);
54
        }
55

            
56
        Ok(())
57
    })
58
}
59

            
60
// Wrapper for WASM numbers
61
pub fn wrap_wasm_i64<Data, Setter>(store: &mut Store<Arc<DataStore<Data>>>, setter: Setter) -> Func
62
where
63
    Data: Send + 'static,
64
    Setter: Fn(&mut Data, i64) + Send + Sync + 'static,
65
{
66
    Func::wrap(
67
        store,
68
        move |caller: Caller<'_, Arc<DataStore<Data>>>, value: i64| {
69
            with_builder(&caller, |builder| {
70
                setter(builder, value);
71
                Ok(())
72
            })
73
            .to_wasmtime_result()
74
        },
75
    )
76
}
77

            
78
// Generic wrapper for creating `Func` to set string fields in the builder
79
pub fn wrap_wasm_str<Data, Setter>(store: &mut Store<Arc<DataStore<Data>>>, setter: Setter) -> Func
80
where
81
    Data: Send + 'static,
82
    Setter: Fn(&mut Data, String) + Send + Sync + 'static,
83
{
84
    Func::wrap(
85
        store,
86
        move |caller: Caller<'_, Arc<DataStore<Data>>>, loc: i32, len: i32| {
87
            with_data_set_str(&caller, loc, len, |builder, value| {
88
                setter(builder, value);
89
            })
90
            .to_wasmtime_result()
91
        },
92
    )
93
}
94

            
95
// Wrapper for WASM i64 getter
96
pub fn wrap_wasm_i64_get<Data, Getter>(
97
    store: &mut Store<Arc<DataStore<Data>>>,
98
    getter: Getter,
99
) -> Func
100
where
101
    Data: Send + 'static,
102
    Getter: Fn(&Data) -> i64 + Send + Sync + 'static,
103
{
104
    Func::wrap(
105
        store,
106
        move |caller: Caller<'_, Arc<DataStore<Data>>>| -> i64 {
107
            let data = caller.data();
108
            let builder = data.data.lock().expect("Failed to lock data");
109

            
110
            // Call the getter to retrieve the i64 value
111
            getter(&builder)
112
        },
113
    )
114
}
115

            
116
// Wrapper for WASM string getter that writes the string to memory and returns its actual length
117
pub fn wrap_wasm_str_get<Data, Getter>(
118
    store: &mut Store<Arc<DataStore<Data>>>,
119
    getter: Getter,
120
) -> Func
121
where
122
    Data: Send + 'static,
123
    Getter: Fn(&Data) -> String + Send + Sync + 'static,
124
{
125
    Func::wrap(
126
        store,
127
        move |mut caller: Caller<'_, Arc<DataStore<Data>>>, loc: i32, max_len: i32| -> i32 {
128
            // Clone the Arc to avoid holding an immutable borrow on `caller`
129
            let data = Arc::clone(caller.data());
130

            
131
            // Extract the value from the builder in a separate scope
132
            let value = {
133
                let builder = data.data.lock().expect("Failed to lock data");
134
                getter(&builder)
135
            }; // `builder` is dropped here, ending the immutable borrow
136

            
137
            let bytes = value.as_bytes();
138

            
139
            // Lock the memory
140
            let m = data.memory.lock().expect("Failed to lock memory");
141

            
142
            if let Some(memory) = &*m {
143
                // Obtain a mutable context from the caller
144
                let mem = memory.data_mut(&mut caller);
145

            
146
                // Calculate the length to copy, limited by `max_len`
147
                let len = bytes.len().min(max_len as usize);
148

            
149
                // Ensure the memory bounds are valid
150
                if (loc as usize + len) <= mem.len() {
151
                    // Copy bytes to the specified location in WASM memory
152
                    mem[loc as usize..loc as usize + len].copy_from_slice(&bytes[..len]);
153

            
154
                    // Return the actual length of the copied string
155
                    return len as i32;
156
                }
157
                // Memory bounds are invalid; indicate failure
158
                return -1;
159
            }
160

            
161
            // Return -1 to indicate failure if memory is unavailable
162
            -1
163
        },
164
    )
165
}
166

            
167
38
pub fn wrap_wasm_str_transform<Data, Transform>(
168
38
    store: &mut Store<Arc<DataStore<Data>>>,
169
38
    transform: Transform,
170
38
) -> Func
171
38
where
172
38
    Data: Send + 'static,
173
38
    Transform: Fn(&Data, String) -> Result<String, anyhow::Error> + Send + Sync + 'static,
174
{
175
38
    Func::wrap(
176
38
        store,
177
        move |mut caller: Caller<'_, Arc<DataStore<Data>>>,
178
              tag_ptr: i32,
179
              tag_len: i32,
180
              buf_ptr: i32,
181
              buf_len: i32|
182
19
              -> i32 {
183
19
            let data = Arc::clone(caller.data());
184

            
185
            // Define the inner logic
186
19
            let result = (|| -> Result<i32, anyhow::Error> {
187
                // Lock the WASM memory
188
19
                let memory = data
189
19
                    .memory
190
19
                    .lock()
191
19
                    .map_err(|_| anyhow::anyhow!("Failed to lock memory"))?;
192
19
                let memory = memory
193
19
                    .as_ref()
194
19
                    .ok_or_else(|| anyhow::anyhow!("Memory not initialized"))?;
195
19
                let mem = memory.data(&caller);
196

            
197
                // Validate memory bounds for the input key
198
19
                if (tag_ptr as usize + tag_len as usize) > mem.len() {
199
                    return Err(anyhow::anyhow!("Input memory bounds error"));
200
19
                }
201

            
202
                // Extract the key string from memory
203
19
                let key = String::from_utf8(
204
19
                    mem[tag_ptr as usize..tag_ptr as usize + tag_len as usize].to_vec(),
205
                )
206
19
                .map_err(|_| anyhow::anyhow!("Failed to decode input string"))?;
207

            
208
                // Lock the data and call the transform function
209
19
                let data_lock = data
210
19
                    .data
211
19
                    .lock()
212
19
                    .map_err(|_| anyhow::anyhow!("Failed to lock data"))?;
213
19
                let output = transform(&*data_lock, key)?;
214

            
215
                // Get output bytes
216
19
                let output_bytes = output.as_bytes();
217
19
                let output_len = output_bytes.len();
218

            
219
                // Validate that the output fits within the provided buffer length
220
19
                if output_len > buf_len as usize {
221
                    return Err(anyhow::anyhow!("Output exceeds buffer size"));
222
19
                }
223

            
224
                // Validate memory bounds for the output buffer
225
19
                if (buf_ptr as usize + output_len) > mem.len() {
226
                    return Err(anyhow::anyhow!("Output memory bounds error"));
227
19
                }
228

            
229
                // Write the result string into the output buffer
230
19
                memory.data_mut(&mut caller)[buf_ptr as usize..buf_ptr as usize + output_len]
231
19
                    .copy_from_slice(output_bytes);
232

            
233
19
                Ok(output_len as i32)
234
            })();
235

            
236
            // Handle the result or return an error code
237
19
            match result {
238
19
                Ok(len) => len,
239
                Err(e) => {
240
                    tracing::error!("Error in wrap_wasm_str_transform: {e:?}");
241
                    -1 // Indicate failure
242
                }
243
            }
244
19
        },
245
    )
246
38
}
247

            
248
38
pub fn wrap_wasm_str_list_get<Data, Getter>(
249
38
    store: &mut Store<Arc<DataStore<Data>>>,
250
38
    getter: Getter,
251
38
) -> Func
252
38
where
253
38
    Data: Send + 'static,
254
38
    Getter: Fn(&Data) -> Result<Vec<String>, anyhow::Error> + Send + Sync + 'static,
255
{
256
38
    Func::wrap(
257
38
        store,
258
38
        move |mut caller: Caller<'_, Arc<DataStore<Data>>>, loc: i32, max_len: i32| -> i32 {
259
38
            let data = Arc::clone(caller.data());
260

            
261
            // Try to execute the getter and handle potential errors
262
38
            let result = (|| -> Result<i32> {
263
38
                let builder = data
264
38
                    .data
265
38
                    .lock()
266
38
                    .map_err(|_| anyhow::anyhow!("Failed to lock data"))?;
267
38
                let list = getter(&builder)?; // Call getter, which may return an error
268
38
                let mut bytes = Vec::new();
269

            
270
57
                for s in list {
271
57
                    bytes.extend_from_slice(s.as_bytes());
272
57
                    bytes.push(0); // Null terminator
273
57
                }
274

            
275
38
                let len = bytes.len().min(max_len as usize);
276

            
277
38
                let memory = data
278
38
                    .memory
279
38
                    .lock()
280
38
                    .map_err(|_| anyhow::anyhow!("Failed to lock memory"))?;
281
38
                if let Some(memory) = &*memory {
282
38
                    let mem = memory.data_mut(&mut caller);
283

            
284
38
                    if (loc as usize + len) <= mem.len() {
285
38
                        mem[loc as usize..loc as usize + len].copy_from_slice(&bytes[..len]);
286
38
                        return Ok(len as i32);
287
                    }
288
                    return Err(anyhow::anyhow!("Memory bounds error"));
289
                }
290

            
291
                Err(anyhow::anyhow!("Memory not available"))
292
            })();
293

            
294
            // Return the result or indicate failure (-1)
295
38
            match result {
296
38
                Ok(len) => len,
297
                Err(e) => {
298
                    tracing::error!("Error in wrap_wasm_str_list_get: {e:?}");
299
                    -1
300
                }
301
            }
302
38
        },
303
    )
304
38
}
305

            
306
38
pub fn wrap_wasm_str_map<Data>(
307
38
    store: &mut Store<Arc<DataStore<Data>>>,
308
38
    updater: impl Fn(&mut Data, String, String) -> Result<(), anyhow::Error> + Send + Sync + 'static,
309
38
) -> Func
310
38
where
311
38
    Data: Send + 'static,
312
{
313
38
    Func::wrap(
314
38
        store,
315
        move |caller: Caller<'_, Arc<DataStore<Data>>>,
316
              key_loc: i32,
317
              key_len: i32,
318
              value_loc: i32,
319
              value_len: i32|
320
57
              -> i32 {
321
57
            let data = caller.data();
322

            
323
57
            let result = (|| -> Result<()> {
324
57
                let memory = data
325
57
                    .memory
326
57
                    .lock()
327
57
                    .map_err(|_| anyhow::anyhow!("Failed to lock memory"))?;
328
57
                if let Some(memory) = &*memory {
329
57
                    let mem = memory.data(&caller);
330

            
331
57
                    if (key_loc as usize + key_len as usize) > mem.len()
332
57
                        || (value_loc as usize + value_len as usize) > mem.len()
333
                    {
334
                        return Err(anyhow::anyhow!("Memory bounds error"));
335
57
                    }
336

            
337
57
                    let key = String::from_utf8(
338
57
                        mem[key_loc as usize..key_loc as usize + key_len as usize].to_vec(),
339
                    )?;
340
57
                    let value = String::from_utf8(
341
57
                        mem[value_loc as usize..value_loc as usize + value_len as usize].to_vec(),
342
                    )?;
343

            
344
57
                    let mut builder = data
345
57
                        .data
346
57
                        .lock()
347
57
                        .map_err(|_| anyhow::anyhow!("Failed to lock data"))?;
348
57
                    updater(&mut builder, key, value)?;
349

            
350
57
                    Ok(())
351
                } else {
352
                    Err(anyhow::anyhow!("Memory not available"))
353
                }
354
            })();
355

            
356
57
            match result {
357
57
                Ok(()) => 0,
358
                Err(e) => {
359
                    tracing::error!("Error in wrap_wasm_str_map: {e:?}");
360
                    -1
361
                }
362
            }
363
57
        },
364
    )
365
38
}