1
//! SLYNK wire framing: a 6-hex-digit ASCII payload-length header followed by
2
//! that many bytes of UTF-8 (one s-expression). See
3
//! `doc/editor/slynk-protocol-transcript.org`.
4

            
5
use std::io;
6

            
7
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
8

            
9
/// Largest payload we will read. SLYNK frames are small control/eval messages;
10
/// a header claiming more than this is treated as a protocol error rather than
11
/// allocating an attacker-controlled buffer.
12
const MAX_FRAME_LEN: usize = 16 * 1024 * 1024;
13

            
14
/// Reads one framed message: 6 hex length digits, then `len` UTF-8 bytes.
15
/// Returns `Ok(None)` on a clean EOF at a frame boundary (peer closed), `Err`
16
/// on a malformed header / oversize length / non-UTF-8 / truncated body — the
17
/// caller closes the connection, never panics.
18
5
pub async fn read_frame<R: AsyncRead + Unpin>(reader: &mut R) -> io::Result<Option<String>> {
19
5
    let mut header = [0u8; 6];
20
5
    match reader.read_exact(&mut header).await {
21
4
        Ok(_) => {}
22
1
        Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
23
        Err(e) => return Err(e),
24
    }
25
4
    let header_str = std::str::from_utf8(&header)
26
4
        .map_err(|_| protocol_err("frame length header is not ASCII"))?;
27
4
    let len = usize::from_str_radix(header_str, 16)
28
4
        .map_err(|_| protocol_err("frame length header is not hex"))?;
29
3
    if len > MAX_FRAME_LEN {
30
        return Err(protocol_err("frame length exceeds maximum"));
31
3
    }
32
3
    let mut body = vec![0u8; len];
33
3
    reader.read_exact(&mut body).await?;
34
2
    let text = String::from_utf8(body).map_err(|_| protocol_err("frame body is not UTF-8"))?;
35
2
    Ok(Some(text))
36
5
}
37

            
38
/// Writes one framed message: the 6-hex length header then the payload bytes.
39
/// Errors if the payload exceeds the 6-hex-digit ceiling (a 16 MiB-ish limit
40
/// the protocol never approaches for control/eval messages).
41
2
pub async fn write_frame<W: AsyncWrite + Unpin>(writer: &mut W, payload: &str) -> io::Result<()> {
42
2
    let len = payload.len();
43
2
    if len > 0xff_ffff {
44
        return Err(protocol_err("outgoing frame exceeds 6-hex-digit length"));
45
2
    }
46
2
    writer.write_all(format!("{len:06x}").as_bytes()).await?;
47
2
    writer.write_all(payload.as_bytes()).await?;
48
2
    writer.flush().await?;
49
2
    Ok(())
50
2
}
51

            
52
1
fn protocol_err(msg: &str) -> io::Error {
53
1
    io::Error::new(io::ErrorKind::InvalidData, msg)
54
1
}
55

            
56
#[cfg(test)]
57
mod tests {
58
    use super::*;
59
    use std::io::Cursor;
60

            
61
    #[tokio::test]
62
1
    async fn round_trips_a_frame() {
63
1
        let mut buf = Vec::new();
64
1
        write_frame(&mut buf, "(:return (:ok nil) 1)")
65
1
            .await
66
1
            .unwrap();
67
        // header is the lowercase hex byte length of the payload
68
1
        assert_eq!(&buf[..6], b"000015");
69
1
        let mut cur = Cursor::new(buf);
70
1
        let got = read_frame(&mut cur).await.unwrap();
71
1
        assert_eq!(got.as_deref(), Some("(:return (:ok nil) 1)"));
72
1
    }
73

            
74
    #[tokio::test]
75
1
    async fn clean_eof_yields_none() {
76
1
        let mut cur = Cursor::new(Vec::new());
77
1
        assert_eq!(read_frame(&mut cur).await.unwrap(), None);
78
1
    }
79

            
80
    #[tokio::test]
81
1
    async fn non_hex_header_is_error_not_panic() {
82
1
        let mut cur = Cursor::new(b"zzzzzz".to_vec());
83
1
        assert!(read_frame(&mut cur).await.is_err());
84
1
    }
85

            
86
    #[tokio::test]
87
1
    async fn truncated_body_is_error() {
88
        // header claims 10 bytes, only 3 follow
89
1
        let mut cur = Cursor::new(b"00000aabc".to_vec());
90
1
        assert!(read_frame(&mut cur).await.is_err());
91
1
    }
92

            
93
    #[tokio::test]
94
1
    async fn utf8_payload_round_trips() {
95
1
        let mut buf = Vec::new();
96
1
        let payload = "(:write-string \"caf\u{e9}\")";
97
1
        write_frame(&mut buf, payload).await.unwrap();
98
1
        let mut cur = Cursor::new(buf);
99
1
        assert_eq!(
100
1
            read_frame(&mut cur).await.unwrap().as_deref(),
101
1
            Some(payload)
102
1
        );
103
1
    }
104
}