Skip to main content

scripting/
host.rs

1use std::collections::HashMap;
2use std::sync::{Arc, Mutex, RwLock};
3use std::time::{SystemTime, UNIX_EPOCH};
4
5use nomiscript::SymbolTable;
6use wasmtime::{Caller, Engine, Linker, Memory, Module};
7
8pub struct WasmHost {
9    engine: Engine,
10    symbol_table: Arc<RwLock<SymbolTable>>,
11    module_cache: Arc<Mutex<HashMap<Vec<u8>, Module>>>,
12}
13
14impl WasmHost {
15    #[must_use]
16    pub fn new(engine: Engine, symbol_table: SymbolTable) -> Self {
17        Self {
18            engine,
19            symbol_table: Arc::new(RwLock::new(symbol_table)),
20            module_cache: Arc::new(Mutex::new(HashMap::new())),
21        }
22    }
23
24    #[must_use]
25    pub fn engine(&self) -> &Engine {
26        &self.engine
27    }
28
29    #[must_use]
30    pub fn symbol_table(&self) -> &Arc<RwLock<SymbolTable>> {
31        &self.symbol_table
32    }
33
34    #[must_use]
35    pub fn module_cache(&self) -> &Arc<Mutex<HashMap<Vec<u8>, Module>>> {
36        &self.module_cache
37    }
38
39    #[must_use]
40    pub fn execution_state(
41        &self,
42        input_offset: u32,
43        output_offset: u32,
44        strings_offset: u32,
45    ) -> ExecutionState {
46        ExecutionState {
47            input_offset,
48            output_offset,
49            strings_offset,
50            output_strings_offset: Arc::new(Mutex::new(0)),
51            memory: None,
52            symbol_table: Arc::clone(&self.symbol_table),
53        }
54    }
55}
56
57pub struct ExecutionState {
58    pub input_offset: u32,
59    pub output_offset: u32,
60    pub strings_offset: u32,
61    pub output_strings_offset: Arc<Mutex<u32>>,
62    pub memory: Option<Memory>,
63    pub symbol_table: Arc<RwLock<SymbolTable>>,
64}
65
66impl ExecutionState {
67    #[must_use]
68    pub fn new(input_offset: u32, output_offset: u32, strings_offset: u32) -> Self {
69        Self {
70            input_offset,
71            output_offset,
72            strings_offset,
73            output_strings_offset: Arc::new(Mutex::new(0)),
74            memory: None,
75            symbol_table: Arc::new(RwLock::new(SymbolTable::new())),
76        }
77    }
78}
79
80pub fn define_host_functions(linker: &mut Linker<ExecutionState>) -> wasmtime::Result<()> {
81    linker.func_wrap(
82        "env",
83        "get_input_offset",
84        |caller: Caller<ExecutionState>| -> u32 { caller.data().input_offset },
85    )?;
86
87    linker.func_wrap(
88        "env",
89        "get_output_offset",
90        |caller: Caller<ExecutionState>| -> u32 { caller.data().output_offset },
91    )?;
92
93    linker.func_wrap(
94        "env",
95        "get_strings_offset",
96        |caller: Caller<ExecutionState>| -> u32 { caller.data().strings_offset },
97    )?;
98
99    linker.func_wrap(
100        "env",
101        "symbol_resolve",
102        |caller: Caller<ExecutionState>, _name_ptr: u32, _name_len: u32| {
103            let _memory = match caller.data().memory {
104                Some(mem) => mem,
105                None => return,
106            };
107            tracing::debug!(
108                name_ptr = _name_ptr,
109                name_len = _name_len,
110                "symbol_resolve called"
111            );
112        },
113    )?;
114
115    linker.func_wrap(
116        "env",
117        "write_bytes",
118        |mut caller: Caller<ExecutionState>, dst: u32, src: u32, len: u32| -> u32 {
119            let memory = match caller.data().memory {
120                Some(mem) => mem,
121                None => return 0,
122            };
123            let data = memory.data_mut(&mut caller);
124            let src_start = src as usize;
125            let src_end = src_start + len as usize;
126            let dst_start = dst as usize;
127
128            if src_end > data.len() || dst_start + len as usize > data.len() {
129                return 0;
130            }
131
132            let bytes: Vec<u8> = data[src_start..src_end].to_vec();
133            data[dst_start..dst_start + len as usize].copy_from_slice(&bytes);
134            len
135        },
136    )?;
137
138    linker.func_wrap(
139        "env",
140        "write_string",
141        |mut caller: Caller<ExecutionState>, ptr: u32, len: u32| -> u32 {
142            let output_offset = caller.data().output_offset;
143            let output_strings = caller.data().output_strings_offset.clone();
144
145            let memory = match caller.data().memory {
146                Some(mem) => mem,
147                None => return 0,
148            };
149
150            let data = memory.data_mut(&mut caller);
151            let src_start = ptr as usize;
152            let src_end = src_start + len as usize;
153
154            if src_end > data.len() {
155                return 0;
156            }
157
158            let mut strings_offset = match output_strings.lock() {
159                Ok(guard) => guard,
160                Err(_) => return 0,
161            };
162
163            let current_offset = *strings_offset;
164            let dst = output_offset as usize + current_offset as usize;
165
166            if dst + len as usize > data.len() {
167                return 0;
168            }
169
170            let bytes: Vec<u8> = data[src_start..src_end].to_vec();
171            data[dst..dst + len as usize].copy_from_slice(&bytes);
172            *strings_offset += len;
173
174            current_offset
175        },
176    )?;
177
178    linker.func_wrap(
179        "env",
180        "log",
181        |caller: Caller<ExecutionState>, level: u32, msg_ptr: u32, msg_len: u32| {
182            tracing::debug!(level, msg_ptr, msg_len, "host log called");
183            let memory = match caller.data().memory {
184                Some(mem) => mem,
185                None => return,
186            };
187
188            let data = memory.data(&caller);
189            let start = msg_ptr as usize;
190            let end = start + msg_len as usize;
191
192            if end > data.len() {
193                return;
194            }
195
196            let msg = match std::str::from_utf8(&data[start..end]) {
197                Ok(s) => s,
198                Err(_) => return,
199            };
200
201            match level {
202                0 => tracing::debug!("[script] {msg}"),
203                1 => tracing::info!("[script] {msg}"),
204                2 => tracing::warn!("[script] {msg}"),
205                _ => tracing::error!("[script] {msg}"),
206            }
207        },
208    )?;
209
210    linker.func_wrap("env", "get_timestamp", || -> i64 {
211        SystemTime::now()
212            .duration_since(UNIX_EPOCH)
213            .map_or(0, |d| d.as_millis() as i64)
214    })?;
215
216    linker.func_wrap(
217        "env",
218        "generate_uuid",
219        |mut caller: Caller<ExecutionState>, out_ptr: u32| {
220            let memory = match caller.data().memory {
221                Some(mem) => mem,
222                None => return,
223            };
224
225            let uuid_bytes = uuid::Uuid::new_v4().into_bytes();
226            let data = memory.data_mut(&mut caller);
227            let start = out_ptr as usize;
228
229            if start + 16 > data.len() {
230                return;
231            }
232
233            data[start..start + 16].copy_from_slice(&uuid_bytes);
234        },
235    )?;
236
237    linker.func_wrap(
238        "env",
239        "get_input_entities_count",
240        |caller: Caller<ExecutionState>| -> i32 {
241            use crate::format::GlobalHeader;
242
243            let memory = match caller.data().memory {
244                Some(mem) => mem,
245                None => return 0,
246            };
247
248            let input_offset = caller.data().input_offset;
249            let data = memory.data(&caller);
250            let input_start = input_offset as usize;
251
252            if input_start + std::mem::size_of::<GlobalHeader>() > data.len() {
253                return 0;
254            }
255
256            if let Some(header) = GlobalHeader::from_bytes(&data[input_start..]) {
257                header.input_entity_count as i32
258            } else {
259                0
260            }
261        },
262    )?;
263
264    Ok(())
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270    use crate::format::BASE_OFFSET;
271
272    #[test]
273    fn test_execution_state_creation() {
274        let state = ExecutionState::new(BASE_OFFSET, BASE_OFFSET + 1024, BASE_OFFSET + 512);
275        assert_eq!(state.input_offset, BASE_OFFSET);
276        assert_eq!(state.output_offset, BASE_OFFSET + 1024);
277        assert_eq!(state.strings_offset, BASE_OFFSET + 512);
278    }
279
280    #[test]
281    fn test_wasm_host_creation() {
282        let host = WasmHost::new(Engine::default(), SymbolTable::new());
283        assert!(host.module_cache().lock().unwrap().is_empty());
284    }
285}