1
use nomiscript::SymbolTable;
2
use tracing::debug;
3
use wasmtime::{Engine, Linker, Module, Store};
4

            
5
use crate::error::HookError;
6
use crate::format::{BASE_OFFSET, GlobalHeader, OUTPUT_HEADER_SIZE, OutputHeader};
7
use crate::host::{WasmHost, define_host_functions};
8
use crate::parser::{OutputParser, ParsedEntity};
9

            
10
const DEFAULT_OUTPUT_SIZE: u32 = 64 * 1024;
11
const WASM_PAGE_SIZE: u32 = 65536;
12

            
13
pub struct ScriptExecutor {
14
    host: WasmHost,
15
}
16

            
17
impl Default for ScriptExecutor {
18
    fn default() -> Self {
19
        Self::new()
20
    }
21
}
22

            
23
impl ScriptExecutor {
24
    #[must_use]
25
153
    pub fn new() -> Self {
26
153
        Self {
27
153
            host: WasmHost::new(Engine::default(), SymbolTable::new()),
28
153
        }
29
153
    }
30

            
31
    #[must_use]
32
38
    pub fn with_engine(engine: Engine) -> Self {
33
38
        Self {
34
38
            host: WasmHost::new(engine, SymbolTable::new()),
35
38
        }
36
38
    }
37

            
38
190
    fn get_or_compile_module(&self, bytecode: &[u8]) -> Result<Module, HookError> {
39
190
        let cache = self.host.module_cache().lock()?;
40
190
        if let Some(module) = cache.get(bytecode) {
41
            return Ok(module.clone());
42
190
        }
43
190
        drop(cache);
44

            
45
190
        let module = Module::new(self.host.engine(), bytecode)?;
46

            
47
190
        let mut cache = self.host.module_cache().lock()?;
48
190
        cache.insert(bytecode.to_vec(), module.clone());
49
190
        Ok(module)
50
190
    }
51

            
52
190
    pub fn execute(
53
190
        &self,
54
190
        bytecode: &[u8],
55
190
        input: &[u8],
56
190
        output_size: Option<u32>,
57
190
    ) -> Result<Vec<ParsedEntity>, HookError> {
58
190
        debug!(bytecode_size = bytecode.len(), "script execution start");
59
190
        let output_size = output_size.unwrap_or(DEFAULT_OUTPUT_SIZE);
60
190
        let module = self.get_or_compile_module(bytecode)?;
61
190
        debug!("module compiled");
62

            
63
190
        let header = GlobalHeader::from_bytes(input)
64
190
            .ok_or_else(|| HookError::Parse("Invalid input header".to_string()))?;
65

            
66
190
        let input_offset = BASE_OFFSET;
67
190
        let output_offset = input_offset + input.len() as u32;
68
190
        let strings_offset = header.strings_pool_offset;
69

            
70
190
        let exec_state = self
71
190
            .host
72
190
            .execution_state(input_offset, output_offset, strings_offset);
73
190
        let mut store = Store::new(self.host.engine(), exec_state);
74

            
75
190
        let mut linker = Linker::new(self.host.engine());
76
190
        define_host_functions(&mut linker)?;
77

            
78
190
        let instance = linker.instantiate(&mut store, &module)?;
79

            
80
190
        let memory = instance
81
190
            .get_memory(&mut store, "memory")
82
190
            .ok_or(HookError::WASMMem)?;
83

            
84
190
        store.data_mut().memory = Some(memory);
85

            
86
190
        let total_size = input.len() + output_size as usize;
87
190
        let required_pages = (BASE_OFFSET as usize + total_size).div_ceil(WASM_PAGE_SIZE as usize);
88
190
        let current_pages = memory.size(&store) as usize;
89

            
90
190
        if required_pages > current_pages {
91
            memory.grow(&mut store, (required_pages - current_pages) as u64)?;
92
190
        }
93

            
94
190
        let mem_data = memory.data_mut(&mut store);
95
190
        let input_start = BASE_OFFSET as usize;
96
190
        mem_data[input_start..input_start + input.len()].copy_from_slice(input);
97

            
98
190
        let output_start = output_offset as usize;
99
190
        let input_entity_count = header.input_entity_count;
100
190
        let output_header = OutputHeader::new(input_entity_count);
101
190
        mem_data[output_start..output_start + OUTPUT_HEADER_SIZE]
102
190
            .copy_from_slice(&output_header.to_bytes());
103

            
104
190
        let should_apply = instance
105
190
            .get_typed_func::<(), i32>(&mut store, "should_apply")
106
190
            .map_err(|e| HookError::Script(format!("Missing should_apply export: {e}")))?;
107

            
108
190
        let result = should_apply.call(&mut store, ())?;
109
190
        debug!(should_apply = result, "should_apply result");
110
190
        if result == 0 {
111
95
            return Ok(Vec::new());
112
95
        }
113

            
114
95
        let process = instance
115
95
            .get_typed_func::<(), ()>(&mut store, "process")
116
95
            .map_err(|e| HookError::Script(format!("Missing process export: {e}")))?;
117

            
118
95
        debug!("calling process");
119
95
        process.call(&mut store, ())?;
120

            
121
95
        let mem_data = memory.data(&store);
122
95
        let output_data = &mem_data[output_start..output_start + output_size as usize];
123

            
124
95
        let output_header = OutputHeader::from_bytes(output_data)
125
95
            .ok_or_else(|| HookError::Parse("Invalid output header".to_string()))?;
126

            
127
95
        let output_strings_offset = { output_header.strings_offset } as usize;
128

            
129
95
        let parser = OutputParser::new(output_data, output_strings_offset)?;
130
95
        let entities: Result<Vec<ParsedEntity>, HookError> = parser.entities().collect();
131
95
        debug!(
132
19
            entity_count = entities.as_ref().map_or(0, std::vec::Vec::len),
133
            "output parse complete"
134
        );
135
95
        entities
136
190
    }
137
}
138

            
139
#[cfg(test)]
140
mod tests {
141
    use super::*;
142

            
143
    #[test]
144
1
    fn test_executor_creation() {
145
1
        let executor = ScriptExecutor::new();
146
1
        assert!(executor.host.module_cache().lock().unwrap().is_empty());
147
1
    }
148
}