Skip to main content

rpc/
framing.rs

1use thiserror::Error;
2
3#[derive(Debug, Error, PartialEq, Eq)]
4pub enum FrameError {
5    #[error("frame contains invalid UTF-8")]
6    InvalidUtf8,
7    #[error("malformed frame: {0}")]
8    Malformed(String),
9}
10
11#[derive(Debug, Default)]
12pub struct FrameDecoder {
13    buf: String,
14}
15
16impl FrameDecoder {
17    #[must_use]
18    pub fn new() -> Self {
19        Self::default()
20    }
21
22    pub fn feed(&mut self, bytes: &[u8]) -> Result<(), FrameError> {
23        let s = std::str::from_utf8(bytes).map_err(|_| FrameError::InvalidUtf8)?;
24        self.buf.push_str(s);
25        Ok(())
26    }
27
28    pub fn next_frame(&mut self) -> Option<Result<String, FrameError>> {
29        let trimmed_offset = self.leading_skip_len();
30        if trimmed_offset >= self.buf.len() {
31            self.buf.drain(..trimmed_offset);
32            return None;
33        }
34        match scan_form_end(&self.buf[trimmed_offset..]) {
35            Scan::Complete(rel_end) => {
36                let abs_end = trimmed_offset + rel_end;
37                let frame: String = self.buf[trimmed_offset..abs_end].to_string();
38                self.buf.drain(..abs_end);
39                Some(Ok(frame))
40            }
41            Scan::Incomplete => {
42                self.buf.drain(..trimmed_offset);
43                None
44            }
45            Scan::Invalid(msg) => {
46                self.buf.clear();
47                Some(Err(FrameError::Malformed(msg)))
48            }
49        }
50    }
51
52    fn leading_skip_len(&self) -> usize {
53        let bytes = self.buf.as_bytes();
54        let mut i = 0;
55        while i < bytes.len() {
56            let c = bytes[i];
57            if c == b' ' || c == b'\t' || c == b'\n' || c == b'\r' {
58                i += 1;
59            } else if c == b';' {
60                while i < bytes.len() && bytes[i] != b'\n' {
61                    i += 1;
62                }
63            } else {
64                break;
65            }
66        }
67        i
68    }
69}
70
71enum Scan {
72    Complete(usize),
73    Incomplete,
74    Invalid(String),
75}
76
77fn scan_form_end(s: &str) -> Scan {
78    let bytes = s.as_bytes();
79    if bytes.is_empty() {
80        return Scan::Incomplete;
81    }
82    match bytes[0] {
83        b'(' => scan_balanced_list(s, 0),
84        b'\'' | b'`' => scan_after_prefix(s, 1),
85        b',' if bytes.get(1) == Some(&b'@') => scan_after_prefix(s, 2),
86        b',' => scan_after_prefix(s, 1),
87        b'"' => scan_string(s, 0),
88        b'#' if bytes.get(1) == Some(&b'"') => scan_hash_string(s, 0),
89        b'#' if bytes.get(1) == Some(&b'u')
90            && bytes.get(2) == Some(&b'8')
91            && bytes.get(3) == Some(&b'(') =>
92        {
93            scan_balanced_list(s, 3)
94        }
95        _ => scan_atom(s, 0),
96    }
97}
98
99fn scan_after_prefix(s: &str, start: usize) -> Scan {
100    if start >= s.len() {
101        return Scan::Incomplete;
102    }
103    match scan_form_end(&s[start..]) {
104        Scan::Complete(rel) => Scan::Complete(start + rel),
105        other => other,
106    }
107}
108
109fn scan_balanced_list(s: &str, start: usize) -> Scan {
110    let bytes = s.as_bytes();
111    let mut i = start;
112    if bytes.get(i) != Some(&b'(') {
113        return Scan::Invalid("expected '(' at list start".into());
114    }
115    let mut depth: usize = 1;
116    i += 1;
117    while i < bytes.len() {
118        match bytes[i] {
119            b'(' => depth += 1,
120            b')' => {
121                depth -= 1;
122                if depth == 0 {
123                    return Scan::Complete(i + 1);
124                }
125            }
126            b'"' => match scan_string(s, i) {
127                Scan::Complete(end) => {
128                    i = end;
129                    continue;
130                }
131                Scan::Incomplete => return Scan::Incomplete,
132                Scan::Invalid(m) => return Scan::Invalid(m),
133            },
134            b'#' if bytes.get(i + 1) == Some(&b'"') => match scan_hash_string(s, i) {
135                Scan::Complete(end) => {
136                    i = end;
137                    continue;
138                }
139                Scan::Incomplete => return Scan::Incomplete,
140                Scan::Invalid(m) => return Scan::Invalid(m),
141            },
142            b';' => {
143                while i < bytes.len() && bytes[i] != b'\n' {
144                    i += 1;
145                }
146                continue;
147            }
148            _ => {}
149        }
150        i += 1;
151    }
152    Scan::Incomplete
153}
154
155fn scan_string(s: &str, start: usize) -> Scan {
156    let bytes = s.as_bytes();
157    if bytes.get(start) != Some(&b'"') {
158        return Scan::Invalid("expected '\"' at string start".into());
159    }
160    if bytes.get(start + 1) == Some(&b'"') && bytes.get(start + 2) == Some(&b'"') {
161        return scan_triple_string(s, start);
162    }
163    let mut i = start + 1;
164    while i < bytes.len() {
165        match bytes[i] {
166            b'\\' => {
167                if i + 1 >= bytes.len() {
168                    return Scan::Incomplete;
169                }
170                i += 2;
171            }
172            b'"' => return Scan::Complete(i + 1),
173            _ => i += 1,
174        }
175    }
176    Scan::Incomplete
177}
178
179fn scan_triple_string(s: &str, start: usize) -> Scan {
180    let bytes = s.as_bytes();
181    let mut i = start + 3;
182    while i + 2 < bytes.len() {
183        if bytes[i] == b'"' && bytes[i + 1] == b'"' && bytes[i + 2] == b'"' {
184            return Scan::Complete(i + 3);
185        }
186        i += 1;
187    }
188    Scan::Incomplete
189}
190
191fn scan_hash_string(s: &str, start: usize) -> Scan {
192    let bytes = s.as_bytes();
193    if bytes.get(start) != Some(&b'#') || bytes.get(start + 1) != Some(&b'"') {
194        return Scan::Invalid("expected '#\"' at hash-string start".into());
195    }
196    let mut i = start + 2;
197    while i < bytes.len() {
198        if bytes[i] == b'"' {
199            return Scan::Complete(i + 1);
200        }
201        i += 1;
202    }
203    Scan::Incomplete
204}
205
206fn scan_atom(s: &str, start: usize) -> Scan {
207    let bytes = s.as_bytes();
208    let mut i = start;
209    while i < bytes.len() {
210        let c = bytes[i];
211        if c == b' '
212            || c == b'\t'
213            || c == b'\n'
214            || c == b'\r'
215            || c == b'('
216            || c == b')'
217            || c == b';'
218        {
219            break;
220        }
221        i += 1;
222    }
223    if i == start {
224        Scan::Incomplete
225    } else {
226        Scan::Complete(i)
227    }
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233
234    fn drain(decoder: &mut FrameDecoder) -> Vec<String> {
235        let mut frames = Vec::new();
236        while let Some(result) = decoder.next_frame() {
237            match result {
238                Ok(f) => frames.push(f),
239                Err(e) => panic!("unexpected frame error: {e}"),
240            }
241        }
242        frames
243    }
244
245    #[test]
246    fn returns_none_when_buffer_empty() {
247        let mut d = FrameDecoder::new();
248        assert!(d.next_frame().is_none());
249    }
250
251    #[test]
252    fn returns_none_when_buffer_only_whitespace() {
253        let mut d = FrameDecoder::new();
254        d.feed(b"   \n\t  ").unwrap();
255        assert!(d.next_frame().is_none());
256    }
257
258    #[test]
259    fn yields_atom_frame() {
260        let mut d = FrameDecoder::new();
261        d.feed(b"42 ").unwrap();
262        let frames = drain(&mut d);
263        assert_eq!(frames, vec!["42".to_string()]);
264    }
265
266    #[test]
267    fn yields_simple_list_frame() {
268        let mut d = FrameDecoder::new();
269        d.feed(b"(foo bar)").unwrap();
270        let frames = drain(&mut d);
271        assert_eq!(frames, vec!["(foo bar)".to_string()]);
272    }
273
274    #[test]
275    fn yields_multiple_frames_streamed() {
276        let mut d = FrameDecoder::new();
277        d.feed(b"(a)\n(b)\n(c)\n").unwrap();
278        let frames = drain(&mut d);
279        assert_eq!(
280            frames,
281            vec!["(a)".to_string(), "(b)".to_string(), "(c)".to_string()]
282        );
283    }
284
285    #[test]
286    fn defers_when_form_incomplete() {
287        let mut d = FrameDecoder::new();
288        d.feed(b"(foo ").unwrap();
289        assert!(d.next_frame().is_none());
290        d.feed(b"bar)").unwrap();
291        let frames = drain(&mut d);
292        assert_eq!(frames, vec!["(foo bar)".to_string()]);
293    }
294
295    #[test]
296    fn handles_nested_lists() {
297        let mut d = FrameDecoder::new();
298        d.feed(b"(a (b (c d) e) f)").unwrap();
299        let frames = drain(&mut d);
300        assert_eq!(frames, vec!["(a (b (c d) e) f)".to_string()]);
301    }
302
303    #[test]
304    fn parens_inside_string_do_not_affect_depth() {
305        let mut d = FrameDecoder::new();
306        d.feed(b"(foo \"a)b(c\" bar)").unwrap();
307        let frames = drain(&mut d);
308        assert_eq!(frames, vec!["(foo \"a)b(c\" bar)".to_string()]);
309    }
310
311    #[test]
312    fn handles_escaped_quote_in_string() {
313        let mut d = FrameDecoder::new();
314        d.feed(b"(say \"he\\\"llo\")").unwrap();
315        let frames = drain(&mut d);
316        assert_eq!(frames, vec!["(say \"he\\\"llo\")".to_string()]);
317    }
318
319    #[test]
320    fn handles_triple_quoted_string_across_lines() {
321        let mut d = FrameDecoder::new();
322        d.feed(b"(doc \"\"\"line one\n).\nline two\"\"\")").unwrap();
323        let frames = drain(&mut d);
324        assert_eq!(
325            frames,
326            vec!["(doc \"\"\"line one\n).\nline two\"\"\")".to_string()]
327        );
328    }
329
330    #[test]
331    fn handles_base64_literal_with_parens_inside() {
332        let mut d = FrameDecoder::new();
333        d.feed(b"(blob #\"abc())def\")").unwrap();
334        let frames = drain(&mut d);
335        assert_eq!(frames, vec!["(blob #\"abc())def\")".to_string()]);
336    }
337
338    #[test]
339    fn handles_byte_vector_literal_with_inner_parens_handled_by_balance() {
340        let mut d = FrameDecoder::new();
341        d.feed(b"(blob #u8(1 2 3))").unwrap();
342        let frames = drain(&mut d);
343        assert_eq!(frames, vec!["(blob #u8(1 2 3))".to_string()]);
344    }
345
346    #[test]
347    fn handles_quote_prefix() {
348        let mut d = FrameDecoder::new();
349        d.feed(b"'(a b)").unwrap();
350        let frames = drain(&mut d);
351        assert_eq!(frames, vec!["'(a b)".to_string()]);
352    }
353
354    #[test]
355    fn handles_quasiquote_with_unquote_inside() {
356        let mut d = FrameDecoder::new();
357        d.feed(b"`(a ,b ,@c)").unwrap();
358        let frames = drain(&mut d);
359        assert_eq!(frames, vec!["`(a ,b ,@c)".to_string()]);
360    }
361
362    #[test]
363    fn skips_top_level_comments_between_frames() {
364        let mut d = FrameDecoder::new();
365        d.feed(b"; comment\n(foo)\n; another\n(bar)\n").unwrap();
366        let frames = drain(&mut d);
367        assert_eq!(frames, vec!["(foo)".to_string(), "(bar)".to_string()]);
368    }
369
370    #[test]
371    fn comment_inside_list_does_not_split_frame() {
372        let mut d = FrameDecoder::new();
373        d.feed(b"(foo ; comment with )\n  bar)").unwrap();
374        let frames = drain(&mut d);
375        assert_eq!(frames, vec!["(foo ; comment with )\n  bar)".to_string()]);
376    }
377
378    #[test]
379    fn invalid_utf8_returns_error() {
380        let mut d = FrameDecoder::new();
381        let err = d.feed(&[0xFF, 0xFE, 0xFD]).unwrap_err();
382        assert_eq!(err, FrameError::InvalidUtf8);
383    }
384
385    #[test]
386    fn standalone_string_is_a_frame() {
387        let mut d = FrameDecoder::new();
388        d.feed(b"\"hello\"").unwrap();
389        let frames = drain(&mut d);
390        assert_eq!(frames, vec!["\"hello\"".to_string()]);
391    }
392
393    #[test]
394    fn standalone_byte_vector_is_a_frame() {
395        let mut d = FrameDecoder::new();
396        d.feed(b"#u8(1 2 3)").unwrap();
397        let frames = drain(&mut d);
398        assert_eq!(frames, vec!["#u8(1 2 3)".to_string()]);
399    }
400
401    #[test]
402    fn standalone_base64_is_a_frame() {
403        let mut d = FrameDecoder::new();
404        d.feed(b"#\"abcd\"").unwrap();
405        let frames = drain(&mut d);
406        assert_eq!(frames, vec!["#\"abcd\"".to_string()]);
407    }
408}