1
use nomiscript::SymbolTable;
2
use tracing::debug;
3
use wasmtime::{Engine, Linker, Module, Store};
4

            
5
use crate::error::HookError;
6
use crate::format::{BASE_OFFSET, GlobalHeader, OUTPUT_HEADER_SIZE, OutputHeader};
7
use crate::host::{WasmHost, define_host_functions};
8
use crate::parser::{OutputParser, ParsedEntity};
9
use crate::runtime::{EngineOpts, build_engine};
10

            
11
const DEFAULT_OUTPUT_SIZE: u32 = 64 * 1024;
12
const WASM_PAGE_SIZE: u32 = 65536;
13
const EPOCH_DEADLINE_TICKS: u64 = 5;
14
const EPOCH_TICK_INTERVAL: std::time::Duration = std::time::Duration::from_secs(1);
15

            
16
pub struct ScriptExecutor {
17
    host: WasmHost,
18
}
19

            
20
impl Default for ScriptExecutor {
21
    fn default() -> Self {
22
        Self::new()
23
    }
24
}
25

            
26
1153
fn default_engine() -> Engine {
27
1153
    build_engine(EngineOpts::baseline()).expect("WASM engine config rejected")
28
1153
}
29

            
30
impl ScriptExecutor {
31
    #[must_use]
32
1153
    pub fn new() -> Self {
33
1153
        Self {
34
1153
            host: WasmHost::new(default_engine(), SymbolTable::new()),
35
1153
        }
36
1153
    }
37

            
38
    #[must_use]
39
1344
    pub fn with_engine(engine: Engine) -> Self {
40
1344
        Self {
41
1344
            host: WasmHost::new(engine, SymbolTable::new()),
42
1344
        }
43
1344
    }
44

            
45
2592
    fn get_or_compile_module(&self, bytecode: &[u8]) -> Result<Module, HookError> {
46
2592
        self.host
47
2592
            .module_cache()
48
2592
            .get_or_compile(self.host.engine(), bytecode)
49
2592
            .map_err(HookError::from)
50
2592
    }
51

            
52
2592
    pub fn execute(
53
2592
        &self,
54
2592
        bytecode: &[u8],
55
2592
        input: &[u8],
56
2592
        output_size: Option<u32>,
57
2592
    ) -> Result<Vec<ParsedEntity>, HookError> {
58
2592
        debug!(bytecode_size = bytecode.len(), "script execution start");
59
2592
        let output_size = output_size.unwrap_or(DEFAULT_OUTPUT_SIZE);
60
2592
        let module = self.get_or_compile_module(bytecode)?;
61
2496
        debug!("module compiled");
62

            
63
2496
        let header = GlobalHeader::from_bytes(input)
64
2496
            .ok_or_else(|| HookError::Parse("Invalid input header".to_string()))?;
65

            
66
2496
        let input_offset = BASE_OFFSET;
67
2496
        let output_offset = input_offset + input.len() as u32;
68
2496
        let strings_offset = header.strings_pool_offset;
69

            
70
2496
        let exec_state = self
71
2496
            .host
72
2496
            .execution_state(input_offset, output_offset, strings_offset);
73
2496
        let mut store = Store::new(self.host.engine(), exec_state);
74
2496
        store.set_epoch_deadline(EPOCH_DEADLINE_TICKS);
75

            
76
2496
        let engine = self.host.engine().clone();
77
2496
        let ticker = std::thread::spawn(move || {
78
4032
            for _ in 0..EPOCH_DEADLINE_TICKS {
79
4032
                std::thread::sleep(EPOCH_TICK_INTERVAL);
80
4032
                engine.increment_epoch();
81
4032
            }
82
2496
        });
83

            
84
2496
        let mut linker = Linker::new(self.host.engine());
85
2496
        define_host_functions(&mut linker)?;
86

            
87
2496
        let instance = linker.instantiate(&mut store, &module)?;
88

            
89
2496
        let memory = instance
90
2496
            .get_memory(&mut store, "memory")
91
2496
            .ok_or(HookError::WASMMem)?;
92

            
93
2496
        store.data_mut().memory = Some(memory);
94

            
95
2496
        let total_size = input.len() + output_size as usize;
96
2496
        let required_pages = (BASE_OFFSET as usize + total_size).div_ceil(WASM_PAGE_SIZE as usize);
97
2496
        let current_pages = memory.size(&store) as usize;
98

            
99
2496
        if required_pages > current_pages {
100
            memory.grow(&mut store, (required_pages - current_pages) as u64)?;
101
2496
        }
102

            
103
2496
        let mem_data = memory.data_mut(&mut store);
104
2496
        let input_start = BASE_OFFSET as usize;
105
2496
        mem_data[input_start..input_start + input.len()].copy_from_slice(input);
106

            
107
2496
        let output_start = output_offset as usize;
108
2496
        let input_entity_count = header.input_entity_count;
109
2496
        let output_header = OutputHeader::new(input_entity_count);
110
2496
        mem_data[output_start..output_start + OUTPUT_HEADER_SIZE]
111
2496
            .copy_from_slice(&output_header.to_bytes());
112

            
113
2496
        let should_apply = instance
114
2496
            .get_typed_func::<(), i32>(&mut store, "should_apply")
115
2496
            .map_err(|e| HookError::Script(format!("Missing should_apply export: {e}")))?;
116

            
117
2496
        let result = should_apply.call(&mut store, ())?;
118
2496
        debug!(should_apply = result, "should_apply result");
119
2496
        if result == 0 {
120
384
            drop(ticker);
121
384
            return Ok(Vec::new());
122
2112
        }
123

            
124
2112
        let process = instance
125
2112
            .get_typed_func::<(), ()>(&mut store, "process")
126
2112
            .map_err(|e| HookError::Script(format!("Missing process export: {e}")))?;
127

            
128
2112
        debug!("calling process");
129
2112
        let call_result = process.call(&mut store, ());
130
2112
        drop(ticker);
131
2112
        call_result?;
132

            
133
2064
        let mem_data = memory.data(&store);
134
2064
        let output_data = &mem_data[output_start..output_start + output_size as usize];
135

            
136
2064
        let output_header = OutputHeader::from_bytes(output_data)
137
2064
            .ok_or_else(|| HookError::Parse("Invalid output header".to_string()))?;
138

            
139
2064
        let output_strings_offset = { output_header.strings_offset } as usize;
140

            
141
2064
        let parser = OutputParser::new(output_data, output_strings_offset)?;
142
2064
        let entities: Result<Vec<ParsedEntity>, HookError> = parser.entities().collect();
143
2064
        debug!(
144
240
            entity_count = entities.as_ref().map_or(0, std::vec::Vec::len),
145
            "output parse complete"
146
        );
147
2064
        entities
148
2592
    }
149
}
150

            
151
#[cfg(test)]
152
mod tests {
153
    use super::*;
154

            
155
    #[test]
156
1
    fn test_executor_creation() {
157
1
        let executor = ScriptExecutor::new();
158
1
        assert!(
159
1
            executor
160
1
                .host
161
1
                .module_cache()
162
1
                .is_empty()
163
1
                .expect("cache lock must not be poisoned in fresh executor")
164
        );
165
1
    }
166
}