1
use nomiscript::SymbolTable;
2
use tracing::debug;
3
use wasmtime::{Config, Engine, Linker, Module, Store};
4

            
5
use crate::error::HookError;
6
use crate::format::{BASE_OFFSET, GlobalHeader, OUTPUT_HEADER_SIZE, OutputHeader};
7
use crate::host::{WasmHost, define_host_functions};
8
use crate::parser::{OutputParser, ParsedEntity};
9

            
10
const DEFAULT_OUTPUT_SIZE: u32 = 64 * 1024;
11
const WASM_PAGE_SIZE: u32 = 65536;
12
const EPOCH_DEADLINE_TICKS: u64 = 5;
13
const EPOCH_TICK_INTERVAL: std::time::Duration = std::time::Duration::from_secs(1);
14

            
15
pub struct ScriptExecutor {
16
    host: WasmHost,
17
}
18

            
19
impl Default for ScriptExecutor {
20
    fn default() -> Self {
21
        Self::new()
22
    }
23
}
24

            
25
589
fn epoch_engine() -> Engine {
26
589
    let mut config = Config::new();
27
589
    config.wasm_gc(true);
28
589
    config.epoch_interruption(true);
29
589
    Engine::new(&config).expect("failed to create WASM engine")
30
589
}
31

            
32
impl ScriptExecutor {
33
    #[must_use]
34
589
    pub fn new() -> Self {
35
589
        Self {
36
589
            host: WasmHost::new(epoch_engine(), SymbolTable::new()),
37
589
        }
38
589
    }
39

            
40
    #[must_use]
41
448
    pub fn with_engine(engine: Engine) -> Self {
42
448
        Self {
43
448
            host: WasmHost::new(engine, SymbolTable::new()),
44
448
        }
45
448
    }
46

            
47
1064
    fn get_or_compile_module(&self, bytecode: &[u8]) -> Result<Module, HookError> {
48
1064
        let cache = self.host.module_cache().lock()?;
49
1064
        if let Some(module) = cache.get(bytecode) {
50
            return Ok(module.clone());
51
1064
        }
52
1064
        drop(cache);
53

            
54
1064
        let module = Module::new(self.host.engine(), bytecode)?;
55

            
56
1064
        let mut cache = self.host.module_cache().lock()?;
57
1064
        cache.insert(bytecode.to_vec(), module.clone());
58
1064
        Ok(module)
59
1064
    }
60

            
61
1064
    pub fn execute(
62
1064
        &self,
63
1064
        bytecode: &[u8],
64
1064
        input: &[u8],
65
1064
        output_size: Option<u32>,
66
1064
    ) -> Result<Vec<ParsedEntity>, HookError> {
67
1064
        debug!(bytecode_size = bytecode.len(), "script execution start");
68
1064
        let output_size = output_size.unwrap_or(DEFAULT_OUTPUT_SIZE);
69
1064
        let module = self.get_or_compile_module(bytecode)?;
70
1064
        debug!("module compiled");
71

            
72
1064
        let header = GlobalHeader::from_bytes(input)
73
1064
            .ok_or_else(|| HookError::Parse("Invalid input header".to_string()))?;
74

            
75
1064
        let input_offset = BASE_OFFSET;
76
1064
        let output_offset = input_offset + input.len() as u32;
77
1064
        let strings_offset = header.strings_pool_offset;
78

            
79
1064
        let exec_state = self
80
1064
            .host
81
1064
            .execution_state(input_offset, output_offset, strings_offset);
82
1064
        let mut store = Store::new(self.host.engine(), exec_state);
83
1064
        store.set_epoch_deadline(EPOCH_DEADLINE_TICKS);
84

            
85
1064
        let engine = self.host.engine().clone();
86
1064
        let ticker = std::thread::spawn(move || {
87
1064
            for _ in 0..EPOCH_DEADLINE_TICKS {
88
700
                std::thread::sleep(EPOCH_TICK_INTERVAL);
89
700
                engine.increment_epoch();
90
700
            }
91
1064
        });
92

            
93
1064
        let mut linker = Linker::new(self.host.engine());
94
1064
        define_host_functions(&mut linker)?;
95

            
96
1064
        let instance = linker.instantiate(&mut store, &module)?;
97

            
98
1064
        let memory = instance
99
1064
            .get_memory(&mut store, "memory")
100
1064
            .ok_or(HookError::WASMMem)?;
101

            
102
1064
        store.data_mut().memory = Some(memory);
103

            
104
1064
        let total_size = input.len() + output_size as usize;
105
1064
        let required_pages = (BASE_OFFSET as usize + total_size).div_ceil(WASM_PAGE_SIZE as usize);
106
1064
        let current_pages = memory.size(&store) as usize;
107

            
108
1064
        if required_pages > current_pages {
109
            memory.grow(&mut store, (required_pages - current_pages) as u64)?;
110
1064
        }
111

            
112
1064
        let mem_data = memory.data_mut(&mut store);
113
1064
        let input_start = BASE_OFFSET as usize;
114
1064
        mem_data[input_start..input_start + input.len()].copy_from_slice(input);
115

            
116
1064
        let output_start = output_offset as usize;
117
1064
        let input_entity_count = header.input_entity_count;
118
1064
        let output_header = OutputHeader::new(input_entity_count);
119
1064
        mem_data[output_start..output_start + OUTPUT_HEADER_SIZE]
120
1064
            .copy_from_slice(&output_header.to_bytes());
121

            
122
1064
        let should_apply = instance
123
1064
            .get_typed_func::<(), i32>(&mut store, "should_apply")
124
1064
            .map_err(|e| HookError::Script(format!("Missing should_apply export: {e}")))?;
125

            
126
1064
        let result = should_apply.call(&mut store, ())?;
127
1064
        debug!(should_apply = result, "should_apply result");
128
1064
        if result == 0 {
129
224
            drop(ticker);
130
224
            return Ok(Vec::new());
131
840
        }
132

            
133
840
        let process = instance
134
840
            .get_typed_func::<(), ()>(&mut store, "process")
135
840
            .map_err(|e| HookError::Script(format!("Missing process export: {e}")))?;
136

            
137
840
        debug!("calling process");
138
840
        let call_result = process.call(&mut store, ());
139
840
        drop(ticker);
140
840
        call_result?;
141

            
142
840
        let mem_data = memory.data(&store);
143
840
        let output_data = &mem_data[output_start..output_start + output_size as usize];
144

            
145
840
        let output_header = OutputHeader::from_bytes(output_data)
146
840
            .ok_or_else(|| HookError::Parse("Invalid output header".to_string()))?;
147

            
148
840
        let output_strings_offset = { output_header.strings_offset } as usize;
149

            
150
840
        let parser = OutputParser::new(output_data, output_strings_offset)?;
151
840
        let entities: Result<Vec<ParsedEntity>, HookError> = parser.entities().collect();
152
840
        debug!(
153
112
            entity_count = entities.as_ref().map_or(0, std::vec::Vec::len),
154
            "output parse complete"
155
        );
156
840
        entities
157
1064
    }
158
}
159

            
160
#[cfg(test)]
161
mod tests {
162
    use super::*;
163

            
164
    #[test]
165
1
    fn test_executor_creation() {
166
1
        let executor = ScriptExecutor::new();
167
1
        assert!(executor.host.module_cache().lock().unwrap().is_empty());
168
1
    }
169
}