1
use scripting::host::{WasmHost, define_host_functions};
2
use scripting::nomiscript::{
3
    Compiler, Expr, Fraction, GIT_REVISION, Program, Reader, Symbol, SymbolKind, SymbolTable, Value,
4
};
5
use scripting_format::{
6
    ContextType, DebugValueData, EntityType, GlobalHeader, OUTPUT_HEADER_SIZE, OutputHeader,
7
    ValueType,
8
};
9
use thiserror::Error;
10
use tracing::debug;
11
use tracing_subscriber::EnvFilter;
12
use tracing_subscriber::prelude::*;
13
use tracing_subscriber::reload;
14
use wasmtime::{Config, Engine, Linker, Module, Store};
15

            
16
#[derive(Error, Debug)]
17
pub enum Error {
18
    #[error("{0}")]
19
    Script(#[from] scripting::nomiscript::Error),
20

            
21
    #[error("{0}")]
22
    Runtime(String),
23
}
24

            
25
impl Error {
26
    #[must_use]
27
    pub fn render(&self, use_color: bool) -> String {
28
        match self {
29
            Error::Script(e) => e.render(use_color),
30
            Error::Runtime(msg) => {
31
                if use_color {
32
                    format!("\x1b[31merror:\x1b[0m {msg}")
33
                } else {
34
                    format!("error: {msg}")
35
                }
36
            }
37
        }
38
    }
39
}
40

            
41
pub type Result<T> = std::result::Result<T, Error>;
42

            
43
pub struct Interpreter {
44
    host: WasmHost,
45
    compiler: Compiler,
46
    reload_handle: reload::Handle<EnvFilter, tracing_subscriber::Registry>,
47
}
48

            
49
const DEFAULT_OUTPUT_SIZE: u32 = 64 * 1024;
50
const WASM_PAGE_SIZE: u32 = 65536;
51

            
52
impl Interpreter {
53
2189
    pub fn new(debug_mode: bool) -> anyhow::Result<Self> {
54
2189
        let mut config = Config::new();
55
2189
        config.wasm_gc(true);
56
2189
        let engine = Engine::new(&config)?;
57

            
58
2189
        let mut symbols = SymbolTable::with_builtins();
59
2189
        symbols.define(
60
2189
            Symbol::new("REVISION", SymbolKind::Variable)
61
2189
                .with_value(Expr::String(GIT_REVISION.to_string())),
62
        );
63

            
64
2189
        let default_filter = if debug_mode { "debug" } else { "warn" };
65
2189
        let filter =
66
2189
            EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(default_filter));
67
2189
        let (filter_layer, reload_handle) = reload::Layer::new(filter);
68
2189
        tracing_subscriber::registry()
69
2189
            .with(filter_layer)
70
2189
            .with(tracing_subscriber::fmt::layer().with_writer(std::io::stderr))
71
2189
            .try_init()
72
2189
            .ok();
73

            
74
2189
        let mut interp = Self {
75
2189
            host: WasmHost::new(engine, symbols),
76
2189
            compiler: Compiler::new(),
77
2189
            reload_handle,
78
2189
        };
79
2189
        interp.load_stdlib()?;
80
2189
        Ok(interp)
81
2189
    }
82

            
83
2189
    fn load_stdlib(&mut self) -> Result<()> {
84
        const STDLIB: &str = include_str!("stdlib.lisp");
85
2189
        let program = Reader::parse(STDLIB)?;
86
2189
        let mut symbols = self
87
2189
            .host
88
2189
            .symbol_table()
89
2189
            .write()
90
2189
            .map_err(|e| Error::Runtime(format!("failed to write symbol table: {e}")))?;
91
2189
        self.compiler.compile(&program, &mut symbols)?;
92
2189
        Ok(())
93
2189
    }
94

            
95
2057
    pub fn eval(&mut self, input: &str) -> Result<Vec<Value>> {
96
2057
        let program = Reader::parse(input)?;
97
2057
        let debug_mode = program.annotations.iter().any(|a| a.name == "debug");
98
2057
        if debug_mode {
99
            self.reload_handle
100
                .modify(|f| *f = EnvFilter::new("debug"))
101
                .ok();
102
2057
        }
103
2057
        let result = self.eval_program(&program);
104
2057
        if debug_mode {
105
            self.reload_handle
106
                .modify(|f| {
107
                    *f = EnvFilter::try_from_default_env()
108
                        .unwrap_or_else(|_| EnvFilter::new("warn"));
109
                })
110
                .ok();
111
2057
        }
112
2057
        result
113
2057
    }
114

            
115
2057
    fn eval_program(&mut self, program: &Program) -> Result<Vec<Value>> {
116
2057
        if program.exprs.is_empty() {
117
11
            return Ok(vec![]);
118
2046
        }
119

            
120
1705
        let wasm = {
121
2046
            let mut symbols = self
122
2046
                .host
123
2046
                .symbol_table()
124
2046
                .write()
125
2046
                .map_err(|e| Error::Runtime(format!("failed to write symbol table: {e}")))?;
126
2046
            self.compiler.compile(program, &mut symbols)?
127
        };
128

            
129
1705
        let value = self.run_wasm(&wasm)?;
130
1705
        Ok(vec![value])
131
2057
    }
132

            
133
    #[must_use]
134
    pub fn struct_fields(&self, name: &str) -> Option<Vec<String>> {
135
        self.host
136
            .symbol_table()
137
            .read()
138
            .ok()
139
            .and_then(|st| st.struct_fields(name).map(<[std::string::String]>::to_vec))
140
    }
141

            
142
407
    pub fn compile_to_wasm(&mut self, input: &str) -> Result<Vec<u8>> {
143
407
        let program = Reader::parse(input)?;
144
407
        if program.exprs.is_empty() {
145
            return Err(Error::Runtime("nothing to compile".to_string()));
146
407
        }
147
407
        let mut symbols = self
148
407
            .host
149
407
            .symbol_table()
150
407
            .write()
151
407
            .map_err(|e| Error::Runtime(format!("failed to write symbol table: {e}")))?;
152
407
        Ok(self.compiler.compile(&program, &mut symbols)?)
153
407
    }
154

            
155
1738
    pub fn run_wasm(&self, wasm: &[u8]) -> Result<Value> {
156
1738
        let input = build_minimal_input(DEFAULT_OUTPUT_SIZE);
157
1738
        self.run_wasm_with_input(wasm, &input)
158
1738
    }
159

            
160
1947
    pub fn run_wasm_with_input(&self, wasm: &[u8], input: &[u8]) -> Result<Value> {
161
1947
        debug!(wasm_size = wasm.len(), "creating WASM module");
162
1947
        let module =
163
1947
            Module::new(self.host.engine(), wasm).map_err(|e| Error::Runtime(e.to_string()))?;
164

            
165
1947
        let output_size = DEFAULT_OUTPUT_SIZE;
166
1947
        let input_offset = scripting_format::BASE_OFFSET;
167
1947
        let output_offset = input_offset + input.len() as u32;
168
1947
        let strings_offset = {
169
1947
            let header = GlobalHeader::from_bytes(input).expect("minimal input must be valid");
170
1947
            header.strings_pool_offset
171
        };
172
1947
        debug!(
173
            input_offset,
174
            output_offset, strings_offset, "memory layout offsets"
175
        );
176

            
177
1947
        let exec_state = self
178
1947
            .host
179
1947
            .execution_state(input_offset, output_offset, strings_offset);
180
1947
        let mut store = Store::new(self.host.engine(), exec_state);
181

            
182
1947
        let mut linker = Linker::new(self.host.engine());
183
1947
        define_host_functions(&mut linker).map_err(|e| Error::Runtime(e.to_string()))?;
184

            
185
1947
        let instance = linker
186
1947
            .instantiate(&mut store, &module)
187
1947
            .map_err(|e| Error::Runtime(e.to_string()))?;
188

            
189
1947
        let memory = instance
190
1947
            .get_memory(&mut store, "memory")
191
1947
            .ok_or_else(|| Error::Runtime("no memory export".to_string()))?;
192

            
193
1947
        store.data_mut().memory = Some(memory);
194

            
195
1947
        let total_size = input.len() + output_size as usize;
196
1947
        let required_pages = (input_offset as usize + total_size).div_ceil(WASM_PAGE_SIZE as usize);
197
1947
        let current_pages = memory.size(&store) as usize;
198

            
199
1947
        if required_pages > current_pages {
200
1947
            debug!(current_pages, required_pages, "growing memory");
201
1947
            memory
202
1947
                .grow(&mut store, (required_pages - current_pages) as u64)
203
1947
                .map_err(|e| Error::Runtime(e.to_string()))?;
204
        }
205

            
206
1947
        let mem_data = memory.data_mut(&mut store);
207
1947
        let input_start = input_offset as usize;
208
1947
        mem_data[input_start..input_start + input.len()].copy_from_slice(input);
209

            
210
1947
        let output_start = output_offset as usize;
211
1947
        let output_header = OutputHeader::new(0);
212
1947
        mem_data[output_start..output_start + OUTPUT_HEADER_SIZE]
213
1947
            .copy_from_slice(&output_header.to_bytes());
214

            
215
1947
        let should_apply = instance
216
1947
            .get_typed_func::<(), i32>(&mut store, "should_apply")
217
1947
            .map_err(|e| Error::Runtime(e.to_string()))?;
218
1947
        debug!("calling should_apply");
219
1947
        let apply = should_apply
220
1947
            .call(&mut store, ())
221
1947
            .map_err(|e| Error::Runtime(e.to_string()))?;
222
1947
        debug!(result = apply, "should_apply returned");
223
1947
        if apply != 1 {
224
            return Err(Error::Runtime(format!(
225
                "should_apply returned {apply}, expected 1"
226
            )));
227
1947
        }
228

            
229
1947
        let process = instance
230
1947
            .get_typed_func::<(), ()>(&mut store, "process")
231
1947
            .map_err(|e| Error::Runtime(e.to_string()))?;
232
1947
        debug!("calling process");
233
1947
        process
234
1947
            .call(&mut store, ())
235
1947
            .map_err(|e| Error::Runtime(e.to_string()))?;
236

            
237
1947
        let mem_data = memory.data(&store);
238
1947
        let output_data = &mem_data[output_start..];
239

            
240
1947
        let result = decode_result(output_data);
241
1947
        debug!("result decoded");
242
1947
        result
243
1947
    }
244

            
245
    pub fn run_wasm_with_input_raw(&self, wasm: &[u8], input: &[u8]) -> Result<Vec<u8>> {
246
        let module =
247
            Module::new(self.host.engine(), wasm).map_err(|e| Error::Runtime(e.to_string()))?;
248

            
249
        let output_size = DEFAULT_OUTPUT_SIZE;
250
        let input_offset = scripting_format::BASE_OFFSET;
251
        let output_offset = input_offset + input.len() as u32;
252
        let strings_offset = {
253
            let header = GlobalHeader::from_bytes(input).expect("minimal input must be valid");
254
            header.strings_pool_offset
255
        };
256

            
257
        let exec_state = self
258
            .host
259
            .execution_state(input_offset, output_offset, strings_offset);
260
        let mut store = Store::new(self.host.engine(), exec_state);
261

            
262
        let mut linker = Linker::new(self.host.engine());
263
        define_host_functions(&mut linker).map_err(|e| Error::Runtime(e.to_string()))?;
264

            
265
        let instance = linker
266
            .instantiate(&mut store, &module)
267
            .map_err(|e| Error::Runtime(e.to_string()))?;
268

            
269
        let memory = instance
270
            .get_memory(&mut store, "memory")
271
            .ok_or_else(|| Error::Runtime("no memory export".to_string()))?;
272

            
273
        store.data_mut().memory = Some(memory);
274

            
275
        let total_size = input.len() + output_size as usize;
276
        let required_pages = (input_offset as usize + total_size).div_ceil(WASM_PAGE_SIZE as usize);
277
        let current_pages = memory.size(&store) as usize;
278

            
279
        if required_pages > current_pages {
280
            memory
281
                .grow(&mut store, (required_pages - current_pages) as u64)
282
                .map_err(|e| Error::Runtime(e.to_string()))?;
283
        }
284

            
285
        let mem_data = memory.data_mut(&mut store);
286
        let input_start = input_offset as usize;
287
        mem_data[input_start..input_start + input.len()].copy_from_slice(input);
288

            
289
        let output_start = output_offset as usize;
290
        let output_header = OutputHeader::new(0);
291
        mem_data[output_start..output_start + OUTPUT_HEADER_SIZE]
292
            .copy_from_slice(&output_header.to_bytes());
293

            
294
        let should_apply = instance
295
            .get_typed_func::<(), i32>(&mut store, "should_apply")
296
            .map_err(|e| Error::Runtime(e.to_string()))?;
297
        let apply = should_apply
298
            .call(&mut store, ())
299
            .map_err(|e| Error::Runtime(e.to_string()))?;
300
        if apply != 1 {
301
            return Err(Error::Runtime(format!(
302
                "should_apply returned {apply}, expected 1"
303
            )));
304
        }
305

            
306
        let process = instance
307
            .get_typed_func::<(), ()>(&mut store, "process")
308
            .map_err(|e| Error::Runtime(e.to_string()))?;
309
        process
310
            .call(&mut store, ())
311
            .map_err(|e| Error::Runtime(e.to_string()))?;
312

            
313
        let mem_data = memory.data(&store);
314
        Ok(mem_data[output_start..output_start + output_size as usize].to_vec())
315
    }
316
}
317

            
318
impl Default for Interpreter {
319
    fn default() -> Self {
320
        Self::new(false).expect("failed to create Interpreter")
321
    }
322
}
323

            
324
1738
fn build_minimal_input(output_size: u32) -> Vec<u8> {
325
    use scripting::MemorySerializer;
326
1738
    let mut ser = MemorySerializer::new();
327
1738
    ser.set_context(ContextType::BatchProcess, EntityType::Transaction);
328
1738
    ser.finalize(output_size)
329
1738
}
330

            
331
1947
fn decode_result(data: &[u8]) -> Result<Value> {
332
1947
    let output_header = OutputHeader::from_bytes(data)
333
1947
        .ok_or_else(|| Error::Runtime("invalid output header".to_string()))?;
334

            
335
1947
    if output_header.output_entity_count == 0 {
336
        return Err(Error::Runtime("no output entities".to_string()));
337
1947
    }
338

            
339
1947
    let entity_header = scripting_format::EntityHeader::from_bytes(&data[OUTPUT_HEADER_SIZE..])
340
1947
        .ok_or_else(|| Error::Runtime("invalid entity header".to_string()))?;
341

            
342
1947
    let data_offset = entity_header.data_offset as usize;
343
1947
    let value_data = DebugValueData::from_bytes(&data[data_offset..])
344
1947
        .ok_or_else(|| Error::Runtime("invalid debug value data".to_string()))?;
345

            
346
1947
    let value_type = ValueType::try_from(value_data.value_type)
347
1947
        .map_err(|()| Error::Runtime("unknown value type".to_string()))?;
348

            
349
1947
    match value_type {
350
264
        ValueType::Nil => Ok(Value::Nil),
351
198
        ValueType::Bool => Ok(Value::Bool(value_data.data1 != 0)),
352
1100
        ValueType::Number => Ok(Value::Number(Fraction::new(
353
1100
            value_data.data1,
354
1100
            value_data.data2,
355
1100
        ))),
356
        ValueType::String | ValueType::Symbol => {
357
385
            let pool_offset = value_data.data1 as usize;
358
385
            let len = value_data.data2 as usize;
359
385
            let strings_offset = output_header.strings_offset as usize;
360
385
            let start = strings_offset + pool_offset;
361
385
            let end = start + len;
362
385
            if end > data.len() {
363
                return Err(Error::Runtime("string data out of bounds".to_string()));
364
385
            }
365
385
            let s = std::str::from_utf8(&data[start..end])
366
385
                .map_err(|_| Error::Runtime("invalid UTF-8 in string".to_string()))?;
367
385
            if value_type == ValueType::Symbol {
368
77
                Ok(Value::Symbol(s.to_string()))
369
            } else {
370
308
                Ok(Value::String(s.to_string()))
371
            }
372
        }
373
    }
374
1947
}