1
use scripting::host::{WasmHost, define_host_functions};
2
use scripting::nomiscript::{
3
    Compiler, Expr, Fraction, GIT_REVISION, Program, Reader, Symbol, SymbolKind, SymbolTable, Value,
4
};
5
use scripting_format::{
6
    ContextType, DebugValueData, EntityType, GlobalHeader, OUTPUT_HEADER_SIZE, OutputHeader,
7
    ValueType,
8
};
9
use thiserror::Error;
10
use tracing::debug;
11
use tracing_subscriber::EnvFilter;
12
use tracing_subscriber::prelude::*;
13
use tracing_subscriber::reload;
14
use wasmtime::{Config, Engine, Linker, Module, Store};
15

            
16
#[derive(Error, Debug)]
17
pub enum Error {
18
    #[error("{0}")]
19
    Script(#[from] scripting::nomiscript::Error),
20

            
21
    #[error("{0}")]
22
    Runtime(String),
23
}
24

            
25
impl Error {
26
    #[must_use]
27
    pub fn render(&self, use_color: bool) -> String {
28
        match self {
29
            Error::Script(e) => e.render(use_color),
30
            Error::Runtime(msg) => {
31
                if use_color {
32
                    format!("\x1b[31merror:\x1b[0m {msg}")
33
                } else {
34
                    format!("error: {msg}")
35
                }
36
            }
37
        }
38
    }
39
}
40

            
41
pub type Result<T> = std::result::Result<T, Error>;
42

            
43
pub struct Interpreter {
44
    host: WasmHost,
45
    compiler: Compiler,
46
    reload_handle: reload::Handle<EnvFilter, tracing_subscriber::Registry>,
47
}
48

            
49
const DEFAULT_OUTPUT_SIZE: u32 = 64 * 1024;
50
const WASM_PAGE_SIZE: u32 = 65536;
51

            
52
impl Interpreter {
53
1826
    pub fn new(debug_mode: bool) -> anyhow::Result<Self> {
54
1826
        let mut config = Config::new();
55
1826
        config.wasm_gc(true);
56
1826
        let engine = Engine::new(&config)?;
57

            
58
1826
        let mut symbols = SymbolTable::with_builtins();
59
1826
        symbols.define(
60
1826
            Symbol::new("REVISION", SymbolKind::Variable)
61
1826
                .with_value(Expr::String(GIT_REVISION.to_string())),
62
        );
63

            
64
1826
        let default_filter = if debug_mode { "debug" } else { "warn" };
65
1826
        let filter =
66
1826
            EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(default_filter));
67
1826
        let (filter_layer, reload_handle) = reload::Layer::new(filter);
68
1826
        tracing_subscriber::registry()
69
1826
            .with(filter_layer)
70
1826
            .with(tracing_subscriber::fmt::layer().with_writer(std::io::stderr))
71
1826
            .try_init()
72
1826
            .ok();
73

            
74
1826
        let mut interp = Self {
75
1826
            host: WasmHost::new(engine, symbols),
76
1826
            compiler: Compiler::new(),
77
1826
            reload_handle,
78
1826
        };
79
1826
        interp.load_stdlib()?;
80
1826
        Ok(interp)
81
1826
    }
82

            
83
1826
    fn load_stdlib(&mut self) -> Result<()> {
84
        const STDLIB: &str = include_str!("stdlib.lisp");
85
1826
        let program = Reader::parse(STDLIB)?;
86
1826
        let mut symbols = self
87
1826
            .host
88
1826
            .symbol_table()
89
1826
            .write()
90
1826
            .map_err(|e| Error::Runtime(format!("failed to write symbol table: {e}")))?;
91
1826
        self.compiler.compile(&program, &mut symbols)?;
92
1826
        Ok(())
93
1826
    }
94

            
95
2035
    pub fn eval(&mut self, input: &str) -> Result<Vec<Value>> {
96
2035
        let program = Reader::parse(input)?;
97
2035
        let debug_mode = program.annotations.iter().any(|a| a.name == "debug");
98
2035
        if debug_mode {
99
            self.reload_handle
100
                .modify(|f| *f = EnvFilter::new("debug"))
101
                .ok();
102
2035
        }
103
2035
        let result = self.eval_program(&program);
104
2035
        if debug_mode {
105
            self.reload_handle
106
                .modify(|f| {
107
                    *f = EnvFilter::try_from_default_env()
108
                        .unwrap_or_else(|_| EnvFilter::new("warn"));
109
                })
110
                .ok();
111
2035
        }
112
2035
        result
113
2035
    }
114

            
115
2035
    fn eval_program(&mut self, program: &Program) -> Result<Vec<Value>> {
116
2035
        if program.exprs.is_empty() {
117
11
            return Ok(vec![]);
118
2024
        }
119

            
120
1683
        let wasm = {
121
2024
            let mut symbols = self
122
2024
                .host
123
2024
                .symbol_table()
124
2024
                .write()
125
2024
                .map_err(|e| Error::Runtime(format!("failed to write symbol table: {e}")))?;
126
2024
            self.compiler.compile(program, &mut symbols)?
127
        };
128

            
129
1683
        let value = self.run_wasm(&wasm)?;
130
1683
        Ok(vec![value])
131
2035
    }
132

            
133
    #[must_use]
134
    pub fn struct_fields(&self, name: &str) -> Option<Vec<String>> {
135
        self.host
136
            .symbol_table()
137
            .read()
138
            .ok()
139
            .and_then(|st| st.struct_fields(name).map(<[std::string::String]>::to_vec))
140
    }
141

            
142
66
    pub fn compile_to_wasm(&mut self, input: &str) -> Result<Vec<u8>> {
143
66
        let program = Reader::parse(input)?;
144
66
        if program.exprs.is_empty() {
145
            return Err(Error::Runtime("nothing to compile".to_string()));
146
66
        }
147
66
        let mut symbols = self
148
66
            .host
149
66
            .symbol_table()
150
66
            .write()
151
66
            .map_err(|e| Error::Runtime(format!("failed to write symbol table: {e}")))?;
152
66
        Ok(self.compiler.compile(&program, &mut symbols)?)
153
66
    }
154

            
155
1716
    pub fn run_wasm(&self, wasm: &[u8]) -> Result<Value> {
156
1716
        debug!(wasm_size = wasm.len(), "creating WASM module");
157
1716
        let module =
158
1716
            Module::new(self.host.engine(), wasm).map_err(|e| Error::Runtime(e.to_string()))?;
159

            
160
1716
        let output_size = DEFAULT_OUTPUT_SIZE;
161
1716
        let input = build_minimal_input(output_size);
162
1716
        let input_offset = scripting_format::BASE_OFFSET;
163
1716
        let output_offset = input_offset + input.len() as u32;
164
1716
        let strings_offset = {
165
1716
            let header = GlobalHeader::from_bytes(&input).expect("minimal input must be valid");
166
1716
            header.strings_pool_offset
167
        };
168
1716
        debug!(
169
            input_offset,
170
            output_offset, strings_offset, "memory layout offsets"
171
        );
172

            
173
1716
        let exec_state = self
174
1716
            .host
175
1716
            .execution_state(input_offset, output_offset, strings_offset);
176
1716
        let mut store = Store::new(self.host.engine(), exec_state);
177

            
178
1716
        let mut linker = Linker::new(self.host.engine());
179
1716
        define_host_functions(&mut linker).map_err(|e| Error::Runtime(e.to_string()))?;
180

            
181
1716
        let instance = linker
182
1716
            .instantiate(&mut store, &module)
183
1716
            .map_err(|e| Error::Runtime(e.to_string()))?;
184

            
185
1716
        let memory = instance
186
1716
            .get_memory(&mut store, "memory")
187
1716
            .ok_or_else(|| Error::Runtime("no memory export".to_string()))?;
188

            
189
1716
        store.data_mut().memory = Some(memory);
190

            
191
1716
        let total_size = input.len() + output_size as usize;
192
1716
        let required_pages = (input_offset as usize + total_size).div_ceil(WASM_PAGE_SIZE as usize);
193
1716
        let current_pages = memory.size(&store) as usize;
194

            
195
1716
        if required_pages > current_pages {
196
1716
            debug!(current_pages, required_pages, "growing memory");
197
1716
            memory
198
1716
                .grow(&mut store, (required_pages - current_pages) as u64)
199
1716
                .map_err(|e| Error::Runtime(e.to_string()))?;
200
        }
201

            
202
1716
        let mem_data = memory.data_mut(&mut store);
203
1716
        let input_start = input_offset as usize;
204
1716
        mem_data[input_start..input_start + input.len()].copy_from_slice(&input);
205

            
206
1716
        let output_start = output_offset as usize;
207
1716
        let output_header = OutputHeader::new(0);
208
1716
        mem_data[output_start..output_start + OUTPUT_HEADER_SIZE]
209
1716
            .copy_from_slice(&output_header.to_bytes());
210

            
211
1716
        let should_apply = instance
212
1716
            .get_typed_func::<(), i32>(&mut store, "should_apply")
213
1716
            .map_err(|e| Error::Runtime(e.to_string()))?;
214
1716
        debug!("calling should_apply");
215
1716
        let apply = should_apply
216
1716
            .call(&mut store, ())
217
1716
            .map_err(|e| Error::Runtime(e.to_string()))?;
218
1716
        debug!(result = apply, "should_apply returned");
219
1716
        if apply != 1 {
220
            return Err(Error::Runtime(format!(
221
                "should_apply returned {apply}, expected 1"
222
            )));
223
1716
        }
224

            
225
1716
        let process = instance
226
1716
            .get_typed_func::<(), ()>(&mut store, "process")
227
1716
            .map_err(|e| Error::Runtime(e.to_string()))?;
228
1716
        debug!("calling process");
229
1716
        process
230
1716
            .call(&mut store, ())
231
1716
            .map_err(|e| Error::Runtime(e.to_string()))?;
232

            
233
1716
        let mem_data = memory.data(&store);
234
1716
        let output_data = &mem_data[output_start..];
235

            
236
1716
        let result = decode_result(output_data);
237
1716
        debug!("result decoded");
238
1716
        result
239
1716
    }
240
}
241

            
242
impl Default for Interpreter {
243
    fn default() -> Self {
244
        Self::new(false).expect("failed to create Interpreter")
245
    }
246
}
247

            
248
1716
fn build_minimal_input(output_size: u32) -> Vec<u8> {
249
    use scripting::MemorySerializer;
250
1716
    let mut ser = MemorySerializer::new();
251
1716
    ser.set_context(ContextType::BatchProcess, EntityType::Transaction);
252
1716
    ser.finalize(output_size)
253
1716
}
254

            
255
1716
fn decode_result(data: &[u8]) -> Result<Value> {
256
1716
    let output_header = OutputHeader::from_bytes(data)
257
1716
        .ok_or_else(|| Error::Runtime("invalid output header".to_string()))?;
258

            
259
1716
    if output_header.output_entity_count == 0 {
260
        return Err(Error::Runtime("no output entities".to_string()));
261
1716
    }
262

            
263
1716
    let entity_header = scripting_format::EntityHeader::from_bytes(&data[OUTPUT_HEADER_SIZE..])
264
1716
        .ok_or_else(|| Error::Runtime("invalid entity header".to_string()))?;
265

            
266
1716
    let data_offset = entity_header.data_offset as usize;
267
1716
    let value_data = DebugValueData::from_bytes(&data[data_offset..])
268
1716
        .ok_or_else(|| Error::Runtime("invalid debug value data".to_string()))?;
269

            
270
1716
    let value_type = ValueType::try_from(value_data.value_type)
271
1716
        .map_err(|()| Error::Runtime("unknown value type".to_string()))?;
272

            
273
1716
    match value_type {
274
264
        ValueType::Nil => Ok(Value::Nil),
275
198
        ValueType::Bool => Ok(Value::Bool(value_data.data1 != 0)),
276
869
        ValueType::Number => Ok(Value::Number(Fraction::new(
277
869
            value_data.data1,
278
869
            value_data.data2,
279
869
        ))),
280
        ValueType::String | ValueType::Symbol => {
281
385
            let pool_offset = value_data.data1 as usize;
282
385
            let len = value_data.data2 as usize;
283
385
            let strings_offset = output_header.strings_offset as usize;
284
385
            let start = strings_offset + pool_offset;
285
385
            let end = start + len;
286
385
            if end > data.len() {
287
                return Err(Error::Runtime("string data out of bounds".to_string()));
288
385
            }
289
385
            let s = std::str::from_utf8(&data[start..end])
290
385
                .map_err(|_| Error::Runtime("invalid UTF-8 in string".to_string()))?;
291
385
            if value_type == ValueType::Symbol {
292
77
                Ok(Value::Symbol(s.to_string()))
293
            } else {
294
308
                Ok(Value::String(s.to_string()))
295
            }
296
        }
297
    }
298
1716
}