1
use std::collections::HashMap;
2
use std::sync::{Arc, Mutex};
3

            
4
use wasmtime::{Engine, Linker, Module, Store};
5

            
6
use crate::error::HookError;
7
use crate::format::{BASE_OFFSET, GlobalHeader, OUTPUT_HEADER_SIZE, OutputHeader};
8
use crate::host::{HostState, define_host_functions};
9
use crate::parser::{OutputParser, ParsedEntity};
10

            
11
const DEFAULT_OUTPUT_SIZE: u32 = 64 * 1024;
12
const WASM_PAGE_SIZE: u32 = 65536;
13

            
14
pub struct ScriptExecutor {
15
    engine: Engine,
16
    module_cache: Arc<Mutex<HashMap<Vec<u8>, Module>>>,
17
}
18

            
19
impl Default for ScriptExecutor {
20
    fn default() -> Self {
21
        Self::new()
22
    }
23
}
24

            
25
impl ScriptExecutor {
26
    #[must_use]
27
49
    pub fn new() -> Self {
28
49
        Self {
29
49
            engine: Engine::default(),
30
49
            module_cache: Arc::new(Mutex::new(HashMap::new())),
31
49
        }
32
49
    }
33

            
34
    #[must_use]
35
    pub fn with_engine(engine: Engine) -> Self {
36
        Self {
37
            engine,
38
            module_cache: Arc::new(Mutex::new(HashMap::new())),
39
        }
40
    }
41

            
42
48
    fn get_or_compile_module(&self, bytecode: &[u8]) -> Result<Module, HookError> {
43
48
        let cache = self.module_cache.lock()?;
44
48
        if let Some(module) = cache.get(bytecode) {
45
            return Ok(module.clone());
46
48
        }
47
48
        drop(cache);
48

            
49
48
        let module = Module::new(&self.engine, bytecode)?;
50

            
51
48
        let mut cache = self.module_cache.lock()?;
52
48
        cache.insert(bytecode.to_vec(), module.clone());
53
48
        Ok(module)
54
48
    }
55

            
56
48
    pub fn execute(
57
48
        &self,
58
48
        bytecode: &[u8],
59
48
        input: &[u8],
60
48
        output_size: Option<u32>,
61
48
    ) -> Result<Vec<ParsedEntity>, HookError> {
62
48
        let output_size = output_size.unwrap_or(DEFAULT_OUTPUT_SIZE);
63
48
        let module = self.get_or_compile_module(bytecode)?;
64

            
65
48
        let header = GlobalHeader::from_bytes(input)
66
48
            .ok_or_else(|| HookError::Parse("Invalid input header".to_string()))?;
67

            
68
48
        let input_offset = BASE_OFFSET;
69
48
        let output_offset = input_offset + input.len() as u32;
70
48
        let strings_offset = header.strings_pool_offset;
71

            
72
48
        let host_state = HostState::new(input_offset, output_offset, strings_offset);
73
48
        let mut store = Store::new(&self.engine, host_state);
74

            
75
48
        let mut linker = Linker::new(&self.engine);
76
48
        define_host_functions(&mut linker)?;
77

            
78
48
        let instance = linker.instantiate(&mut store, &module)?;
79

            
80
48
        let memory = instance
81
48
            .get_memory(&mut store, "memory")
82
48
            .ok_or(HookError::WASMMem)?;
83

            
84
48
        store.data_mut().memory = Some(memory);
85

            
86
48
        let total_size = input.len() + output_size as usize;
87
48
        let required_pages = (BASE_OFFSET as usize + total_size).div_ceil(WASM_PAGE_SIZE as usize);
88
48
        let current_pages = memory.size(&store) as usize;
89

            
90
48
        if required_pages > current_pages {
91
            memory.grow(&mut store, (required_pages - current_pages) as u64)?;
92
48
        }
93

            
94
48
        let mem_data = memory.data_mut(&mut store);
95
48
        let input_start = BASE_OFFSET as usize;
96
48
        mem_data[input_start..input_start + input.len()].copy_from_slice(input);
97

            
98
48
        let output_start = output_offset as usize;
99
48
        let input_entity_count = header.input_entity_count;
100
48
        let output_header = OutputHeader::new(input_entity_count);
101
48
        mem_data[output_start..output_start + OUTPUT_HEADER_SIZE]
102
48
            .copy_from_slice(&output_header.to_bytes());
103

            
104
48
        let should_apply = instance
105
48
            .get_typed_func::<(), i32>(&mut store, "should_apply")
106
48
            .map_err(|e| HookError::Script(format!("Missing should_apply export: {e}")))?;
107

            
108
48
        let result = should_apply.call(&mut store, ())?;
109
48
        if result == 0 {
110
32
            return Ok(Vec::new());
111
16
        }
112

            
113
16
        let process = instance
114
16
            .get_typed_func::<(), ()>(&mut store, "process")
115
16
            .map_err(|e| HookError::Script(format!("Missing process export: {e}")))?;
116

            
117
16
        process.call(&mut store, ())?;
118

            
119
16
        let mem_data = memory.data(&store);
120
16
        let output_data = &mem_data[output_start..output_start + output_size as usize];
121

            
122
16
        let output_header = OutputHeader::from_bytes(output_data)
123
16
            .ok_or_else(|| HookError::Parse("Invalid output header".to_string()))?;
124

            
125
16
        let output_strings_offset = { output_header.strings_offset } as usize;
126

            
127
16
        let parser = OutputParser::new(output_data, output_strings_offset)?;
128
16
        parser.entities().collect()
129
48
    }
130
}
131

            
132
#[cfg(test)]
133
mod tests {
134
    use super::*;
135

            
136
    #[test]
137
1
    fn test_executor_creation() {
138
1
        let executor = ScriptExecutor::new();
139
1
        assert!(executor.module_cache.lock().unwrap().is_empty());
140
1
    }
141
}