1use nomiscript::SymbolTable;
2use tracing::debug;
3use wasmtime::{Config, Engine, Linker, Module, Store};
4
5use crate::error::HookError;
6use crate::format::{BASE_OFFSET, GlobalHeader, OUTPUT_HEADER_SIZE, OutputHeader};
7use crate::host::{WasmHost, define_host_functions};
8use crate::parser::{OutputParser, ParsedEntity};
9
10const DEFAULT_OUTPUT_SIZE: u32 = 64 * 1024;
11const WASM_PAGE_SIZE: u32 = 65536;
12const EPOCH_DEADLINE_TICKS: u64 = 5;
13const EPOCH_TICK_INTERVAL: std::time::Duration = std::time::Duration::from_secs(1);
14
15pub struct ScriptExecutor {
16 host: WasmHost,
17}
18
19impl Default for ScriptExecutor {
20 fn default() -> Self {
21 Self::new()
22 }
23}
24
25fn epoch_engine() -> Engine {
26 let mut config = Config::new();
27 config.wasm_gc(true);
28 config.epoch_interruption(true);
29 Engine::new(&config).expect("failed to create WASM engine")
30}
31
32impl ScriptExecutor {
33 #[must_use]
34 pub fn new() -> Self {
35 Self {
36 host: WasmHost::new(epoch_engine(), SymbolTable::new()),
37 }
38 }
39
40 #[must_use]
41 pub fn with_engine(engine: Engine) -> Self {
42 Self {
43 host: WasmHost::new(engine, SymbolTable::new()),
44 }
45 }
46
47 fn get_or_compile_module(&self, bytecode: &[u8]) -> Result<Module, HookError> {
48 let cache = self.host.module_cache().lock()?;
49 if let Some(module) = cache.get(bytecode) {
50 return Ok(module.clone());
51 }
52 drop(cache);
53
54 let module = Module::new(self.host.engine(), bytecode)?;
55
56 let mut cache = self.host.module_cache().lock()?;
57 cache.insert(bytecode.to_vec(), module.clone());
58 Ok(module)
59 }
60
61 pub fn execute(
62 &self,
63 bytecode: &[u8],
64 input: &[u8],
65 output_size: Option<u32>,
66 ) -> Result<Vec<ParsedEntity>, HookError> {
67 debug!(bytecode_size = bytecode.len(), "script execution start");
68 let output_size = output_size.unwrap_or(DEFAULT_OUTPUT_SIZE);
69 let module = self.get_or_compile_module(bytecode)?;
70 debug!("module compiled");
71
72 let header = GlobalHeader::from_bytes(input)
73 .ok_or_else(|| HookError::Parse("Invalid input header".to_string()))?;
74
75 let input_offset = BASE_OFFSET;
76 let output_offset = input_offset + input.len() as u32;
77 let strings_offset = header.strings_pool_offset;
78
79 let exec_state = self
80 .host
81 .execution_state(input_offset, output_offset, strings_offset);
82 let mut store = Store::new(self.host.engine(), exec_state);
83 store.set_epoch_deadline(EPOCH_DEADLINE_TICKS);
84
85 let engine = self.host.engine().clone();
86 let ticker = std::thread::spawn(move || {
87 for _ in 0..EPOCH_DEADLINE_TICKS {
88 std::thread::sleep(EPOCH_TICK_INTERVAL);
89 engine.increment_epoch();
90 }
91 });
92
93 let mut linker = Linker::new(self.host.engine());
94 define_host_functions(&mut linker)?;
95
96 let instance = linker.instantiate(&mut store, &module)?;
97
98 let memory = instance
99 .get_memory(&mut store, "memory")
100 .ok_or(HookError::WASMMem)?;
101
102 store.data_mut().memory = Some(memory);
103
104 let total_size = input.len() + output_size as usize;
105 let required_pages = (BASE_OFFSET as usize + total_size).div_ceil(WASM_PAGE_SIZE as usize);
106 let current_pages = memory.size(&store) as usize;
107
108 if required_pages > current_pages {
109 memory.grow(&mut store, (required_pages - current_pages) as u64)?;
110 }
111
112 let mem_data = memory.data_mut(&mut store);
113 let input_start = BASE_OFFSET as usize;
114 mem_data[input_start..input_start + input.len()].copy_from_slice(input);
115
116 let output_start = output_offset as usize;
117 let input_entity_count = header.input_entity_count;
118 let output_header = OutputHeader::new(input_entity_count);
119 mem_data[output_start..output_start + OUTPUT_HEADER_SIZE]
120 .copy_from_slice(&output_header.to_bytes());
121
122 let should_apply = instance
123 .get_typed_func::<(), i32>(&mut store, "should_apply")
124 .map_err(|e| HookError::Script(format!("Missing should_apply export: {e}")))?;
125
126 let result = should_apply.call(&mut store, ())?;
127 debug!(should_apply = result, "should_apply result");
128 if result == 0 {
129 drop(ticker);
130 return Ok(Vec::new());
131 }
132
133 let process = instance
134 .get_typed_func::<(), ()>(&mut store, "process")
135 .map_err(|e| HookError::Script(format!("Missing process export: {e}")))?;
136
137 debug!("calling process");
138 let call_result = process.call(&mut store, ());
139 drop(ticker);
140 call_result?;
141
142 let mem_data = memory.data(&store);
143 let output_data = &mem_data[output_start..output_start + output_size as usize];
144
145 let output_header = OutputHeader::from_bytes(output_data)
146 .ok_or_else(|| HookError::Parse("Invalid output header".to_string()))?;
147
148 let output_strings_offset = { output_header.strings_offset } as usize;
149
150 let parser = OutputParser::new(output_data, output_strings_offset)?;
151 let entities: Result<Vec<ParsedEntity>, HookError> = parser.entities().collect();
152 debug!(
153 entity_count = entities.as_ref().map_or(0, std::vec::Vec::len),
154 "output parse complete"
155 );
156 entities
157 }
158}
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163
164 #[test]
165 fn test_executor_creation() {
166 let executor = ScriptExecutor::new();
167 assert!(executor.host.module_cache().lock().unwrap().is_empty());
168 }
169}