1
use std::sync::{Arc, Mutex};
2
use std::time::{SystemTime, UNIX_EPOCH};
3

            
4
use wasmtime::{Caller, Linker, Memory};
5

            
6
pub struct HostState {
7
    pub input_offset: u32,
8
    pub output_offset: u32,
9
    pub strings_offset: u32,
10
    pub output_strings_offset: Arc<Mutex<u32>>,
11
    pub memory: Option<Memory>,
12
}
13

            
14
impl HostState {
15
    #[must_use]
16
49
    pub fn new(input_offset: u32, output_offset: u32, strings_offset: u32) -> Self {
17
49
        Self {
18
49
            input_offset,
19
49
            output_offset,
20
49
            strings_offset,
21
49
            output_strings_offset: Arc::new(Mutex::new(0)),
22
49
            memory: None,
23
49
        }
24
49
    }
25
}
26

            
27
48
pub fn define_host_functions(linker: &mut Linker<HostState>) -> wasmtime::Result<()> {
28
48
    linker.func_wrap(
29
48
        "env",
30
48
        "get_input_offset",
31
64
        |caller: Caller<HostState>| -> u32 { caller.data().input_offset },
32
    )?;
33

            
34
48
    linker.func_wrap(
35
48
        "env",
36
48
        "get_output_offset",
37
64
        |caller: Caller<HostState>| -> u32 { caller.data().output_offset },
38
    )?;
39

            
40
48
    linker.func_wrap(
41
48
        "env",
42
48
        "get_strings_offset",
43
        |caller: Caller<HostState>| -> u32 { caller.data().strings_offset },
44
    )?;
45

            
46
48
    linker.func_wrap(
47
48
        "env",
48
48
        "write_bytes",
49
        |mut caller: Caller<HostState>, dst: u32, src: u32, len: u32| -> u32 {
50
            let memory = match caller.data().memory {
51
                Some(mem) => mem,
52
                None => return 0,
53
            };
54
            let data = memory.data_mut(&mut caller);
55
            let src_start = src as usize;
56
            let src_end = src_start + len as usize;
57
            let dst_start = dst as usize;
58

            
59
            if src_end > data.len() || dst_start + len as usize > data.len() {
60
                return 0;
61
            }
62

            
63
            let bytes: Vec<u8> = data[src_start..src_end].to_vec();
64
            data[dst_start..dst_start + len as usize].copy_from_slice(&bytes);
65
            len
66
        },
67
    )?;
68

            
69
48
    linker.func_wrap(
70
48
        "env",
71
48
        "write_string",
72
        |mut caller: Caller<HostState>, ptr: u32, len: u32| -> u32 {
73
            let output_offset = caller.data().output_offset;
74
            let output_strings = caller.data().output_strings_offset.clone();
75

            
76
            let memory = match caller.data().memory {
77
                Some(mem) => mem,
78
                None => return 0,
79
            };
80

            
81
            let data = memory.data_mut(&mut caller);
82
            let src_start = ptr as usize;
83
            let src_end = src_start + len as usize;
84

            
85
            if src_end > data.len() {
86
                return 0;
87
            }
88

            
89
            let mut strings_offset = match output_strings.lock() {
90
                Ok(guard) => guard,
91
                Err(_) => return 0,
92
            };
93

            
94
            let current_offset = *strings_offset;
95
            let dst = output_offset as usize + current_offset as usize;
96

            
97
            if dst + len as usize > data.len() {
98
                return 0;
99
            }
100

            
101
            let bytes: Vec<u8> = data[src_start..src_end].to_vec();
102
            data[dst..dst + len as usize].copy_from_slice(&bytes);
103
            *strings_offset += len;
104

            
105
            current_offset
106
        },
107
    )?;
108

            
109
48
    linker.func_wrap(
110
48
        "env",
111
48
        "log",
112
        |caller: Caller<HostState>, level: u32, msg_ptr: u32, msg_len: u32| {
113
            let memory = match caller.data().memory {
114
                Some(mem) => mem,
115
                None => return,
116
            };
117

            
118
            let data = memory.data(&caller);
119
            let start = msg_ptr as usize;
120
            let end = start + msg_len as usize;
121

            
122
            if end > data.len() {
123
                return;
124
            }
125

            
126
            let msg = match std::str::from_utf8(&data[start..end]) {
127
                Ok(s) => s,
128
                Err(_) => return,
129
            };
130

            
131
            match level {
132
                0 => log::debug!("[script] {msg}"),
133
                1 => log::info!("[script] {msg}"),
134
                2 => log::warn!("[script] {msg}"),
135
                _ => log::error!("[script] {msg}"),
136
            }
137
        },
138
    )?;
139

            
140
48
    linker.func_wrap("env", "get_timestamp", || -> i64 {
141
        SystemTime::now()
142
            .duration_since(UNIX_EPOCH)
143
            .map(|d| d.as_millis() as i64)
144
            .unwrap_or(0)
145
    })?;
146

            
147
48
    linker.func_wrap(
148
48
        "env",
149
48
        "generate_uuid",
150
32
        |mut caller: Caller<HostState>, out_ptr: u32| {
151
32
            let memory = match caller.data().memory {
152
32
                Some(mem) => mem,
153
                None => return,
154
            };
155

            
156
32
            let uuid_bytes = uuid::Uuid::new_v4().into_bytes();
157
32
            let data = memory.data_mut(&mut caller);
158
32
            let start = out_ptr as usize;
159

            
160
32
            if start + 16 > data.len() {
161
                return;
162
32
            }
163

            
164
32
            data[start..start + 16].copy_from_slice(&uuid_bytes);
165
32
        },
166
    )?;
167

            
168
48
    Ok(())
169
48
}
170

            
171
#[cfg(test)]
172
mod tests {
173
    use super::*;
174
    use crate::format::BASE_OFFSET;
175

            
176
    #[test]
177
1
    fn test_host_state_creation() {
178
1
        let state = HostState::new(BASE_OFFSET, BASE_OFFSET + 1024, BASE_OFFSET + 512);
179
1
        assert_eq!(state.input_offset, BASE_OFFSET);
180
1
        assert_eq!(state.output_offset, BASE_OFFSET + 1024);
181
1
        assert_eq!(state.strings_offset, BASE_OFFSET + 512);
182
1
    }
183
}