Skip to main content

scripting/
host.rs

1use std::sync::{Arc, Mutex, RwLock};
2use std::time::{SystemTime, UNIX_EPOCH};
3
4use nomiscript::SymbolTable;
5use wasmtime::{Caller, Engine, Linker, Memory};
6
7use crate::runtime::ModuleCache;
8
9pub struct WasmHost {
10    engine: Engine,
11    symbol_table: Arc<RwLock<SymbolTable>>,
12    module_cache: ModuleCache,
13}
14
15impl WasmHost {
16    #[must_use]
17    pub fn new(engine: Engine, symbol_table: SymbolTable) -> Self {
18        Self {
19            engine,
20            symbol_table: Arc::new(RwLock::new(symbol_table)),
21            module_cache: ModuleCache::new(),
22        }
23    }
24
25    #[must_use]
26    pub fn engine(&self) -> &Engine {
27        &self.engine
28    }
29
30    #[must_use]
31    pub fn symbol_table(&self) -> &Arc<RwLock<SymbolTable>> {
32        &self.symbol_table
33    }
34
35    #[must_use]
36    pub fn module_cache(&self) -> &ModuleCache {
37        &self.module_cache
38    }
39
40    #[must_use]
41    pub fn execution_state(
42        &self,
43        input_offset: u32,
44        output_offset: u32,
45        strings_offset: u32,
46    ) -> ExecutionState {
47        ExecutionState {
48            input_offset,
49            output_offset,
50            strings_offset,
51            output_strings_offset: Arc::new(Mutex::new(0)),
52            memory: None,
53            symbol_table: Arc::clone(&self.symbol_table),
54        }
55    }
56}
57
58pub struct ExecutionState {
59    pub input_offset: u32,
60    pub output_offset: u32,
61    pub strings_offset: u32,
62    pub output_strings_offset: Arc<Mutex<u32>>,
63    pub memory: Option<Memory>,
64    pub symbol_table: Arc<RwLock<SymbolTable>>,
65}
66
67impl ExecutionState {
68    #[must_use]
69    pub fn new(input_offset: u32, output_offset: u32, strings_offset: u32) -> Self {
70        Self {
71            input_offset,
72            output_offset,
73            strings_offset,
74            output_strings_offset: Arc::new(Mutex::new(0)),
75            memory: None,
76            symbol_table: Arc::new(RwLock::new(SymbolTable::new())),
77        }
78    }
79}
80
81pub fn define_host_functions(linker: &mut Linker<ExecutionState>) -> wasmtime::Result<()> {
82    linker.func_wrap(
83        "env",
84        "get_input_offset",
85        |caller: Caller<ExecutionState>| -> u32 { caller.data().input_offset },
86    )?;
87
88    linker.func_wrap(
89        "env",
90        "get_output_offset",
91        |caller: Caller<ExecutionState>| -> u32 { caller.data().output_offset },
92    )?;
93
94    linker.func_wrap(
95        "env",
96        "get_strings_offset",
97        |caller: Caller<ExecutionState>| -> u32 { caller.data().strings_offset },
98    )?;
99
100    linker.func_wrap(
101        "env",
102        "symbol_resolve",
103        |caller: Caller<ExecutionState>, _name_ptr: u32, _name_len: u32| {
104            let _memory = match caller.data().memory {
105                Some(mem) => mem,
106                None => return,
107            };
108            tracing::debug!(
109                name_ptr = _name_ptr,
110                name_len = _name_len,
111                "symbol_resolve called"
112            );
113        },
114    )?;
115
116    linker.func_wrap(
117        "env",
118        "write_bytes",
119        |mut caller: Caller<ExecutionState>, dst: u32, src: u32, len: u32| -> u32 {
120            let memory = match caller.data().memory {
121                Some(mem) => mem,
122                None => return 0,
123            };
124            let data = memory.data_mut(&mut caller);
125            let src_start = src as usize;
126            let src_end = src_start + len as usize;
127            let dst_start = dst as usize;
128
129            if src_end > data.len() || dst_start + len as usize > data.len() {
130                return 0;
131            }
132
133            let bytes: Vec<u8> = data[src_start..src_end].to_vec();
134            data[dst_start..dst_start + len as usize].copy_from_slice(&bytes);
135            len
136        },
137    )?;
138
139    linker.func_wrap(
140        "env",
141        "write_string",
142        |mut caller: Caller<ExecutionState>, ptr: u32, len: u32| -> u32 {
143            let output_offset = caller.data().output_offset;
144            let output_strings = caller.data().output_strings_offset.clone();
145
146            let memory = match caller.data().memory {
147                Some(mem) => mem,
148                None => return 0,
149            };
150
151            let data = memory.data_mut(&mut caller);
152            let src_start = ptr as usize;
153            let src_end = src_start + len as usize;
154
155            if src_end > data.len() {
156                return 0;
157            }
158
159            let mut strings_offset = match output_strings.lock() {
160                Ok(guard) => guard,
161                Err(_) => return 0,
162            };
163
164            let current_offset = *strings_offset;
165            let dst = output_offset as usize + current_offset as usize;
166
167            if dst + len as usize > data.len() {
168                return 0;
169            }
170
171            let bytes: Vec<u8> = data[src_start..src_end].to_vec();
172            data[dst..dst + len as usize].copy_from_slice(&bytes);
173            *strings_offset += len;
174
175            current_offset
176        },
177    )?;
178
179    linker.func_wrap(
180        "env",
181        "log",
182        |caller: Caller<ExecutionState>, level: u32, msg_ptr: u32, msg_len: u32| {
183            tracing::debug!(level, msg_ptr, msg_len, "host log called");
184            let memory = match caller.data().memory {
185                Some(mem) => mem,
186                None => return,
187            };
188
189            let data = memory.data(&caller);
190            let start = msg_ptr as usize;
191            let end = start + msg_len as usize;
192
193            if end > data.len() {
194                return;
195            }
196
197            let msg = match std::str::from_utf8(&data[start..end]) {
198                Ok(s) => s,
199                Err(_) => return,
200            };
201
202            match level {
203                0 => tracing::debug!("[script] {msg}"),
204                1 => tracing::info!("[script] {msg}"),
205                2 => tracing::warn!("[script] {msg}"),
206                _ => tracing::error!("[script] {msg}"),
207            }
208        },
209    )?;
210
211    linker.func_wrap("env", "get_timestamp", || -> i64 {
212        SystemTime::now()
213            .duration_since(UNIX_EPOCH)
214            .map_or(0, |d| d.as_millis() as i64)
215    })?;
216
217    linker.func_wrap(
218        "env",
219        "generate_uuid",
220        |mut caller: Caller<ExecutionState>, out_ptr: u32| {
221            let memory = match caller.data().memory {
222                Some(mem) => mem,
223                None => return,
224            };
225
226            let uuid_bytes = uuid::Uuid::new_v4().into_bytes();
227            let data = memory.data_mut(&mut caller);
228            let start = out_ptr as usize;
229
230            if start + 16 > data.len() {
231                return;
232            }
233
234            data[start..start + 16].copy_from_slice(&uuid_bytes);
235        },
236    )?;
237
238    linker.func_wrap(
239        "env",
240        "get_input_entities_count",
241        |caller: Caller<ExecutionState>| -> i32 {
242            use crate::format::GlobalHeader;
243
244            let memory = match caller.data().memory {
245                Some(mem) => mem,
246                None => return 0,
247            };
248
249            let input_offset = caller.data().input_offset;
250            let data = memory.data(&caller);
251            let input_start = input_offset as usize;
252
253            if input_start + std::mem::size_of::<GlobalHeader>() > data.len() {
254                return 0;
255            }
256
257            if let Some(header) = GlobalHeader::from_bytes(&data[input_start..]) {
258                header.input_entity_count as i32
259            } else {
260                0
261            }
262        },
263    )?;
264
265    // Tier 3 boundary bridge: the compiler wraps each host-invoked body in
266    // a `try_table` that catches an uncaught `$nomi_error`, reads its
267    // code+message, and calls `__nomi_raise`. Script-mode modules
268    // (ScriptExecutor, nms) instantiate through this linker, so the bridge
269    // must live here too — mirrors `rpc::natives::raise`. Returns the
270    // `__nomi_raise:CODE:MSG` marker `classify_runtime_error` recognises.
271    linker.func_wrap(
272        "nomi",
273        "__nomi_raise",
274        |mut caller: Caller<ExecutionState>,
275         code_arg: Option<wasmtime::Rooted<wasmtime::ArrayRef>>,
276         msg_arg: Option<wasmtime::Rooted<wasmtime::ArrayRef>>|
277         -> wasmtime::Result<()> {
278            let code = crate::runtime::read_string_arg(&mut caller, code_arg)?
279                .ok_or_else(|| wasmtime::Error::msg("error: missing :code arg"))?;
280            let message =
281                crate::runtime::read_string_arg(&mut caller, msg_arg)?.unwrap_or_default();
282            Err(wasmtime::Error::msg(format!(
283                "{}{code}:{message}",
284                crate::runtime::NOMI_RAISE_MARKER
285            )))
286        },
287    )?;
288
289    Ok(())
290}
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295    use crate::format::BASE_OFFSET;
296
297    #[test]
298    fn test_execution_state_creation() {
299        let state = ExecutionState::new(BASE_OFFSET, BASE_OFFSET + 1024, BASE_OFFSET + 512);
300        assert_eq!(state.input_offset, BASE_OFFSET);
301        assert_eq!(state.output_offset, BASE_OFFSET + 1024);
302        assert_eq!(state.strings_offset, BASE_OFFSET + 512);
303    }
304
305    #[test]
306    fn test_wasm_host_creation() {
307        let host = WasmHost::new(Engine::default(), SymbolTable::new());
308        assert!(
309            host.module_cache()
310                .is_empty()
311                .expect("cache lock must not be poisoned in fresh host")
312        );
313    }
314}