Skip to main content

scripting/
executor.rs

1use nomiscript::SymbolTable;
2use tracing::debug;
3use wasmtime::{Engine, Linker, Module, Store};
4
5use crate::error::HookError;
6use crate::format::{BASE_OFFSET, GlobalHeader, OUTPUT_HEADER_SIZE, OutputHeader};
7use crate::host::{WasmHost, define_host_functions};
8use crate::parser::{OutputParser, ParsedEntity};
9use crate::runtime::{EngineOpts, build_engine};
10
11const DEFAULT_OUTPUT_SIZE: u32 = 64 * 1024;
12const WASM_PAGE_SIZE: u32 = 65536;
13const EPOCH_DEADLINE_TICKS: u64 = 5;
14const EPOCH_TICK_INTERVAL: std::time::Duration = std::time::Duration::from_secs(1);
15
16pub struct ScriptExecutor {
17    host: WasmHost,
18}
19
20impl Default for ScriptExecutor {
21    fn default() -> Self {
22        Self::new()
23    }
24}
25
26fn default_engine() -> Engine {
27    build_engine(EngineOpts::baseline()).expect("WASM engine config rejected")
28}
29
30impl ScriptExecutor {
31    #[must_use]
32    pub fn new() -> Self {
33        Self {
34            host: WasmHost::new(default_engine(), SymbolTable::new()),
35        }
36    }
37
38    #[must_use]
39    pub fn with_engine(engine: Engine) -> Self {
40        Self {
41            host: WasmHost::new(engine, SymbolTable::new()),
42        }
43    }
44
45    fn get_or_compile_module(&self, bytecode: &[u8]) -> Result<Module, HookError> {
46        self.host
47            .module_cache()
48            .get_or_compile(self.host.engine(), bytecode)
49            .map_err(HookError::from)
50    }
51
52    pub fn execute(
53        &self,
54        bytecode: &[u8],
55        input: &[u8],
56        output_size: Option<u32>,
57    ) -> Result<Vec<ParsedEntity>, HookError> {
58        debug!(bytecode_size = bytecode.len(), "script execution start");
59        let output_size = output_size.unwrap_or(DEFAULT_OUTPUT_SIZE);
60        let module = self.get_or_compile_module(bytecode)?;
61        debug!("module compiled");
62
63        let header = GlobalHeader::from_bytes(input)
64            .ok_or_else(|| HookError::Parse("Invalid input header".to_string()))?;
65
66        let input_offset = BASE_OFFSET;
67        let output_offset = input_offset + input.len() as u32;
68        let strings_offset = header.strings_pool_offset;
69
70        let exec_state = self
71            .host
72            .execution_state(input_offset, output_offset, strings_offset);
73        let mut store = Store::new(self.host.engine(), exec_state);
74        store.set_epoch_deadline(EPOCH_DEADLINE_TICKS);
75
76        let engine = self.host.engine().clone();
77        let ticker = std::thread::spawn(move || {
78            for _ in 0..EPOCH_DEADLINE_TICKS {
79                std::thread::sleep(EPOCH_TICK_INTERVAL);
80                engine.increment_epoch();
81            }
82        });
83
84        let mut linker = Linker::new(self.host.engine());
85        define_host_functions(&mut linker)?;
86
87        let instance = linker.instantiate(&mut store, &module)?;
88
89        let memory = instance
90            .get_memory(&mut store, "memory")
91            .ok_or(HookError::WASMMem)?;
92
93        store.data_mut().memory = Some(memory);
94
95        let total_size = input.len() + output_size as usize;
96        let required_pages = (BASE_OFFSET as usize + total_size).div_ceil(WASM_PAGE_SIZE as usize);
97        let current_pages = memory.size(&store) as usize;
98
99        if required_pages > current_pages {
100            memory.grow(&mut store, (required_pages - current_pages) as u64)?;
101        }
102
103        let mem_data = memory.data_mut(&mut store);
104        let input_start = BASE_OFFSET as usize;
105        mem_data[input_start..input_start + input.len()].copy_from_slice(input);
106
107        let output_start = output_offset as usize;
108        let input_entity_count = header.input_entity_count;
109        let output_header = OutputHeader::new(input_entity_count);
110        mem_data[output_start..output_start + OUTPUT_HEADER_SIZE]
111            .copy_from_slice(&output_header.to_bytes());
112
113        let should_apply = instance
114            .get_typed_func::<(), i32>(&mut store, "should_apply")
115            .map_err(|e| HookError::Script(format!("Missing should_apply export: {e}")))?;
116
117        let result = should_apply.call(&mut store, ())?;
118        debug!(should_apply = result, "should_apply result");
119        if result == 0 {
120            drop(ticker);
121            return Ok(Vec::new());
122        }
123
124        let process = instance
125            .get_typed_func::<(), ()>(&mut store, "process")
126            .map_err(|e| HookError::Script(format!("Missing process export: {e}")))?;
127
128        debug!("calling process");
129        let call_result = process.call(&mut store, ());
130        drop(ticker);
131        call_result?;
132
133        let mem_data = memory.data(&store);
134        let output_data = &mem_data[output_start..output_start + output_size as usize];
135
136        let output_header = OutputHeader::from_bytes(output_data)
137            .ok_or_else(|| HookError::Parse("Invalid output header".to_string()))?;
138
139        let output_strings_offset = { output_header.strings_offset } as usize;
140
141        let parser = OutputParser::new(output_data, output_strings_offset)?;
142        let entities: Result<Vec<ParsedEntity>, HookError> = parser.entities().collect();
143        debug!(
144            entity_count = entities.as_ref().map_or(0, std::vec::Vec::len),
145            "output parse complete"
146        );
147        entities
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154
155    #[test]
156    fn test_executor_creation() {
157        let executor = ScriptExecutor::new();
158        assert!(
159            executor
160                .host
161                .module_cache()
162                .is_empty()
163                .expect("cache lock must not be poisoned in fresh executor")
164        );
165    }
166}