1
//! Closure parameter-type inference from body usage.
2
//!
3
//! A value-position `(lambda …)` lifted to a closure needs user-visible param
4
//! types. Defaulting every param to `Ratio` mis-types a list accumulator (a
5
//! param used as a `(cons CAR param)` cdr is a pair, not a scalar), so FOLD /
6
//! MAP over such a closure rejected it. This module infers each required
7
//! param's `WasmType` from how the body uses it; the closure's result type then
8
//! follows from compiling the body under those param types.
9
//!
10
//! Today's single inferred shape: a param used as the CDR of `(cons CAR param)`
11
//! → `PairRef(<car cell element>)`. A param with no such use stays `Ratio`
12
//! (identical to the old default — no existing scalar closure regresses). The
13
//! scan is SCOPE-AWARE: a binder (`lambda`/`let`/`let*`/`do`/`do*`/`dolist`)
14
//! that rebinds the name shadows the param in its BODY, so a cons-cdr use of the
15
//! inner variable must not retype the outer param — but the binder's INIT
16
//! expressions (and a `dolist` list-expr) are in the OUTER scope and ARE
17
//! scanned.
18

            
19
use crate::ast::{Expr, LambdaParams, PairElement, WasmType};
20

            
21
/// Infers each required param's user-visible `WasmType` from `body`.
22
2856
pub(super) fn infer_param_types(params: &LambdaParams, body: &Expr) -> Vec<WasmType> {
23
2856
    params
24
2856
        .required
25
2856
        .iter()
26
3876
        .map(|name| pinned_param_type(body, name).unwrap_or(WasmType::Ratio))
27
2856
        .collect()
28
2856
}
29

            
30
/// The `WasmType` `name` is pinned to by a type-constraining use in `body`,
31
/// anywhere `name` still refers to the searched param (not inside a scope that
32
/// rebinds it). Two pins today:
33
/// - CDR of a `(cons CAR name)` → `PairRef(<car cell element>)` (list accumulator);
34
/// - argument to a string native (`string=`, `upcase-string`) → `StringRef`.
35
///
36
/// `None` if no pinning use is found (caller defaults to `Ratio`).
37
11722
fn pinned_param_type(body: &Expr, name: &str) -> Option<WasmType> {
38
11722
    let Expr::List(elems) = body else {
39
7622
        return None;
40
    };
41
4100
    if let [Expr::Symbol(head), car, Expr::Symbol(cdr)] = elems.as_slice()
42
1843
        && head.eq_ignore_ascii_case("cons")
43
141
        && cdr == name
44
    {
45
72
        return Some(WasmType::PairRef(car_cell_element(car)));
46
4028
    }
47
4028
    if let Some(Expr::Symbol(head)) = elems.first()
48
4027
        && is_string_native(head)
49
69
        && elems[1..]
50
69
            .iter()
51
69
            .any(|a| matches!(a, Expr::Symbol(s) if s == name))
52
    {
53
69
        return Some(WasmType::StringRef);
54
3959
    }
55
3959
    match head_upper(elems).as_deref() {
56
        // Quoted forms are data, not evaluated — a param mentioned inside must
57
        // not pin its type. (The `Expr::Quote`/`Expr::Quasiquote` AST variants
58
        // already return `None` above by not being a `List`; this guards the
59
        // `(quote …)` / `(quasiquote …)` list spellings.)
60
3958
        Some("QUOTE") | Some("QUASIQUOTE") => None,
61
3956
        Some("LET") => scan_let(elems, name, false),
62
3953
        Some("LET*") => scan_let(elems, name, true),
63
3952
        Some("DO") | Some("DO*") => scan_do(elems, name),
64
3949
        Some("DOLIST") => scan_dolist(elems, name),
65
3947
        Some("LAMBDA") => {
66
            // The body is shadowed iff `name` is one of the lambda's params; a
67
            // param list carries no scannable outer-scope subform.
68
1
            if lambda_binds(elems, name) {
69
                None
70
            } else {
71
1
                scan_each(tail(elems, 1), name)
72
            }
73
        }
74
3947
        _ => scan_each(tail(elems, 1), name),
75
    }
76
11722
}
77

            
78
/// Natives that REQUIRE a `StringRef` argument, so a param passed to one is a
79
/// string. (Only these two exist today; extend alongside new string natives.)
80
4027
fn is_string_native(head: &str) -> bool {
81
69
    matches!(
82
4027
        head.to_ascii_uppercase().as_str(),
83
4027
        "STRING=" | "UPCASE-STRING"
84
    )
85
4027
}
86

            
87
/// The `PairElement` a car expression contributes. Literal cars pin their slot;
88
/// a sibling-param / captured / computed car has no statically-known cell type,
89
/// so it widens to `AnyRef` — keeping the cdr param's pair element consistent
90
/// with whatever `compile_cons_to_stack` actually produces for that car.
91
76
fn car_cell_element(car: &Expr) -> PairElement {
92
76
    match car {
93
4
        Expr::Number(_) => PairElement::Ratio,
94
1
        Expr::Bool(_) => PairElement::Bool,
95
1
        Expr::String(_) => PairElement::StringRef,
96
70
        _ => PairElement::AnyRef,
97
    }
98
76
}
99

            
100
3959
fn head_upper(elems: &[Expr]) -> Option<String> {
101
3959
    match elems.first() {
102
3958
        Some(Expr::Symbol(s)) => Some(s.to_ascii_uppercase()),
103
1
        _ => None,
104
    }
105
3959
}
106

            
107
3954
fn scan_each(forms: &[Expr], name: &str) -> Option<WasmType> {
108
7830
    forms.iter().find_map(|e| pinned_param_type(e, name))
109
3954
}
110

            
111
/// `forms[from..]`, but never panics on a malformed/short form: a list with
112
/// fewer than `from` elements (e.g. a bare `()`, `(let)`, `(do)`) yields an
113
/// empty slice rather than an out-of-range slice index. Param-type inference is
114
/// a best-effort hint, so a too-short form simply contributes no pin; the
115
/// malformed form is rejected (or compiled) by the real path later.
116
3954
fn tail(forms: &[Expr], from: usize) -> &[Expr] {
117
3954
    forms.get(from..).unwrap_or(&[])
118
3954
}
119

            
120
/// `(let ((v init)…) body…)` / `(let* …)`. Binding INITs are in the outer scope
121
/// (for LET* up to the binding that introduces `name`); the body is shadowed
122
/// once any binding (re)binds `name`.
123
4
fn scan_let(elems: &[Expr], name: &str, sequential: bool) -> Option<WasmType> {
124
4
    let mut shadowed = false;
125
4
    if let Some(Expr::List(bindings)) = elems.get(1) {
126
2
        for binding in bindings {
127
2
            let Expr::List(parts) = binding else { continue };
128
2
            let binds_here = matches!(parts.first(), Some(Expr::Symbol(s)) if s == name);
129
            // For LET* a binding shadows `name` for SUBSEQUENT inits; for LET
130
            // every init is in the outer scope regardless.
131
2
            if !(sequential && shadowed)
132
2
                && let Some(init) = parts.get(1)
133
2
                && let Some(ty) = pinned_param_type(init, name)
134
            {
135
1
                return Some(ty);
136
1
            }
137
1
            if binds_here {
138
1
                shadowed = true;
139
1
            }
140
        }
141
2
    }
142
3
    if shadowed {
143
1
        None
144
    } else {
145
2
        scan_each(tail(elems, 2), name)
146
    }
147
4
}
148

            
149
/// `(do ((v init step)…) (end res…) body…)`. Inits are outer scope; steps, the
150
/// end clause, and the body see the loop vars, so they're shadowed once any var
151
/// is `name`.
152
3
fn scan_do(elems: &[Expr], name: &str) -> Option<WasmType> {
153
3
    let mut shadowed = false;
154
3
    if let Some(Expr::List(vars)) = elems.get(1) {
155
        for var in vars {
156
            let Expr::List(parts) = var else { continue };
157
            if let Some(init) = parts.get(1)
158
                && let Some(ty) = pinned_param_type(init, name)
159
            {
160
                return Some(ty);
161
            }
162
            if matches!(parts.first(), Some(Expr::Symbol(s)) if s == name) {
163
                shadowed = true;
164
            }
165
        }
166
3
    }
167
3
    if shadowed {
168
        None
169
    } else {
170
3
        scan_each(tail(elems, 2), name)
171
    }
172
3
}
173

            
174
/// `(dolist (var list-expr [result]) body…)`. The list-expr is outer scope; the
175
/// result form and body see the loop var.
176
2
fn scan_dolist(elems: &[Expr], name: &str) -> Option<WasmType> {
177
2
    let Some(Expr::List(spec)) = elems.get(1) else {
178
1
        return scan_each(tail(elems, 1), name);
179
    };
180
1
    if let Some(list_expr) = spec.get(1)
181
1
        && let Some(ty) = pinned_param_type(list_expr, name)
182
    {
183
1
        return Some(ty);
184
    }
185
    let var_is_name = matches!(spec.first(), Some(Expr::Symbol(s)) if s == name);
186
    if var_is_name {
187
        None
188
    } else {
189
        scan_each(tail(elems, 2), name)
190
    }
191
2
}
192

            
193
1
fn lambda_binds(elems: &[Expr], name: &str) -> bool {
194
1
    matches!(elems.get(1), Some(Expr::List(params))
195
        if params.iter().any(|p| matches!(p, Expr::Symbol(s) if s == name)))
196
1
}
197

            
198
#[cfg(test)]
199
mod tests {
200
    use super::*;
201
    use crate::ast::Fraction;
202

            
203
42
    fn sym(s: &str) -> Expr {
204
42
        Expr::Symbol(s.to_string())
205
42
    }
206
30
    fn list(items: Vec<Expr>) -> Expr {
207
30
        Expr::List(items)
208
30
    }
209
7
    fn num(n: i64) -> Expr {
210
7
        Expr::Number(Fraction::from_integer(n))
211
7
    }
212

            
213
4
    fn pair(elem: PairElement) -> Option<WasmType> {
214
4
        Some(WasmType::PairRef(elem))
215
4
    }
216

            
217
    #[test]
218
1
    fn cons_cdr_literal_car_is_that_cell() {
219
1
        let body = list(vec![sym("CONS"), num(1), sym("acc")]);
220
1
        assert_eq!(pinned_param_type(&body, "acc"), pair(PairElement::Ratio));
221
1
    }
222

            
223
    #[test]
224
1
    fn cons_cdr_nonliteral_car_widens_to_anyref() {
225
1
        let body = list(vec![sym("CONS"), sym("x"), sym("acc")]);
226
1
        assert_eq!(pinned_param_type(&body, "acc"), pair(PairElement::AnyRef));
227
1
        assert_eq!(pinned_param_type(&body, "x"), None);
228
1
    }
229

            
230
    #[test]
231
1
    fn no_cons_cdr_use_is_unconstrained() {
232
1
        let body = list(vec![sym("*"), sym("x"), num(2)]);
233
1
        assert_eq!(pinned_param_type(&body, "x"), None);
234
1
    }
235

            
236
    #[test]
237
1
    fn string_native_arg_pins_stringref() {
238
        // (string= s "x") → s is a string.
239
1
        let body = list(vec![sym("STRING="), sym("s"), Expr::String("x".into())]);
240
1
        assert_eq!(pinned_param_type(&body, "s"), Some(WasmType::StringRef));
241
1
    }
242

            
243
    #[test]
244
1
    fn quoted_string_native_does_not_pin() {
245
        // (quote (string= x "a")) — quoted data must NOT pin `x` to StringRef.
246
1
        let body = list(vec![
247
1
            sym("quote"),
248
1
            list(vec![sym("string="), sym("x"), Expr::String("a".into())]),
249
        ]);
250
1
        assert_eq!(pinned_param_type(&body, "x"), None);
251
1
    }
252

            
253
    #[test]
254
1
    fn quasiquoted_string_native_does_not_pin() {
255
1
        let body = list(vec![
256
1
            sym("quasiquote"),
257
1
            list(vec![sym("string="), sym("x"), Expr::String("a".into())]),
258
        ]);
259
1
        assert_eq!(pinned_param_type(&body, "x"), None);
260
1
    }
261

            
262
    #[test]
263
1
    fn shadowing_let_body_does_not_retype_outer_param() {
264
        // (let ((x nil)) (cons 1 x)) — cdr is the INNER `x`; outer stays None.
265
1
        let inner = list(vec![
266
1
            sym("LET"),
267
1
            list(vec![list(vec![sym("x"), Expr::Nil])]),
268
1
            list(vec![sym("CONS"), num(1), sym("x")]),
269
        ]);
270
1
        assert_eq!(pinned_param_type(&inner, "x"), None);
271
1
    }
272

            
273
    #[test]
274
1
    fn let_initializer_cons_cdr_is_still_outer_scope() {
275
        // (let ((acc (cons 1 acc))) acc) — the init `(cons 1 acc)` references the
276
        // OUTER `acc` (the binding isn't in scope for its own init), so `acc` IS
277
        // a list accumulator. Scanning must reach the init even though the let
278
        // rebinds `acc`.
279
1
        let form = list(vec![
280
1
            sym("LET"),
281
1
            list(vec![list(vec![
282
1
                sym("acc"),
283
1
                list(vec![sym("CONS"), num(1), sym("acc")]),
284
            ])]),
285
1
            sym("acc"),
286
        ]);
287
1
        assert_eq!(pinned_param_type(&form, "acc"), pair(PairElement::Ratio));
288
1
    }
289

            
290
    #[test]
291
1
    fn dolist_list_expr_cons_cdr_is_outer_scope() {
292
        // (dolist (e (cons 1 acc)) e) — the list expr references outer `acc`.
293
1
        let form = list(vec![
294
1
            sym("DOLIST"),
295
1
            list(vec![sym("e"), list(vec![sym("CONS"), num(1), sym("acc")])]),
296
1
            sym("e"),
297
        ]);
298
1
        assert_eq!(pinned_param_type(&form, "acc"), pair(PairElement::Ratio));
299
1
    }
300

            
301
    #[test]
302
1
    fn lambda_binder_shadows_outer_param() {
303
1
        let body = list(vec![
304
1
            list(vec![
305
1
                sym("LAMBDA"),
306
1
                list(vec![sym("x")]),
307
1
                list(vec![sym("CONS"), num(1), sym("x")]),
308
            ]),
309
1
            sym("y"),
310
        ]);
311
1
        assert_eq!(pinned_param_type(&body, "x"), None);
312
1
    }
313

            
314
    #[test]
315
1
    fn short_binder_body_does_not_panic_on_slice() {
316
        // Regression (AFL): a too-short binder form — `(do)`, `(let)`,
317
        // `(lambda)` — used to index `&elems[2..]` / `&elems[1..]` out of range
318
        // and panic. Param inference is best-effort, so a malformed form must
319
        // contribute no pin (return None), never crash.
320
6
        for head in ["DO", "DO*", "LET", "LET*", "DOLIST", "LAMBDA"] {
321
6
            let bare = list(vec![sym(head)]);
322
6
            assert_eq!(pinned_param_type(&bare, "x"), None, "bare ({head})");
323
        }
324
        // Nested under a lambda body, mirroring the fuzz-found shape
325
        // `(lambda (s) (do))`.
326
1
        let nested = list(vec![sym("DO")]);
327
1
        assert_eq!(pinned_param_type(&nested, "s"), None);
328
1
    }
329

            
330
    #[test]
331
1
    fn literal_car_cells_map_to_their_slots() {
332
1
        assert_eq!(car_cell_element(&num(3)), PairElement::Ratio);
333
1
        assert_eq!(car_cell_element(&Expr::Bool(true)), PairElement::Bool);
334
1
        assert_eq!(
335
1
            car_cell_element(&Expr::String("s".into())),
336
            PairElement::StringRef
337
        );
338
1
        assert_eq!(car_cell_element(&sym("v")), PairElement::AnyRef);
339
1
    }
340
}