1
//! Per-channel state for the `nomisync-eval` ssh subsystem.
2
//!
3
//! Owns the frame decoder (newline + balanced-paren), the epoch cancel
4
//! handle, and a tokio mpsc to a per-channel worker that drives
5
//! `rpc::Session::handle_form` independently of the russh handler
6
//! callback. Decoupling the worker from `data()` is what lets the
7
//! handler scan subsequent bytes (an emacs `C-g` arrives as the ETX
8
//! byte `0x03`) and bump the engine epoch synchronously, tripping the
9
//! in-flight eval and surfacing `(:error (:code interrupted ...))`
10
//! on the wire.
11
//!
12
//! Constructed in `subsystem_request` once the requested subsystem
13
//! name matches `"nomisync-eval"` AND the session has authenticated
14
//! (i.e. `user_id.is_some()`). Dropped in `channel_eof` /
15
//! `channel_close` along with the rest of the per-channel state; the
16
//! drop closes the mpsc, the worker observes the close and exits.
17

            
18
use russh::ChannelId;
19
use russh::server::Handle as RusshHandle;
20
use tokio::sync::mpsc;
21

            
22
use rpc::session::SessionError;
23
use rpc::{EpochBumper, FrameDecoder, FrameError, InterruptHandle, ScriptCtx, Session};
24

            
25
/// ASCII End-of-Text (`Ctrl-C` / `C-g`). Clients emit this byte on
26
/// the same channel as form data to request a cooperative cancel of
27
/// the in-flight `nomi-eval`. Stripped from the byte stream before
28
/// the decoder sees it so it never becomes part of a frame.
29
const ETX: u8 = 0x03;
30

            
31
/// Trait for the worker's response back-channel. The russh-backed
32
/// impl writes to an SSH channel via the server's `Handle`; the
33
/// mpsc-backed impl drains into a tokio receiver for unit tests.
34
/// Either way, the worker emits one full response (newline-
35
/// terminated) per send call.
36
#[async_trait::async_trait]
37
pub trait ResponseSink: Send + Sync + 'static {
38
    async fn send(&self, bytes: Vec<u8>);
39
}
40

            
41
pub struct RusshSink {
42
    handle: RusshHandle,
43
    channel: ChannelId,
44
}
45

            
46
impl RusshSink {
47
    pub fn new(handle: RusshHandle, channel: ChannelId) -> Self {
48
        Self { handle, channel }
49
    }
50
}
51

            
52
#[async_trait::async_trait]
53
impl ResponseSink for RusshSink {
54
    async fn send(&self, bytes: Vec<u8>) {
55
        if let Err(remaining) = self.handle.data(self.channel, bytes).await {
56
            log::warn!(
57
                "eval channel response write dropped {} bytes (channel closed?)",
58
                remaining.len()
59
            );
60
        }
61
    }
62
}
63

            
64
pub struct EvalChannelState {
65
    decoder: FrameDecoder,
66
    bumper: EpochBumper,
67
    /// Generation-counter cancel signal, shared with the worker's
68
    /// `Session`. Bumped together with the epoch on ETX so an
69
    /// interrupt arriving in the worker-pickup / compile window (before
70
    /// the eval enters Wasm, where an epoch bump would be absorbed by
71
    /// the not-yet-set deadline) is still observed by the session's
72
    /// pre-start `check_interrupt`.
73
    interrupt: InterruptHandle,
74
    forms_tx: mpsc::UnboundedSender<String>,
75
    /// Worker task drives `Session::handle_form` sequentially per
76
    /// channel; held so the worker stays alive for the channel's
77
    /// lifetime. Dropping `EvalChannelState` closes `forms_tx`, the
78
    /// worker's `recv()` returns `None`, and the task exits cleanly.
79
    _worker: tokio::task::JoinHandle<()>,
80
}
81

            
82
impl EvalChannelState {
83
7
    pub fn new(ctx: ScriptCtx, sink: Box<dyn ResponseSink>) -> Result<Self, SessionError> {
84
7
        let mut session = Session::new(ctx)?;
85
7
        let bumper = session.epoch_bumper();
86
7
        let interrupt = session.interrupt_handle();
87
7
        let (forms_tx, mut forms_rx) = mpsc::unbounded_channel::<String>();
88
7
        let worker = tokio::spawn(async move {
89
11
            while let Some(frame) = forms_rx.recv().await {
90
5
                let mut response = session.handle_form(&frame).await;
91
5
                response.push('\n');
92
5
                sink.send(response.into_bytes()).await;
93
            }
94
1
        });
95
7
        Ok(Self {
96
7
            decoder: FrameDecoder::new(),
97
7
            bumper,
98
7
            interrupt,
99
7
            forms_tx,
100
7
            _worker: worker,
101
7
        })
102
7
    }
103

            
104
    /// Feeds raw bytes received on the channel. Scans for the ETX
105
    /// interrupt token synchronously and, on finding one, signals the
106
    /// cancel both ways: it advances the interrupt generation (caught
107
    /// by `handle_form`'s pre-start `check_interrupt` if the eval has
108
    /// not entered Wasm yet) and bumps the engine epoch (traps an eval
109
    /// already running in `call_async`). Wiring both closes the
110
    /// worker-pickup / compile window, where a lone epoch bump would be
111
    /// absorbed by the not-yet-set deadline and the cancel lost. Either
112
    /// path surfaces a `(:code interrupted ...)` envelope. ETX bytes are
113
    /// stripped before the bytes reach the frame decoder.
114
    ///
115
    /// Errors only for non-UTF8 input (the ssh layer should never
116
    /// emit that for an s-expression subsystem).
117
8
    pub fn feed(&mut self, bytes: &[u8]) -> Result<(), FrameError> {
118
8
        let has_interrupt = bytes.contains(&ETX);
119
8
        if has_interrupt {
120
3
            self.interrupt.interrupt();
121
3
            self.bumper.bump();
122
5
        }
123
        let cleaned: Vec<u8>;
124
8
        let payload: &[u8] = if has_interrupt {
125
40
            cleaned = bytes.iter().copied().filter(|b| *b != ETX).collect();
126
3
            &cleaned
127
        } else {
128
5
            bytes
129
        };
130
8
        self.decoder.feed(payload)?;
131
        loop {
132
12
            match self.decoder.next_frame() {
133
7
                None => break,
134
5
                Some(Ok(frame)) => {
135
5
                    if self.forms_tx.send(frame).is_err() {
136
                        // Worker exited (channel drop or session
137
                        // error). Subsequent frames silently noop;
138
                        // the channel will be torn down by the SSH
139
                        // layer when the client notices the missing
140
                        // responses.
141
                        break;
142
5
                    }
143
                }
144
                Some(Err(err)) => {
145
                    return Err(err);
146
                }
147
            }
148
        }
149
7
        Ok(())
150
8
    }
151
}
152

            
153
#[cfg(test)]
154
mod tests {
155
    use super::*;
156
    use rpc::ScriptLimits;
157
    use sqlx::types::Uuid;
158

            
159
    struct MpscSink {
160
        tx: mpsc::UnboundedSender<Vec<u8>>,
161
    }
162

            
163
    #[async_trait::async_trait]
164
    impl ResponseSink for MpscSink {
165
5
        async fn send(&self, bytes: Vec<u8>) {
166
            let _ = self.tx.send(bytes);
167
5
        }
168
    }
169

            
170
6
    fn new_state() -> (EvalChannelState, mpsc::UnboundedReceiver<Vec<u8>>) {
171
6
        state_with_ctx(ScriptCtx::new(Uuid::new_v4()))
172
6
    }
173

            
174
7
    fn state_with_ctx(ctx: ScriptCtx) -> (EvalChannelState, mpsc::UnboundedReceiver<Vec<u8>>) {
175
7
        let (tx, rx) = mpsc::unbounded_channel();
176
7
        let sink = Box::new(MpscSink { tx });
177
7
        let state = EvalChannelState::new(ctx, sink).expect("session init");
178
7
        (state, rx)
179
7
    }
180

            
181
    /// Collects `expected` responses, awaiting each rather than
182
    /// busy-polling a fixed budget so a slow CI runner never spuriously
183
    /// yields fewer (a genuine hang still fails via the per-response
184
    /// timeout). After the expected count arrives it keeps draining on a
185
    /// short grace window, so an over-emission regression (an extra
186
    /// frame) is collected and trips the caller's count assertion.
187
4
    async fn drain_count(
188
4
        rx: &mut mpsc::UnboundedReceiver<Vec<u8>>,
189
4
        expected: usize,
190
4
    ) -> Vec<String> {
191
4
        let mut out = Vec::new();
192
        loop {
193
9
            let budget = if out.len() < expected {
194
5
                std::time::Duration::from_secs(10)
195
            } else {
196
4
                std::time::Duration::from_millis(100)
197
            };
198
9
            match tokio::time::timeout(budget, rx.recv()).await {
199
5
                Ok(Some(b)) => out.push(String::from_utf8(b).unwrap()),
200
4
                Ok(None) | Err(_) => break,
201
            }
202
        }
203
4
        out
204
4
    }
205

            
206
    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
207
1
    async fn empty_feed_produces_no_responses() {
208
1
        let (mut state, mut rx) = new_state();
209
1
        state.feed(b"").unwrap();
210
1
        tokio::time::sleep(std::time::Duration::from_millis(20)).await;
211
1
        assert!(rx.try_recv().is_err());
212
1
    }
213

            
214
    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
215
1
    async fn complete_frame_yields_response_with_trailing_newline() {
216
1
        let (mut state, mut rx) = new_state();
217
1
        state
218
1
            .feed(b"(:id 1 :form (rpc-protocol-version))\n")
219
1
            .unwrap();
220
1
        let responses = drain_count(&mut rx, 1).await;
221
1
        assert_eq!(responses.len(), 1, "got {responses:?}");
222
1
        let r = &responses[0];
223
1
        assert!(r.starts_with("(:id 1"), "{r:?}");
224
1
        assert!(r.ends_with('\n'), "{r:?}");
225
1
    }
226

            
227
    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
228
1
    async fn two_concatenated_frames_yield_two_responses() {
229
1
        let (mut state, mut rx) = new_state();
230
1
        state
231
1
            .feed(b"(:id 3 :form (rpc-protocol-version))\n(:id 4 :form (rpc-protocol-version))\n")
232
1
            .unwrap();
233
1
        let responses = drain_count(&mut rx, 2).await;
234
1
        assert_eq!(responses.len(), 2);
235
1
        assert!(responses[0].contains(":id 3"));
236
1
        assert!(responses[1].contains(":id 4"));
237
1
    }
238

            
239
    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
240
1
    async fn etx_byte_in_stream_is_stripped() {
241
        // Send `(form)<ETX>` — the ETX must NOT pollute the decoder
242
        // (which would refuse the byte as non-paren content) and must
243
        // NOT prevent the frame from completing.
244
1
        let (mut state, mut rx) = new_state();
245
1
        let mut bytes = b"(:id 5 :form (rpc-protocol-version))\n".to_vec();
246
1
        bytes.push(ETX);
247
1
        state.feed(&bytes).unwrap();
248
1
        let responses = drain_count(&mut rx, 1).await;
249
1
        assert_eq!(responses.len(), 1, "got {responses:?}");
250
1
        assert!(responses[0].contains(":id 5"));
251
1
    }
252

            
253
    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
254
1
    async fn etx_byte_alone_bumps_epoch_without_decoding_a_frame() {
255
        // Just an ETX with no surrounding form should not produce
256
        // any output — the bumper fires, but there's nothing to
257
        // cancel since no eval is in flight.
258
1
        let (mut state, mut rx) = new_state();
259
1
        state.feed(&[ETX]).unwrap();
260
1
        tokio::time::sleep(std::time::Duration::from_millis(20)).await;
261
1
        assert!(rx.try_recv().is_err());
262
1
    }
263

            
264
    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
265
1
    async fn etx_byte_after_form_cancels_inflight_eval() {
266
        // End-to-end cancel scenario through the eval-channel layer:
267
        // feed a long-running form, then feed an ETX byte on the same
268
        // channel. `feed` both bumps the epoch and advances the
269
        // interrupt generation, so the cancel is observed whether the
270
        // eval is already inside Wasm (epoch trap) or still in the
271
        // worker-pickup / compile window (pre-start check_interrupt).
272
        //
273
        // Fuel is set effectively unbounded and the loop runs billions
274
        // of iterations (the bound is just under i32::MAX, since the
275
        // loop counter is an Index / i32), so the form can NEVER
276
        // terminate on its own (neither a natural value nor a
277
        // fuel-exhaustion `:code runtime`) within the test window.
278
        // `:code interrupted` is therefore the only possible outcome —
279
        // the assertion pins exactly that, which is what makes this a
280
        // real regression guard for the cancel path rather than a test
281
        // the fuel cap can satisfy on its own.
282
1
        let ctx = ScriptCtx::new(Uuid::new_v4()).with_limits(ScriptLimits {
283
1
            fuel: u64::MAX,
284
1
            ..ScriptLimits::default()
285
1
        });
286
1
        let (mut state, mut rx) = state_with_ctx(ctx);
287
1
        state
288
1
            .feed(b"(:id 9 :form (do ((i 0 (+ i 1))) ((>= i 2000000000) i)))\n")
289
1
            .unwrap();
290
        // Yield so the worker picks the form up and starts running.
291
1
        tokio::time::sleep(std::time::Duration::from_millis(20)).await;
292
1
        state.feed(&[ETX]).unwrap();
293
1
        let responses = drain_count(&mut rx, 1).await;
294
1
        assert_eq!(responses.len(), 1, "got {responses:?}");
295
1
        let r = &responses[0];
296
1
        assert!(r.contains(":id 9"), "{r:?}");
297
1
        assert!(r.contains(":code interrupted"), "{r:?}");
298
1
    }
299

            
300
    #[tokio::test(flavor = "current_thread")]
301
1
    async fn feed_rejects_non_utf8() {
302
1
        let (mut state, _rx) = new_state();
303
1
        let invalid: [u8; 2] = [0xC0, 0xC1];
304
1
        assert!(state.feed(&invalid).is_err());
305
1
    }
306
}