Skip to main content

nms/
interpreter.rs

1use scripting::host::{WasmHost, define_host_functions};
2use scripting::nomiscript::{
3    Compiler, Expr, Fraction, GIT_REVISION, Program, Reader, Symbol, SymbolKind, SymbolTable, Value,
4};
5use scripting_format::{
6    ContextType, DebugValueData, EntityType, GlobalHeader, OUTPUT_HEADER_SIZE, OutputHeader,
7    ValueType,
8};
9use thiserror::Error;
10use tracing::debug;
11use tracing_subscriber::EnvFilter;
12use tracing_subscriber::prelude::*;
13use tracing_subscriber::reload;
14use wasmtime::{Config, Engine, Linker, Module, Store};
15
16#[derive(Error, Debug)]
17pub enum Error {
18    #[error("{0}")]
19    Script(#[from] scripting::nomiscript::Error),
20
21    #[error("{0}")]
22    Runtime(String),
23}
24
25impl Error {
26    #[must_use]
27    pub fn render(&self, use_color: bool) -> String {
28        match self {
29            Error::Script(e) => e.render(use_color),
30            Error::Runtime(msg) => {
31                if use_color {
32                    format!("\x1b[31merror:\x1b[0m {msg}")
33                } else {
34                    format!("error: {msg}")
35                }
36            }
37        }
38    }
39}
40
41pub type Result<T> = std::result::Result<T, Error>;
42
43pub struct Interpreter {
44    host: WasmHost,
45    compiler: Compiler,
46    reload_handle: reload::Handle<EnvFilter, tracing_subscriber::Registry>,
47}
48
49const DEFAULT_OUTPUT_SIZE: u32 = 64 * 1024;
50const WASM_PAGE_SIZE: u32 = 65536;
51
52impl Interpreter {
53    pub fn new(debug_mode: bool) -> anyhow::Result<Self> {
54        let mut config = Config::new();
55        config.wasm_gc(true);
56        let engine = Engine::new(&config)?;
57
58        let mut symbols = SymbolTable::with_builtins();
59        symbols.define(
60            Symbol::new("REVISION", SymbolKind::Variable)
61                .with_value(Expr::String(GIT_REVISION.to_string())),
62        );
63
64        let default_filter = if debug_mode { "debug" } else { "warn" };
65        let filter =
66            EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(default_filter));
67        let (filter_layer, reload_handle) = reload::Layer::new(filter);
68        tracing_subscriber::registry()
69            .with(filter_layer)
70            .with(tracing_subscriber::fmt::layer().with_writer(std::io::stderr))
71            .try_init()
72            .ok();
73
74        let mut interp = Self {
75            host: WasmHost::new(engine, symbols),
76            compiler: Compiler::new(),
77            reload_handle,
78        };
79        interp.load_stdlib()?;
80        Ok(interp)
81    }
82
83    fn load_stdlib(&mut self) -> Result<()> {
84        const STDLIB: &str = include_str!("stdlib.lisp");
85        let program = Reader::parse(STDLIB)?;
86        let mut symbols = self
87            .host
88            .symbol_table()
89            .write()
90            .map_err(|e| Error::Runtime(format!("failed to write symbol table: {e}")))?;
91        self.compiler.compile(&program, &mut symbols)?;
92        Ok(())
93    }
94
95    pub fn eval(&mut self, input: &str) -> Result<Vec<Value>> {
96        let program = Reader::parse(input)?;
97        let debug_mode = program.annotations.iter().any(|a| a.name == "debug");
98        if debug_mode {
99            self.reload_handle
100                .modify(|f| *f = EnvFilter::new("debug"))
101                .ok();
102        }
103        let result = self.eval_program(&program);
104        if debug_mode {
105            self.reload_handle
106                .modify(|f| {
107                    *f = EnvFilter::try_from_default_env()
108                        .unwrap_or_else(|_| EnvFilter::new("warn"));
109                })
110                .ok();
111        }
112        result
113    }
114
115    fn eval_program(&mut self, program: &Program) -> Result<Vec<Value>> {
116        if program.exprs.is_empty() {
117            return Ok(vec![]);
118        }
119
120        let wasm = {
121            let mut symbols = self
122                .host
123                .symbol_table()
124                .write()
125                .map_err(|e| Error::Runtime(format!("failed to write symbol table: {e}")))?;
126            self.compiler.compile(program, &mut symbols)?
127        };
128
129        let value = self.run_wasm(&wasm)?;
130        Ok(vec![value])
131    }
132
133    #[must_use]
134    pub fn struct_fields(&self, name: &str) -> Option<Vec<String>> {
135        self.host
136            .symbol_table()
137            .read()
138            .ok()
139            .and_then(|st| st.struct_fields(name).map(<[std::string::String]>::to_vec))
140    }
141
142    pub fn compile_to_wasm(&mut self, input: &str) -> Result<Vec<u8>> {
143        let program = Reader::parse(input)?;
144        if program.exprs.is_empty() {
145            return Err(Error::Runtime("nothing to compile".to_string()));
146        }
147        let mut symbols = self
148            .host
149            .symbol_table()
150            .write()
151            .map_err(|e| Error::Runtime(format!("failed to write symbol table: {e}")))?;
152        Ok(self.compiler.compile(&program, &mut symbols)?)
153    }
154
155    pub fn run_wasm(&self, wasm: &[u8]) -> Result<Value> {
156        let input = build_minimal_input(DEFAULT_OUTPUT_SIZE);
157        self.run_wasm_with_input(wasm, &input)
158    }
159
160    pub fn run_wasm_with_input(&self, wasm: &[u8], input: &[u8]) -> Result<Value> {
161        debug!(wasm_size = wasm.len(), "creating WASM module");
162        let module =
163            Module::new(self.host.engine(), wasm).map_err(|e| Error::Runtime(e.to_string()))?;
164
165        let output_size = DEFAULT_OUTPUT_SIZE;
166        let input_offset = scripting_format::BASE_OFFSET;
167        let output_offset = input_offset + input.len() as u32;
168        let strings_offset = {
169            let header = GlobalHeader::from_bytes(input).expect("minimal input must be valid");
170            header.strings_pool_offset
171        };
172        debug!(
173            input_offset,
174            output_offset, strings_offset, "memory layout offsets"
175        );
176
177        let exec_state = self
178            .host
179            .execution_state(input_offset, output_offset, strings_offset);
180        let mut store = Store::new(self.host.engine(), exec_state);
181
182        let mut linker = Linker::new(self.host.engine());
183        define_host_functions(&mut linker).map_err(|e| Error::Runtime(e.to_string()))?;
184
185        let instance = linker
186            .instantiate(&mut store, &module)
187            .map_err(|e| Error::Runtime(e.to_string()))?;
188
189        let memory = instance
190            .get_memory(&mut store, "memory")
191            .ok_or_else(|| Error::Runtime("no memory export".to_string()))?;
192
193        store.data_mut().memory = Some(memory);
194
195        let total_size = input.len() + output_size as usize;
196        let required_pages = (input_offset as usize + total_size).div_ceil(WASM_PAGE_SIZE as usize);
197        let current_pages = memory.size(&store) as usize;
198
199        if required_pages > current_pages {
200            debug!(current_pages, required_pages, "growing memory");
201            memory
202                .grow(&mut store, (required_pages - current_pages) as u64)
203                .map_err(|e| Error::Runtime(e.to_string()))?;
204        }
205
206        let mem_data = memory.data_mut(&mut store);
207        let input_start = input_offset as usize;
208        mem_data[input_start..input_start + input.len()].copy_from_slice(input);
209
210        let output_start = output_offset as usize;
211        let output_header = OutputHeader::new(0);
212        mem_data[output_start..output_start + OUTPUT_HEADER_SIZE]
213            .copy_from_slice(&output_header.to_bytes());
214
215        let should_apply = instance
216            .get_typed_func::<(), i32>(&mut store, "should_apply")
217            .map_err(|e| Error::Runtime(e.to_string()))?;
218        debug!("calling should_apply");
219        let apply = should_apply
220            .call(&mut store, ())
221            .map_err(|e| Error::Runtime(e.to_string()))?;
222        debug!(result = apply, "should_apply returned");
223        if apply != 1 {
224            return Err(Error::Runtime(format!(
225                "should_apply returned {apply}, expected 1"
226            )));
227        }
228
229        let process = instance
230            .get_typed_func::<(), ()>(&mut store, "process")
231            .map_err(|e| Error::Runtime(e.to_string()))?;
232        debug!("calling process");
233        process
234            .call(&mut store, ())
235            .map_err(|e| Error::Runtime(e.to_string()))?;
236
237        let mem_data = memory.data(&store);
238        let output_data = &mem_data[output_start..];
239
240        let result = decode_result(output_data);
241        debug!("result decoded");
242        result
243    }
244
245    pub fn run_wasm_with_input_raw(&self, wasm: &[u8], input: &[u8]) -> Result<Vec<u8>> {
246        let module =
247            Module::new(self.host.engine(), wasm).map_err(|e| Error::Runtime(e.to_string()))?;
248
249        let output_size = DEFAULT_OUTPUT_SIZE;
250        let input_offset = scripting_format::BASE_OFFSET;
251        let output_offset = input_offset + input.len() as u32;
252        let strings_offset = {
253            let header = GlobalHeader::from_bytes(input).expect("minimal input must be valid");
254            header.strings_pool_offset
255        };
256
257        let exec_state = self
258            .host
259            .execution_state(input_offset, output_offset, strings_offset);
260        let mut store = Store::new(self.host.engine(), exec_state);
261
262        let mut linker = Linker::new(self.host.engine());
263        define_host_functions(&mut linker).map_err(|e| Error::Runtime(e.to_string()))?;
264
265        let instance = linker
266            .instantiate(&mut store, &module)
267            .map_err(|e| Error::Runtime(e.to_string()))?;
268
269        let memory = instance
270            .get_memory(&mut store, "memory")
271            .ok_or_else(|| Error::Runtime("no memory export".to_string()))?;
272
273        store.data_mut().memory = Some(memory);
274
275        let total_size = input.len() + output_size as usize;
276        let required_pages = (input_offset as usize + total_size).div_ceil(WASM_PAGE_SIZE as usize);
277        let current_pages = memory.size(&store) as usize;
278
279        if required_pages > current_pages {
280            memory
281                .grow(&mut store, (required_pages - current_pages) as u64)
282                .map_err(|e| Error::Runtime(e.to_string()))?;
283        }
284
285        let mem_data = memory.data_mut(&mut store);
286        let input_start = input_offset as usize;
287        mem_data[input_start..input_start + input.len()].copy_from_slice(input);
288
289        let output_start = output_offset as usize;
290        let output_header = OutputHeader::new(0);
291        mem_data[output_start..output_start + OUTPUT_HEADER_SIZE]
292            .copy_from_slice(&output_header.to_bytes());
293
294        let should_apply = instance
295            .get_typed_func::<(), i32>(&mut store, "should_apply")
296            .map_err(|e| Error::Runtime(e.to_string()))?;
297        let apply = should_apply
298            .call(&mut store, ())
299            .map_err(|e| Error::Runtime(e.to_string()))?;
300        if apply != 1 {
301            return Err(Error::Runtime(format!(
302                "should_apply returned {apply}, expected 1"
303            )));
304        }
305
306        let process = instance
307            .get_typed_func::<(), ()>(&mut store, "process")
308            .map_err(|e| Error::Runtime(e.to_string()))?;
309        process
310            .call(&mut store, ())
311            .map_err(|e| Error::Runtime(e.to_string()))?;
312
313        let mem_data = memory.data(&store);
314        Ok(mem_data[output_start..output_start + output_size as usize].to_vec())
315    }
316}
317
318impl Default for Interpreter {
319    fn default() -> Self {
320        Self::new(false).expect("failed to create Interpreter")
321    }
322}
323
324fn build_minimal_input(output_size: u32) -> Vec<u8> {
325    use scripting::MemorySerializer;
326    let mut ser = MemorySerializer::new();
327    ser.set_context(ContextType::BatchProcess, EntityType::Transaction);
328    ser.finalize(output_size)
329}
330
331fn decode_result(data: &[u8]) -> Result<Value> {
332    let output_header = OutputHeader::from_bytes(data)
333        .ok_or_else(|| Error::Runtime("invalid output header".to_string()))?;
334
335    if output_header.output_entity_count == 0 {
336        return Err(Error::Runtime("no output entities".to_string()));
337    }
338
339    let entity_header = scripting_format::EntityHeader::from_bytes(&data[OUTPUT_HEADER_SIZE..])
340        .ok_or_else(|| Error::Runtime("invalid entity header".to_string()))?;
341
342    let data_offset = entity_header.data_offset as usize;
343    let value_data = DebugValueData::from_bytes(&data[data_offset..])
344        .ok_or_else(|| Error::Runtime("invalid debug value data".to_string()))?;
345
346    let value_type = ValueType::try_from(value_data.value_type)
347        .map_err(|()| Error::Runtime("unknown value type".to_string()))?;
348
349    match value_type {
350        ValueType::Nil => Ok(Value::Nil),
351        ValueType::Bool => Ok(Value::Bool(value_data.data1 != 0)),
352        ValueType::Number => Ok(Value::Number(Fraction::new(
353            value_data.data1,
354            value_data.data2,
355        ))),
356        ValueType::String | ValueType::Symbol => {
357            let pool_offset = value_data.data1 as usize;
358            let len = value_data.data2 as usize;
359            let strings_offset = output_header.strings_offset as usize;
360            let start = strings_offset + pool_offset;
361            let end = start + len;
362            if end > data.len() {
363                return Err(Error::Runtime("string data out of bounds".to_string()));
364            }
365            let s = std::str::from_utf8(&data[start..end])
366                .map_err(|_| Error::Runtime("invalid UTF-8 in string".to_string()))?;
367            if value_type == ValueType::Symbol {
368                Ok(Value::Symbol(s.to_string()))
369            } else {
370                Ok(Value::String(s.to_string()))
371            }
372        }
373    }
374}