1
//! `FOLD` — left fold over a list. `(fold f init list)` returns the
2
//! accumulator after applying `f(acc, item)` left-to-right.
3
//!
4
//! Three paths matching MAP / FILTER:
5
//! - **Constant-fold**: function + list both resolve at compile time;
6
//!   walk element-by-element calling `f` via the eval `call` pipeline,
7
//!   threading `init` through.
8
//! - **Runtime closure, literal list**: emit FUNCALL per element,
9
//!   threading the accumulator through a fresh local each iteration.
10
//! - **Runtime closure or runtime `PairRef` list**: walk the input
11
//!   chain at runtime, calling `f` on each car with the accumulator.
12

            
13
use crate::ast::{Expr, PairElement, WasmType};
14
use crate::compiler::context::CompileContext;
15
use crate::compiler::emit::FunctionEmitter;
16
use crate::compiler::expr::{
17
    call, compile_expr, compile_for_effect, compile_for_stack, compile_for_stack_as,
18
    emit_nil_default, eval_value, format_expr,
19
};
20
use crate::error::{Error, Result};
21
use crate::runtime::SymbolTable;
22

            
23
1428
pub(super) fn fold(symbols: &mut SymbolTable, args: &[Expr]) -> Result<Expr> {
24
1428
    if args.len() != 3 {
25
        return Err(Error::Arity {
26
            name: "FOLD".to_string(),
27
            expected: 3,
28
            actual: args.len(),
29
        });
30
1428
    }
31
1428
    let fn_resolved = eval_value(symbols, &args[0])?;
32
1428
    let init_resolved = eval_value(symbols, &args[1])?;
33
1428
    let list_resolved = eval_value(symbols, &args[2])?;
34
1428
    let runtime_closure = matches!(fn_resolved.wasm_type(), Some(WasmType::Closure(_)));
35
1428
    let runtime_list = matches!(list_resolved.wasm_type(), Some(WasmType::PairRef(_)));
36
1428
    if runtime_closure || runtime_list {
37
1428
        let acc_ty = eval_fold_result_type(
38
1428
            symbols,
39
1428
            &fn_resolved,
40
1428
            &args[0],
41
1428
            &init_resolved,
42
1428
            &list_resolved,
43
        )?;
44
1428
        return Ok(Expr::WasmRuntime(acc_ty));
45
    }
46
    // Static fn + constant list: try to fully reduce. If the body produces a
47
    // runtime value (a side-effecting native like `print`, or any host-fn
48
    // result), the fold ISN'T constant-foldable — surface a runtime placeholder
49
    // so the compile path takes the per-element FUNCALL runtime lowering
50
    // instead of baking a bare `WasmRuntime` placeholder into the result.
51
    let folded = constant_fold_fold(symbols, &args[0], &init_resolved, &list_resolved)?;
52
    if folded.is_wasm_runtime() {
53
        let acc_ty = eval_fold_result_type(
54
            symbols,
55
            &fn_resolved,
56
            &args[0],
57
            &init_resolved,
58
            &list_resolved,
59
        )?;
60
        return Ok(Expr::WasmRuntime(acc_ty));
61
    }
62
    Ok(folded)
63
1428
}
64

            
65
/// Eval-path fold result type, kept in lockstep with `compile_fold_to_stack`:
66
/// a runtime closure's accumulator type IS its signature's result type. The
67
/// eval surface has no `CompileContext`, so it reads the result the closure-emit
68
/// site recorded into the symbol table (`SymbolTable::closure_result`); without
69
/// it (closure not recorded) it falls back to the body/seed probe. This stops
70
/// the eval mirror sizing a fold accumulator from the seed (e.g. `0` → Index)
71
/// while codegen emits the closure's real (e.g. Ratio) result.
72
1428
fn eval_fold_result_type(
73
1428
    symbols: &mut SymbolTable,
74
1428
    fn_resolved: &Expr,
75
1428
    fn_arg: &Expr,
76
1428
    init_resolved: &Expr,
77
1428
    list_resolved: &Expr,
78
1428
) -> Result<WasmType> {
79
1428
    if let Some(WasmType::Closure(sig)) = fn_resolved.wasm_type()
80
1428
        && let Some(result) = symbols.closure_result(sig)
81
    {
82
1428
        return Ok(result);
83
    }
84
    accumulator_type(symbols, fn_arg, init_resolved, list_resolved)
85
1428
}
86

            
87
/// The accumulator/result `WasmType` of a fold. A concrete (non-nil) init
88
/// fixes it. A `nil` init is POLYMORPHIC — its `Bool` type is just "empty" —
89
/// so the real accumulator type is whatever the fold body returns; probe it by
90
/// classifying one application `f(init, <sample-elem>)` on a clone (no
91
/// live-table mutation). Falls back to the init's own type when the body can't
92
/// be probed. Shared by the eval mirror and the codegen path so they agree.
93
748
fn accumulator_type(
94
748
    symbols: &mut SymbolTable,
95
748
    fn_arg: &Expr,
96
748
    init_resolved: &Expr,
97
748
    list_resolved: &Expr,
98
748
) -> Result<WasmType> {
99
    // A concrete runtime-typed init (a `WasmLocal`/`WasmRuntime`, NOT nil)
100
    // fixes the accumulator type directly.
101
748
    if let Some(ty) = init_resolved.wasm_type()
102
        && !matches!(init_resolved, Expr::Nil)
103
    {
104
        return Ok(ty);
105
748
    }
106
    // Otherwise (a `nil` polymorphic seed, or a literal init like `0` whose
107
    // own classified type may still be refined by the body) the accumulator
108
    // type is what the body returns — probe one application.
109
748
    if let Some(elem) = sample_element(list_resolved)?
110
748
        && let Some(ty) = probe_body_type(symbols, fn_arg, init_resolved, &elem)?
111
    {
112
748
        return Ok(ty);
113
    }
114
    // Empty list or un-probeable body: fall back to the init's own stack type
115
    // (a literal → its classified type, e.g. `0` → Ratio; nil → Bool).
116
    Ok(crate::compiler::expr::classify_stack_type(init_resolved).unwrap_or(WasmType::Ratio))
117
748
}
118

            
119
/// A representative element value for probing the fold body's return type:
120
/// the runtime element's placeholder for a `PairRef` list, else the first
121
/// constant element. `None` for an empty list (nothing to probe).
122
748
fn sample_element(list_resolved: &Expr) -> Result<Option<Expr>> {
123
748
    if let Some(WasmType::PairRef(elem)) = list_resolved.wasm_type() {
124
612
        return Ok(Some(Expr::WasmRuntime(elem.as_wasm_type())));
125
136
    }
126
136
    Ok(extract_list_elements(list_resolved)?.into_iter().next())
127
748
}
128

            
129
/// Classify the type of one fold application `f(acc, elem)` without emitting —
130
/// evaluated on a CLONE so no live-table mutation, then run through the shared
131
/// stack-type classifier so the prediction matches codegen.
132
748
fn probe_body_type(
133
748
    symbols: &mut SymbolTable,
134
748
    fn_arg: &Expr,
135
748
    acc: &Expr,
136
748
    elem: &Expr,
137
748
) -> Result<Option<WasmType>> {
138
748
    let call_args = vec![fn_arg.clone(), as_literal_arg(acc), as_literal_arg(elem)];
139
748
    let mut probe = symbols.clone();
140
748
    let Ok(result) = call(&mut probe, &call_args) else {
141
        return Ok(None);
142
    };
143
748
    Ok(crate::compiler::expr::classify_stack_type(&result))
144
748
}
145

            
146
204
fn constant_fold_fold(
147
204
    symbols: &mut SymbolTable,
148
204
    fn_arg: &Expr,
149
204
    init_resolved: &Expr,
150
204
    list_resolved: &Expr,
151
204
) -> Result<Expr> {
152
204
    let elements = extract_list_elements(list_resolved)?;
153
204
    let mut acc = init_resolved.clone();
154
476
    for elem in elements {
155
476
        let call_args = vec![fn_arg.clone(), as_literal_arg(&acc), as_literal_arg(&elem)];
156
476
        let result = call(symbols, &call_args)?;
157
476
        acc = eval_value(symbols, &result)?;
158
    }
159
204
    Ok(acc)
160
204
}
161

            
162
4012
fn as_literal_arg(expr: &Expr) -> Expr {
163
4012
    match expr {
164
        Expr::Symbol(_) | Expr::List(_) | Expr::Cons(_, _) => Expr::Quote(Box::new(expr.clone())),
165
4012
        _ => expr.clone(),
166
    }
167
4012
}
168

            
169
/// Seed the accumulator local from `init`, coercing the init to the resolved
170
/// accumulator type. For a `PairRef` accumulator a constant-list seed (`nil`,
171
/// or a quoted literal list) is materialized into a runtime `$pair` chain —
172
/// `nil` is the empty chain (`ref.null pair`). For any other `acc_ty` the init
173
/// must already match.
174
1564
fn emit_fold_init(
175
1564
    ctx: &mut CompileContext,
176
1564
    emit: &mut FunctionEmitter,
177
1564
    symbols: &mut SymbolTable,
178
1564
    init_arg: &Expr,
179
1564
    acc_ty: WasmType,
180
1564
) -> Result<()> {
181
    // Probe the seed's resolved value on a CLONE so a seed that mutates the
182
    // table (SETF/DEFVAR) isn't applied here AND again by the emission path.
183
1564
    let resolved = eval_value(&mut symbols.clone(), init_arg)?;
184
    // A nil seed (for ANY accumulator type, including PairRef where it is the
185
    // empty `$pair` chain) must first emit the effects of a non-literal seed
186
    // that merely *resolves* to nil, then push the typed nil default. Checked
187
    // before the PairRef constant-list materialization so a `(progn …effects… nil)`
188
    // seed isn't silently collapsed to `ref.null`.
189
1564
    if matches!(resolved, Expr::Nil) {
190
408
        if !matches!(init_arg, Expr::Nil) {
191
136
            compile_for_effect(ctx, emit, symbols, init_arg)?;
192
272
        }
193
408
        return emit_nil_default(ctx, emit, acc_ty);
194
1156
    }
195
1156
    if matches!(acc_ty, WasmType::PairRef(_))
196
68
        && !matches!(resolved.wasm_type(), Some(WasmType::PairRef(_)))
197
    {
198
        // A constant-list seed for a runtime pair accumulator: materialize it
199
        // into a fresh `$pair` chain (the cell type is monomorphic), so a fold
200
        // forced onto the runtime path by a side-effecting body can still seed
201
        // a non-nil literal-list accumulator. See `append::push_list_arg`.
202
68
        let elements = super::map::extract_list_elements(&resolved)?;
203
68
        if elements.is_empty() {
204
            emit.ref_null(ctx.ids.ty_pair);
205
            return Ok(());
206
68
        }
207
68
        let members: Vec<Expr> = elements
208
68
            .into_iter()
209
68
            .map(|e| match e {
210
68
                Expr::Number(_) | Expr::String(_) | Expr::Bool(_) | Expr::Nil => e,
211
                other => Expr::Quote(Box::new(other)),
212
68
            })
213
68
            .collect();
214
68
        super::cons::compile_pair_chain(ctx, emit, symbols, &members)?;
215
68
        return Ok(());
216
1088
    }
217
    // Scalar / Index / Money / nil seed: coerce to the accumulator type (an
218
    // integer/fractional literal crosses the sanctioned Index↔Scalar boundary).
219
1088
    compile_for_stack_as(ctx, emit, symbols, init_arg, acc_ty).map_err(|_| {
220
        Error::Compile(format!(
221
            "FOLD: init does not match accumulator type {acc_ty}"
222
        ))
223
    })
224
1564
}
225

            
226
1020
fn extract_list_elements(expr: &Expr) -> Result<Vec<Expr>> {
227
1020
    match expr {
228
        Expr::List(elems) => Ok(elems.clone()),
229
        Expr::Nil => Ok(vec![]),
230
1020
        Expr::Quote(inner) => match inner.as_ref() {
231
1020
            Expr::List(elems) => Ok(elems.clone()),
232
            Expr::Nil => Ok(vec![]),
233
            other => Err(Error::Compile(format!(
234
                "FOLD expects a list, got quoted {}",
235
                format_expr(other)
236
            ))),
237
        },
238
        other => Err(Error::Compile(format!(
239
            "FOLD expects a list, got {}",
240
            format_expr(other)
241
        ))),
242
    }
243
1020
}
244

            
245
884
pub(super) fn compile_fold(
246
884
    ctx: &mut CompileContext,
247
884
    emit: &mut FunctionEmitter,
248
884
    symbols: &mut SymbolTable,
249
884
    args: &[Expr],
250
884
) -> Result<()> {
251
884
    if args.len() != 3 {
252
        return Err(Error::Arity {
253
            name: "FOLD".to_string(),
254
            expected: 3,
255
            actual: args.len(),
256
        });
257
884
    }
258
884
    let fn_resolved = eval_value(symbols, &args[0])?;
259
884
    let init_resolved = eval_value(symbols, &args[1])?;
260
884
    let list_resolved = eval_value(symbols, &args[2])?;
261
884
    let runtime_closure = matches!(fn_resolved.wasm_type(), Some(WasmType::Closure(_)));
262
884
    let runtime_list = matches!(list_resolved.wasm_type(), Some(WasmType::PairRef(_)));
263
884
    if runtime_closure || runtime_list {
264
680
        let ty = compile_fold_to_stack(ctx, emit, symbols, args)?;
265
680
        return crate::compiler::expr::serialize_stack_to_output(ctx, emit, ty);
266
204
    }
267
204
    let result = constant_fold_fold(symbols, &args[0], &init_resolved, &list_resolved)?;
268
204
    if result.is_wasm_runtime() {
269
        // Body produced a runtime value — not constant-foldable. Lower the
270
        // static fn + constant list element-by-element via the runtime path.
271
136
        let ty = compile_fold_to_stack(ctx, emit, symbols, args)?;
272
136
        return crate::compiler::expr::serialize_stack_to_output(ctx, emit, ty);
273
68
    }
274
68
    compile_expr(ctx, emit, symbols, &result)
275
884
}
276

            
277
1564
pub(super) fn compile_fold_to_stack(
278
1564
    ctx: &mut CompileContext,
279
1564
    emit: &mut FunctionEmitter,
280
1564
    symbols: &mut SymbolTable,
281
1564
    args: &[Expr],
282
1564
) -> Result<WasmType> {
283
1564
    if args.len() != 3 {
284
        return Err(Error::Compile(
285
            "FOLD: stack-position lowering needs (fn init list)".to_string(),
286
        ));
287
1564
    }
288
1564
    let fn_resolved = eval_value(symbols, &args[0])?;
289
1564
    let init_resolved = eval_value(symbols, &args[1])?;
290
1564
    let list_resolved = eval_value(symbols, &args[2])?;
291
    // A runtime closure declares its accumulator type in its signature (the
292
    // result type the body returns under its parameter types). Use that rather
293
    // than the eval-side body probe, which sees literal-typed sample args and
294
    // would mis-infer Index where the closure params are Scalar/Money.
295
1564
    let acc_ty = match fn_resolved.wasm_type() {
296
816
        Some(WasmType::Closure(sig)) => ctx.closure_sig(sig).result,
297
748
        _ => accumulator_type(symbols, &args[0], &init_resolved, &list_resolved)?,
298
    };
299

            
300
    // A let-bound closure has a FIXED signature; applying it via `call_ref` over
301
    // a non-Ratio list mismatches its iteration param. If its source body is
302
    // recoverable, inline it per element instead — `compile_lambda_call` binds
303
    // the iteration param to the actual element type. The accumulator type stays
304
    // the closure's declared result (it threads through unchanged).
305
1564
    let call_fn = super::inline_closure_fn(ctx, symbols, &args[0])?;
306

            
307
1564
    if let Some(WasmType::PairRef(elem)) = list_resolved.wasm_type() {
308
884
        return compile_fold_runtime_list(
309
884
            ctx,
310
884
            emit,
311
884
            symbols,
312
884
            FoldRuntimeArgs {
313
884
                fn_arg: &call_fn,
314
884
                init_arg: &args[1],
315
884
                list_expr: &args[2],
316
884
                elem,
317
884
                acc_ty,
318
884
            },
319
        );
320
680
    }
321
    // Any callable over a constant list — runtime closure OR a bare lambda /
322
    // defun whose body isn't constant-foldable (const-fold bailed here). Each
323
    // element is applied via FUNCALL, which serves both shapes.
324
680
    compile_fold_literal(
325
680
        ctx,
326
680
        emit,
327
680
        symbols,
328
680
        &call_fn,
329
680
        &args[1],
330
680
        &list_resolved,
331
680
        acc_ty,
332
    )
333
1564
}
334

            
335
/// Lower `(fold <fn> init '(e0 e1 ...))` element-by-element. Each per-element
336
/// call rides FUNCALL (→ `call_ref` for a closure, inline for a bare lambda)
337
/// and threads the accumulator through a fresh local.
338
680
fn compile_fold_literal(
339
680
    ctx: &mut CompileContext,
340
680
    emit: &mut FunctionEmitter,
341
680
    symbols: &mut SymbolTable,
342
680
    fn_arg: &Expr,
343
680
    init_arg: &Expr,
344
680
    list_resolved: &Expr,
345
680
    acc_ty: WasmType,
346
680
) -> Result<WasmType> {
347
680
    let elements = extract_list_elements(list_resolved)?;
348
680
    let acc_local = ctx.alloc_local(acc_ty)?;
349
680
    emit_fold_init(ctx, emit, symbols, init_arg, acc_ty)?;
350
680
    emit.local_set(acc_local);
351

            
352
1564
    for elem_expr in &elements {
353
1564
        let funcall = Expr::List(vec![
354
1564
            Expr::Symbol("FUNCALL".to_string()),
355
1564
            fn_arg.clone(),
356
1564
            Expr::WasmLocal(acc_local, acc_ty),
357
1564
            as_literal_arg(elem_expr),
358
1564
        ]);
359
1564
        let result_ty = compile_for_stack(ctx, emit, symbols, &funcall)?;
360
1564
        if result_ty != acc_ty {
361
            return Err(Error::Compile(format!(
362
                "FOLD: closure return type {result_ty} doesn't match accumulator type {acc_ty}"
363
            )));
364
1564
        }
365
1564
        emit.local_set(acc_local);
366
    }
367
680
    emit.local_get(acc_local);
368
680
    Ok(acc_ty)
369
680
}
370

            
371
struct FoldRuntimeArgs<'a> {
372
    fn_arg: &'a Expr,
373
    init_arg: &'a Expr,
374
    list_expr: &'a Expr,
375
    elem: PairElement,
376
    acc_ty: WasmType,
377
}
378

            
379
/// Walks a runtime `PairRef(elem)` input. Calls `f(acc, car)` per cell,
380
/// threading the accumulator through a fresh local.
381
884
fn compile_fold_runtime_list(
382
884
    ctx: &mut CompileContext,
383
884
    emit: &mut FunctionEmitter,
384
884
    symbols: &mut SymbolTable,
385
884
    args: FoldRuntimeArgs<'_>,
386
884
) -> Result<WasmType> {
387
    let FoldRuntimeArgs {
388
884
        fn_arg,
389
884
        init_arg,
390
884
        list_expr,
391
884
        elem,
392
884
        acc_ty,
393
884
    } = args;
394
884
    let pair_idx = ctx.ids.ty_pair;
395
884
    let pair_local = ctx.alloc_local(WasmType::PairRef(elem))?;
396
884
    let acc_local = ctx.alloc_local(acc_ty)?;
397
884
    let car_local = ctx.alloc_local(elem.as_wasm_type())?;
398

            
399
884
    emit_fold_init(ctx, emit, symbols, init_arg, acc_ty)?;
400
884
    emit.local_set(acc_local);
401

            
402
884
    compile_for_stack(ctx, emit, symbols, list_expr)?;
403
884
    emit.local_set(pair_local);
404

            
405
884
    emit.block_start();
406
884
    emit.loop_start();
407

            
408
884
    emit.local_get(pair_local);
409
884
    emit.ref_is_null();
410
884
    emit.br_if(1);
411

            
412
884
    emit.local_get(pair_local);
413
884
    emit.struct_get(pair_idx, 0);
414
884
    crate::compiler::native::list::emit_pair_car_downcast(ctx, emit, elem);
415
884
    emit.local_set(car_local);
416

            
417
884
    let funcall = Expr::List(vec![
418
884
        Expr::Symbol("FUNCALL".to_string()),
419
884
        fn_arg.clone(),
420
884
        Expr::WasmLocal(acc_local, acc_ty),
421
884
        Expr::WasmLocal(car_local, elem.as_wasm_type()),
422
884
    ]);
423
884
    let result_ty = compile_for_stack(ctx, emit, symbols, &funcall)?;
424
884
    if result_ty != acc_ty {
425
        return Err(Error::Compile(format!(
426
            "FOLD: closure return type {result_ty} doesn't match accumulator type {acc_ty}"
427
        )));
428
884
    }
429
884
    emit.local_set(acc_local);
430

            
431
884
    emit.local_get(pair_local);
432
884
    emit.struct_get(pair_idx, 1);
433
884
    emit.local_set(pair_local);
434

            
435
884
    emit.br(0);
436
884
    emit.block_end();
437
884
    emit.block_end();
438

            
439
884
    emit.local_get(acc_local);
440
884
    Ok(acc_ty)
441
884
}