1
use std::collections::HashMap;
2

            
3
use tracing::debug;
4
use wasm_encoder::{
5
    ArrayType, CodeSection, CompositeInnerType, CompositeType, DataCountSection, DataSection,
6
    EntityType as WasmEntityType, ExportKind, ExportSection, FieldType, Function, FunctionSection,
7
    ImportSection, Instruction, MemorySection, MemoryType, Module, StorageType, SubType,
8
    TypeSection, ValType,
9
};
10

            
11
pub struct CompileContext {
12
    types: TypeSection,
13
    imports: ImportSection,
14
    functions: FunctionSection,
15
    memories: MemorySection,
16
    exports: ExportSection,
17
    data: DataSection,
18
    codes: CodeSection,
19
    data_count: u32,
20
    type_count: u32,
21
    import_func_count: u32,
22
    local_func_count: u32,
23
    type_cache: HashMap<Vec<ValType>, HashMap<Vec<ValType>, u32>>,
24
    type_names: HashMap<String, u32>,
25
    func_names: HashMap<String, u32>,
26
}
27

            
28
impl CompileContext {
29
12113
    pub fn new() -> Self {
30
12113
        debug!("initializing compile context");
31
12113
        let mut ctx = Self {
32
12113
            types: TypeSection::new(),
33
12113
            imports: ImportSection::new(),
34
12113
            functions: FunctionSection::new(),
35
12113
            memories: MemorySection::new(),
36
12113
            exports: ExportSection::new(),
37
12113
            data: DataSection::new(),
38
12113
            codes: CodeSection::new(),
39
12113
            data_count: 0,
40
12113
            type_count: 0,
41
12113
            import_func_count: 0,
42
12113
            local_func_count: 0,
43
12113
            type_cache: HashMap::new(),
44
12113
            type_names: HashMap::new(),
45
12113
            func_names: HashMap::new(),
46
12113
        };
47

            
48
12113
        ctx.register_type("i8_array");
49
12113
        ctx.register_import("env", "get_output_offset", &[], &[ValType::I32]);
50
12113
        ctx.register_import("env", "symbol_resolve", &[ValType::I32, ValType::I32], &[]);
51
12113
        ctx.register_import(
52
12113
            "env",
53
12113
            "log",
54
12113
            &[ValType::I32, ValType::I32, ValType::I32],
55
12113
            &[],
56
        );
57
12113
        ctx.register_function("should_apply", &[], &[ValType::I32]);
58
12113
        ctx.register_function("process", &[], &[]);
59
12113
        ctx.export_func("should_apply");
60
12113
        ctx.export_func("process");
61

            
62
12113
        ctx.memories.memory(MemoryType {
63
12113
            minimum: 1,
64
12113
            maximum: None,
65
12113
            memory64: false,
66
12113
            shared: false,
67
12113
            page_size_log2: None,
68
12113
        });
69
12113
        ctx.exports.export("memory", ExportKind::Memory, 0);
70

            
71
12113
        debug!("compile context initialized");
72
12113
        ctx
73
12113
    }
74

            
75
12113
    pub fn register_type(&mut self, name: &str) -> u32 {
76
12113
        let idx = self.type_count;
77
12113
        self.types.ty().subtype(&SubType {
78
12113
            is_final: true,
79
12113
            supertype_idx: None,
80
12113
            composite_type: CompositeType {
81
12113
                inner: CompositeInnerType::Array(ArrayType(FieldType {
82
12113
                    element_type: StorageType::I8,
83
12113
                    mutable: false,
84
12113
                })),
85
12113
                shared: false,
86
12113
                describes: None,
87
12113
                descriptor: None,
88
12113
            },
89
12113
        });
90
12113
        self.type_count += 1;
91
12113
        self.type_names.insert(name.to_string(), idx);
92
12113
        idx
93
12113
    }
94

            
95
36339
    pub fn register_import(
96
36339
        &mut self,
97
36339
        module: &str,
98
36339
        name: &str,
99
36339
        params: &[ValType],
100
36339
        results: &[ValType],
101
36339
    ) -> u32 {
102
36339
        let type_idx = self.get_or_create_func_type(params, results);
103
36339
        self.imports
104
36339
            .import(module, name, WasmEntityType::Function(type_idx));
105
36339
        let func_idx = self.import_func_count;
106
36339
        self.import_func_count += 1;
107
36339
        self.func_names.insert(name.to_string(), func_idx);
108
36339
        func_idx
109
36339
    }
110

            
111
24226
    pub fn register_function(&mut self, name: &str, params: &[ValType], results: &[ValType]) {
112
24226
        let type_idx = self.get_or_create_func_type(params, results);
113
24226
        self.functions.function(type_idx);
114
24226
        let func_idx = self.import_func_count + self.local_func_count;
115
24226
        self.local_func_count += 1;
116
24226
        self.func_names.insert(name.to_string(), func_idx);
117
24226
    }
118

            
119
24226
    pub fn export_func(&mut self, name: &str) {
120
24226
        let idx = self.func_names[name];
121
24226
        self.exports.export(name, ExportKind::Func, idx);
122
24226
    }
123

            
124
12317
    pub fn func(&self, name: &str) -> u32 {
125
12317
        self.func_names[name]
126
12317
    }
127

            
128
18984
    pub fn type_idx(&self, name: &str) -> u32 {
129
18984
        self.type_names[name]
130
18984
    }
131

            
132
60565
    fn get_or_create_func_type(&mut self, params: &[ValType], results: &[ValType]) -> u32 {
133
60565
        if let Some(inner) = self.type_cache.get(params)
134
24226
            && let Some(&idx) = inner.get(results)
135
        {
136
12113
            return idx;
137
48452
        }
138
48452
        let idx = self.type_count;
139
48452
        self.types
140
48452
            .ty()
141
48452
            .function(params.iter().copied(), results.iter().copied());
142
48452
        self.type_count += 1;
143
48452
        self.type_cache
144
48452
            .entry(params.to_vec())
145
48452
            .or_default()
146
48452
            .insert(results.to_vec(), idx);
147
48452
        idx
148
60565
    }
149

            
150
7109
    pub fn add_data(&mut self, bytes: &[u8]) -> u32 {
151
7109
        let idx = self.data_count;
152
7109
        debug!(idx, len = bytes.len(), "adding data segment");
153
7109
        self.data.passive(bytes.iter().copied());
154
7109
        self.data_count += 1;
155
7109
        idx
156
7109
    }
157

            
158
12113
    pub fn add_should_apply(&mut self) {
159
12113
        debug!("emitting should_apply function");
160
12113
        let mut f = Function::new([]);
161
12113
        f.instruction(&Instruction::I32Const(1));
162
12113
        f.instruction(&Instruction::End);
163
12113
        self.codes.function(&f);
164
12113
    }
165

            
166
11058
    pub fn add_process(&mut self, f: Function) {
167
11058
        debug!("emitting process function");
168
11058
        self.codes.function(&f);
169
11058
    }
170

            
171
11058
    pub fn finish(self) -> Vec<u8> {
172
11058
        debug!(data_segments = self.data_count, "assembling WASM module");
173
11058
        let mut module = Module::new();
174
11058
        module.section(&self.types);
175
11058
        module.section(&self.imports);
176
11058
        module.section(&self.functions);
177
11058
        module.section(&self.memories);
178
11058
        module.section(&self.exports);
179
11058
        module.section(&DataCountSection {
180
11058
            count: self.data_count,
181
11058
        });
182
11058
        module.section(&self.codes);
183
11058
        module.section(&self.data);
184
11058
        module.finish()
185
11058
    }
186
}
187

            
188
impl Default for CompileContext {
189
    fn default() -> Self {
190
        Self::new()
191
    }
192
}