1
use std::collections::HashMap;
2
use std::sync::{Arc, Mutex, RwLock};
3
use std::time::{SystemTime, UNIX_EPOCH};
4

            
5
use nomiscript::SymbolTable;
6
use wasmtime::{Caller, Engine, Linker, Memory, Module};
7

            
8
pub struct WasmHost {
9
    engine: Engine,
10
    symbol_table: Arc<RwLock<SymbolTable>>,
11
    module_cache: Arc<Mutex<HashMap<Vec<u8>, Module>>>,
12
}
13

            
14
impl WasmHost {
15
    #[must_use]
16
6610
    pub fn new(engine: Engine, symbol_table: SymbolTable) -> Self {
17
6610
        Self {
18
6610
            engine,
19
6610
            symbol_table: Arc::new(RwLock::new(symbol_table)),
20
6610
            module_cache: Arc::new(Mutex::new(HashMap::new())),
21
6610
        }
22
6610
    }
23

            
24
    #[must_use]
25
19124
    pub fn engine(&self) -> &Engine {
26
19124
        &self.engine
27
19124
    }
28

            
29
    #[must_use]
30
11816
    pub fn symbol_table(&self) -> &Arc<RwLock<SymbolTable>> {
31
11816
        &self.symbol_table
32
11816
    }
33

            
34
    #[must_use]
35
2130
    pub fn module_cache(&self) -> &Arc<Mutex<HashMap<Vec<u8>, Module>>> {
36
2130
        &self.module_cache
37
2130
    }
38

            
39
    #[must_use]
40
6020
    pub fn execution_state(
41
6020
        &self,
42
6020
        input_offset: u32,
43
6020
        output_offset: u32,
44
6020
        strings_offset: u32,
45
6020
    ) -> ExecutionState {
46
6020
        ExecutionState {
47
6020
            input_offset,
48
6020
            output_offset,
49
6020
            strings_offset,
50
6020
            output_strings_offset: Arc::new(Mutex::new(0)),
51
6020
            memory: None,
52
6020
            symbol_table: Arc::clone(&self.symbol_table),
53
6020
        }
54
6020
    }
55
}
56

            
57
pub struct ExecutionState {
58
    pub input_offset: u32,
59
    pub output_offset: u32,
60
    pub strings_offset: u32,
61
    pub output_strings_offset: Arc<Mutex<u32>>,
62
    pub memory: Option<Memory>,
63
    pub symbol_table: Arc<RwLock<SymbolTable>>,
64
}
65

            
66
impl ExecutionState {
67
    #[must_use]
68
1
    pub fn new(input_offset: u32, output_offset: u32, strings_offset: u32) -> Self {
69
1
        Self {
70
1
            input_offset,
71
1
            output_offset,
72
1
            strings_offset,
73
1
            output_strings_offset: Arc::new(Mutex::new(0)),
74
1
            memory: None,
75
1
            symbol_table: Arc::new(RwLock::new(SymbolTable::new())),
76
1
        }
77
1
    }
78
}
79

            
80
6020
pub fn define_host_functions(linker: &mut Linker<ExecutionState>) -> wasmtime::Result<()> {
81
6020
    linker.func_wrap(
82
6020
        "env",
83
6020
        "get_input_offset",
84
5012
        |caller: Caller<ExecutionState>| -> u32 { caller.data().input_offset },
85
    )?;
86

            
87
6020
    linker.func_wrap(
88
6020
        "env",
89
6020
        "get_output_offset",
90
6496
        |caller: Caller<ExecutionState>| -> u32 { caller.data().output_offset },
91
    )?;
92

            
93
6020
    linker.func_wrap(
94
6020
        "env",
95
6020
        "get_strings_offset",
96
        |caller: Caller<ExecutionState>| -> u32 { caller.data().strings_offset },
97
    )?;
98

            
99
6020
    linker.func_wrap(
100
6020
        "env",
101
6020
        "symbol_resolve",
102
        |caller: Caller<ExecutionState>, _name_ptr: u32, _name_len: u32| {
103
            let _memory = match caller.data().memory {
104
                Some(mem) => mem,
105
                None => return,
106
            };
107
            tracing::debug!(
108
                name_ptr = _name_ptr,
109
                name_len = _name_len,
110
                "symbol_resolve called"
111
            );
112
        },
113
    )?;
114

            
115
6020
    linker.func_wrap(
116
6020
        "env",
117
6020
        "write_bytes",
118
        |mut caller: Caller<ExecutionState>, dst: u32, src: u32, len: u32| -> u32 {
119
            let memory = match caller.data().memory {
120
                Some(mem) => mem,
121
                None => return 0,
122
            };
123
            let data = memory.data_mut(&mut caller);
124
            let src_start = src as usize;
125
            let src_end = src_start + len as usize;
126
            let dst_start = dst as usize;
127

            
128
            if src_end > data.len() || dst_start + len as usize > data.len() {
129
                return 0;
130
            }
131

            
132
            let bytes: Vec<u8> = data[src_start..src_end].to_vec();
133
            data[dst_start..dst_start + len as usize].copy_from_slice(&bytes);
134
            len
135
        },
136
    )?;
137

            
138
6020
    linker.func_wrap(
139
6020
        "env",
140
6020
        "write_string",
141
        |mut caller: Caller<ExecutionState>, ptr: u32, len: u32| -> u32 {
142
            let output_offset = caller.data().output_offset;
143
            let output_strings = caller.data().output_strings_offset.clone();
144

            
145
            let memory = match caller.data().memory {
146
                Some(mem) => mem,
147
                None => return 0,
148
            };
149

            
150
            let data = memory.data_mut(&mut caller);
151
            let src_start = ptr as usize;
152
            let src_end = src_start + len as usize;
153

            
154
            if src_end > data.len() {
155
                return 0;
156
            }
157

            
158
            let mut strings_offset = match output_strings.lock() {
159
                Ok(guard) => guard,
160
                Err(_) => return 0,
161
            };
162

            
163
            let current_offset = *strings_offset;
164
            let dst = output_offset as usize + current_offset as usize;
165

            
166
            if dst + len as usize > data.len() {
167
                return 0;
168
            }
169

            
170
            let bytes: Vec<u8> = data[src_start..src_end].to_vec();
171
            data[dst..dst + len as usize].copy_from_slice(&bytes);
172
            *strings_offset += len;
173

            
174
            current_offset
175
        },
176
    )?;
177

            
178
6020
    linker.func_wrap(
179
6020
        "env",
180
6020
        "log",
181
84
        |caller: Caller<ExecutionState>, level: u32, msg_ptr: u32, msg_len: u32| {
182
84
            tracing::debug!(level, msg_ptr, msg_len, "host log called");
183
84
            let memory = match caller.data().memory {
184
84
                Some(mem) => mem,
185
                None => return,
186
            };
187

            
188
84
            let data = memory.data(&caller);
189
84
            let start = msg_ptr as usize;
190
84
            let end = start + msg_len as usize;
191

            
192
84
            if end > data.len() {
193
                return;
194
84
            }
195

            
196
84
            let msg = match std::str::from_utf8(&data[start..end]) {
197
84
                Ok(s) => s,
198
                Err(_) => return,
199
            };
200

            
201
84
            match level {
202
84
                0 => tracing::debug!("[script] {msg}"),
203
                1 => tracing::info!("[script] {msg}"),
204
                2 => tracing::warn!("[script] {msg}"),
205
                _ => tracing::error!("[script] {msg}"),
206
            }
207
84
        },
208
    )?;
209

            
210
6020
    linker.func_wrap("env", "get_timestamp", || -> i64 {
211
        SystemTime::now()
212
            .duration_since(UNIX_EPOCH)
213
            .map_or(0, |d| d.as_millis() as i64)
214
    })?;
215

            
216
6020
    linker.func_wrap(
217
6020
        "env",
218
6020
        "generate_uuid",
219
560
        |mut caller: Caller<ExecutionState>, out_ptr: u32| {
220
560
            let memory = match caller.data().memory {
221
560
                Some(mem) => mem,
222
                None => return,
223
            };
224

            
225
560
            let uuid_bytes = uuid::Uuid::new_v4().into_bytes();
226
560
            let data = memory.data_mut(&mut caller);
227
560
            let start = out_ptr as usize;
228

            
229
560
            if start + 16 > data.len() {
230
                return;
231
560
            }
232

            
233
560
            data[start..start + 16].copy_from_slice(&uuid_bytes);
234
560
        },
235
    )?;
236

            
237
6020
    linker.func_wrap(
238
6020
        "env",
239
6020
        "get_input_entities_count",
240
2324
        |caller: Caller<ExecutionState>| -> i32 {
241
            use crate::format::GlobalHeader;
242

            
243
2324
            let memory = match caller.data().memory {
244
2324
                Some(mem) => mem,
245
                None => return 0,
246
            };
247

            
248
2324
            let input_offset = caller.data().input_offset;
249
2324
            let data = memory.data(&caller);
250
2324
            let input_start = input_offset as usize;
251

            
252
2324
            if input_start + std::mem::size_of::<GlobalHeader>() > data.len() {
253
                return 0;
254
2324
            }
255

            
256
2324
            if let Some(header) = GlobalHeader::from_bytes(&data[input_start..]) {
257
2324
                header.input_entity_count as i32
258
            } else {
259
                0
260
            }
261
2324
        },
262
    )?;
263

            
264
6020
    Ok(())
265
6020
}
266

            
267
#[cfg(test)]
268
mod tests {
269
    use super::*;
270
    use crate::format::BASE_OFFSET;
271

            
272
    #[test]
273
1
    fn test_execution_state_creation() {
274
1
        let state = ExecutionState::new(BASE_OFFSET, BASE_OFFSET + 1024, BASE_OFFSET + 512);
275
1
        assert_eq!(state.input_offset, BASE_OFFSET);
276
1
        assert_eq!(state.output_offset, BASE_OFFSET + 1024);
277
1
        assert_eq!(state.strings_offset, BASE_OFFSET + 512);
278
1
    }
279

            
280
    #[test]
281
1
    fn test_wasm_host_creation() {
282
1
        let host = WasmHost::new(Engine::default(), SymbolTable::new());
283
1
        assert!(host.module_cache().lock().unwrap().is_empty());
284
1
    }
285
}