1use std::collections::HashMap;
2use std::sync::{Arc, Mutex, RwLock};
3use std::time::{SystemTime, UNIX_EPOCH};
4
5use nomiscript::SymbolTable;
6use wasmtime::{Caller, Engine, Linker, Memory, Module};
7
8pub struct WasmHost {
9 engine: Engine,
10 symbol_table: Arc<RwLock<SymbolTable>>,
11 module_cache: Arc<Mutex<HashMap<Vec<u8>, Module>>>,
12}
13
14impl WasmHost {
15 #[must_use]
16 pub fn new(engine: Engine, symbol_table: SymbolTable) -> Self {
17 Self {
18 engine,
19 symbol_table: Arc::new(RwLock::new(symbol_table)),
20 module_cache: Arc::new(Mutex::new(HashMap::new())),
21 }
22 }
23
24 #[must_use]
25 pub fn engine(&self) -> &Engine {
26 &self.engine
27 }
28
29 #[must_use]
30 pub fn symbol_table(&self) -> &Arc<RwLock<SymbolTable>> {
31 &self.symbol_table
32 }
33
34 #[must_use]
35 pub fn module_cache(&self) -> &Arc<Mutex<HashMap<Vec<u8>, Module>>> {
36 &self.module_cache
37 }
38
39 #[must_use]
40 pub fn execution_state(
41 &self,
42 input_offset: u32,
43 output_offset: u32,
44 strings_offset: u32,
45 ) -> ExecutionState {
46 ExecutionState {
47 input_offset,
48 output_offset,
49 strings_offset,
50 output_strings_offset: Arc::new(Mutex::new(0)),
51 memory: None,
52 symbol_table: Arc::clone(&self.symbol_table),
53 }
54 }
55}
56
57pub struct ExecutionState {
58 pub input_offset: u32,
59 pub output_offset: u32,
60 pub strings_offset: u32,
61 pub output_strings_offset: Arc<Mutex<u32>>,
62 pub memory: Option<Memory>,
63 pub symbol_table: Arc<RwLock<SymbolTable>>,
64}
65
66impl ExecutionState {
67 #[must_use]
68 pub fn new(input_offset: u32, output_offset: u32, strings_offset: u32) -> Self {
69 Self {
70 input_offset,
71 output_offset,
72 strings_offset,
73 output_strings_offset: Arc::new(Mutex::new(0)),
74 memory: None,
75 symbol_table: Arc::new(RwLock::new(SymbolTable::new())),
76 }
77 }
78}
79
80pub fn define_host_functions(linker: &mut Linker<ExecutionState>) -> wasmtime::Result<()> {
81 linker.func_wrap(
82 "env",
83 "get_input_offset",
84 |caller: Caller<ExecutionState>| -> u32 { caller.data().input_offset },
85 )?;
86
87 linker.func_wrap(
88 "env",
89 "get_output_offset",
90 |caller: Caller<ExecutionState>| -> u32 { caller.data().output_offset },
91 )?;
92
93 linker.func_wrap(
94 "env",
95 "get_strings_offset",
96 |caller: Caller<ExecutionState>| -> u32 { caller.data().strings_offset },
97 )?;
98
99 linker.func_wrap(
100 "env",
101 "symbol_resolve",
102 |caller: Caller<ExecutionState>, _name_ptr: u32, _name_len: u32| {
103 let _memory = match caller.data().memory {
104 Some(mem) => mem,
105 None => return,
106 };
107 tracing::debug!(
108 name_ptr = _name_ptr,
109 name_len = _name_len,
110 "symbol_resolve called"
111 );
112 },
113 )?;
114
115 linker.func_wrap(
116 "env",
117 "write_bytes",
118 |mut caller: Caller<ExecutionState>, dst: u32, src: u32, len: u32| -> u32 {
119 let memory = match caller.data().memory {
120 Some(mem) => mem,
121 None => return 0,
122 };
123 let data = memory.data_mut(&mut caller);
124 let src_start = src as usize;
125 let src_end = src_start + len as usize;
126 let dst_start = dst as usize;
127
128 if src_end > data.len() || dst_start + len as usize > data.len() {
129 return 0;
130 }
131
132 let bytes: Vec<u8> = data[src_start..src_end].to_vec();
133 data[dst_start..dst_start + len as usize].copy_from_slice(&bytes);
134 len
135 },
136 )?;
137
138 linker.func_wrap(
139 "env",
140 "write_string",
141 |mut caller: Caller<ExecutionState>, ptr: u32, len: u32| -> u32 {
142 let output_offset = caller.data().output_offset;
143 let output_strings = caller.data().output_strings_offset.clone();
144
145 let memory = match caller.data().memory {
146 Some(mem) => mem,
147 None => return 0,
148 };
149
150 let data = memory.data_mut(&mut caller);
151 let src_start = ptr as usize;
152 let src_end = src_start + len as usize;
153
154 if src_end > data.len() {
155 return 0;
156 }
157
158 let mut strings_offset = match output_strings.lock() {
159 Ok(guard) => guard,
160 Err(_) => return 0,
161 };
162
163 let current_offset = *strings_offset;
164 let dst = output_offset as usize + current_offset as usize;
165
166 if dst + len as usize > data.len() {
167 return 0;
168 }
169
170 let bytes: Vec<u8> = data[src_start..src_end].to_vec();
171 data[dst..dst + len as usize].copy_from_slice(&bytes);
172 *strings_offset += len;
173
174 current_offset
175 },
176 )?;
177
178 linker.func_wrap(
179 "env",
180 "log",
181 |caller: Caller<ExecutionState>, level: u32, msg_ptr: u32, msg_len: u32| {
182 tracing::debug!(level, msg_ptr, msg_len, "host log called");
183 let memory = match caller.data().memory {
184 Some(mem) => mem,
185 None => return,
186 };
187
188 let data = memory.data(&caller);
189 let start = msg_ptr as usize;
190 let end = start + msg_len as usize;
191
192 if end > data.len() {
193 return;
194 }
195
196 let msg = match std::str::from_utf8(&data[start..end]) {
197 Ok(s) => s,
198 Err(_) => return,
199 };
200
201 match level {
202 0 => tracing::debug!("[script] {msg}"),
203 1 => tracing::info!("[script] {msg}"),
204 2 => tracing::warn!("[script] {msg}"),
205 _ => tracing::error!("[script] {msg}"),
206 }
207 },
208 )?;
209
210 linker.func_wrap("env", "get_timestamp", || -> i64 {
211 SystemTime::now()
212 .duration_since(UNIX_EPOCH)
213 .map_or(0, |d| d.as_millis() as i64)
214 })?;
215
216 linker.func_wrap(
217 "env",
218 "generate_uuid",
219 |mut caller: Caller<ExecutionState>, out_ptr: u32| {
220 let memory = match caller.data().memory {
221 Some(mem) => mem,
222 None => return,
223 };
224
225 let uuid_bytes = uuid::Uuid::new_v4().into_bytes();
226 let data = memory.data_mut(&mut caller);
227 let start = out_ptr as usize;
228
229 if start + 16 > data.len() {
230 return;
231 }
232
233 data[start..start + 16].copy_from_slice(&uuid_bytes);
234 },
235 )?;
236
237 linker.func_wrap(
238 "env",
239 "get_input_entities_count",
240 |caller: Caller<ExecutionState>| -> i32 {
241 use crate::format::GlobalHeader;
242
243 let memory = match caller.data().memory {
244 Some(mem) => mem,
245 None => return 0,
246 };
247
248 let input_offset = caller.data().input_offset;
249 let data = memory.data(&caller);
250 let input_start = input_offset as usize;
251
252 if input_start + std::mem::size_of::<GlobalHeader>() > data.len() {
253 return 0;
254 }
255
256 if let Some(header) = GlobalHeader::from_bytes(&data[input_start..]) {
257 header.input_entity_count as i32
258 } else {
259 0
260 }
261 },
262 )?;
263
264 Ok(())
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270 use crate::format::BASE_OFFSET;
271
272 #[test]
273 fn test_execution_state_creation() {
274 let state = ExecutionState::new(BASE_OFFSET, BASE_OFFSET + 1024, BASE_OFFSET + 512);
275 assert_eq!(state.input_offset, BASE_OFFSET);
276 assert_eq!(state.output_offset, BASE_OFFSET + 1024);
277 assert_eq!(state.strings_offset, BASE_OFFSET + 512);
278 }
279
280 #[test]
281 fn test_wasm_host_creation() {
282 let host = WasmHost::new(Engine::default(), SymbolTable::new());
283 assert!(host.module_cache().lock().unwrap().is_empty());
284 }
285}