1
//! `MAP` eval + compile.
2
//!
3
//! Three paths:
4
//! - **Constant-fold**: when the function and every list arg resolves
5
//!   at compile time, `map_fn` walks element-by-element via the eval
6
//!   `call` pipeline and returns the resulting list literal.
7
//! - **Runtime closure, literal list**: emit a per-element FUNCALL,
8
//!   stash each result in a fresh local, then walk the locals in
9
//!   reverse and prepend each car onto the accumulator via `pair_new`.
10
//! - **Runtime closure or runtime `PairRef` list**: walk the input
11
//!   chain at runtime, calling the function on each car, prepend onto
12
//!   a reversed accumulator, and run `emit_reverse_loop` once at the
13
//!   end so the output preserves input order.
14

            
15
use crate::ast::{Expr, PairElement, WasmType};
16
use crate::compiler::context::CompileContext;
17
use crate::compiler::emit::FunctionEmitter;
18
use crate::compiler::expr::{call, compile_expr, compile_for_stack, eval_value, format_expr};
19
use crate::error::{Error, Result};
20
use crate::runtime::SymbolTable;
21

            
22
use super::reverse::emit_reverse_loop;
23

            
24
952
pub(super) fn map_fn(symbols: &mut SymbolTable, args: &[Expr]) -> Result<Expr> {
25
952
    if args.len() < 2 {
26
68
        return Err(Error::Arity {
27
68
            name: "MAP".to_string(),
28
68
            expected: 2,
29
68
            actual: args.len(),
30
68
        });
31
884
    }
32
884
    let fn_resolved = eval_value(symbols, &args[0])?;
33
884
    if let Some(WasmType::Closure(sig)) = fn_resolved.wasm_type() {
34
        // MAP's OUTPUT element is the closure's RESULT type (it transforms each
35
        // element) — NOT the input element. The closure-emit site records the
36
        // result so this eval prediction agrees with codegen (`compile_map_*`
37
        // uses the per-row closure result). Fall back to the input element only
38
        // when the result isn't recorded or isn't a pair-cell type (e.g. a
39
        // nested-pair result, which codegen rejects anyway).
40
136
        let elem = symbols
41
136
            .closure_result(sig)
42
136
            .and_then(PairElement::from_wasm_type)
43
136
            .or(runtime_pair_input_element(symbols, &args[1..])?)
44
136
            .unwrap_or(PairElement::Ratio);
45
136
        return Ok(Expr::WasmRuntime(WasmType::PairRef(elem)));
46
748
    }
47
748
    if list_args_have_runtime_pair(symbols, &args[1..])? {
48
        let elem = runtime_pair_input_element(symbols, &args[1..])?
49
            .ok_or_else(|| Error::Compile("MAP: runtime list element type unknown".to_string()))?;
50
        return Ok(Expr::WasmRuntime(WasmType::PairRef(elem)));
51
748
    }
52
    // Static fn + constant list(s): try to fully reduce. If the body produces
53
    // a runtime value per element (a side-effecting native, or a runtime
54
    // result), MAP ISN'T constant-foldable — surface a runtime placeholder
55
    // typed by the body's per-element result so the compile path lowers each
56
    // element call at runtime instead of baking placeholders into a list.
57
748
    let folded = constant_fold_map(symbols, &args[0], &args[1..])?;
58
680
    if let Some(elem) = mapped_runtime_element(&folded) {
59
        return Ok(Expr::WasmRuntime(WasmType::PairRef(elem)));
60
680
    }
61
680
    Ok(folded)
62
952
}
63

            
64
/// If a folded MAP result is a list whose elements are runtime placeholders
65
/// (the body produced runtime values), returns the unified `PairElement` of
66
/// those results — the signal that MAP must lower at runtime. `None` when the
67
/// fold produced a genuine constant list.
68
1156
fn mapped_runtime_element(folded: &Expr) -> Option<PairElement> {
69
1156
    let elems = match folded {
70
1156
        Expr::Quote(inner) => match inner.as_ref() {
71
1156
            Expr::List(elems) => elems,
72
            _ => return None,
73
        },
74
        Expr::List(elems) => elems,
75
        _ => return None,
76
    };
77
1156
    if !elems.iter().any(Expr::is_wasm_runtime) {
78
884
        return None;
79
272
    }
80
272
    let mut elem: Option<PairElement> = None;
81
612
    for e in elems {
82
612
        let pe = e
83
612
            .wasm_type()
84
612
            .and_then(PairElement::from_wasm_type)
85
612
            .or_else(|| super::infer::literal_pair_element(e))
86
612
            .unwrap_or(PairElement::AnyRef);
87
612
        elem = Some(match elem {
88
340
            Some(prev) => prev.widen(pe),
89
272
            None => pe,
90
        });
91
    }
92
272
    Some(elem.unwrap_or(PairElement::AnyRef))
93
1156
}
94

            
95
1496
fn constant_fold_map(
96
1496
    symbols: &mut SymbolTable,
97
1496
    function_arg: &Expr,
98
1496
    list_args: &[Expr],
99
1496
) -> Result<Expr> {
100
1496
    let lists: Vec<Vec<Expr>> = list_args
101
1496
        .iter()
102
1836
        .map(|arg| {
103
1836
            let resolved = eval_value(symbols, arg)?;
104
1836
            extract_list_elements(&resolved)
105
1836
        })
106
1496
        .collect::<Result<_>>()?;
107

            
108
1428
    let min_len = lists.iter().map(std::vec::Vec::len).min().unwrap_or(0);
109

            
110
1428
    let mut results = Vec::with_capacity(min_len);
111
3060
    for i in 0..min_len {
112
3060
        let mut call_args = vec![function_arg.clone()];
113
3808
        for list in &lists {
114
3808
            call_args.push(as_literal_arg(&list[i]));
115
3808
        }
116
3060
        let result = call(symbols, &call_args)?;
117
2788
        let resolved_result = eval_value(symbols, &result)?;
118
2788
        results.push(resolved_result);
119
    }
120
1156
    Ok(Expr::Quote(Box::new(Expr::List(results))))
121
1496
}
122

            
123
5236
fn as_literal_arg(expr: &Expr) -> Expr {
124
5236
    match expr {
125
476
        Expr::Symbol(_) | Expr::List(_) | Expr::Cons(_, _) => Expr::Quote(Box::new(expr.clone())),
126
4760
        _ => expr.clone(),
127
    }
128
5236
}
129

            
130
1904
fn list_args_have_runtime_pair(symbols: &mut SymbolTable, list_args: &[Expr]) -> Result<bool> {
131
2244
    for arg in list_args {
132
2244
        let resolved = eval_value(symbols, arg)?;
133
2244
        if matches!(resolved.wasm_type(), Some(WasmType::PairRef(_))) {
134
204
            return Ok(true);
135
2040
        }
136
    }
137
1700
    Ok(false)
138
1904
}
139

            
140
136
fn runtime_pair_input_element(
141
136
    symbols: &mut SymbolTable,
142
136
    list_args: &[Expr],
143
136
) -> Result<Option<PairElement>> {
144
136
    for arg in list_args {
145
136
        let resolved = eval_value(symbols, arg)?;
146
136
        if let Some(WasmType::PairRef(elem)) = resolved.wasm_type() {
147
            return Ok(Some(elem));
148
136
        }
149
    }
150
136
    Ok(None)
151
136
}
152

            
153
4692
pub(super) fn extract_list_elements(expr: &Expr) -> Result<Vec<Expr>> {
154
4692
    match expr {
155
        Expr::List(elems) => Ok(elems.clone()),
156
272
        Expr::Nil => Ok(vec![]),
157
4352
        Expr::Quote(inner) => match inner.as_ref() {
158
4352
            Expr::List(elems) => Ok(elems.clone()),
159
            Expr::Nil => Ok(vec![]),
160
            other => Err(Error::Compile(format!(
161
                "MAP expects list arguments, got quoted {}",
162
                format_expr(other)
163
            ))),
164
        },
165
68
        other => Err(Error::Compile(format!(
166
68
            "MAP expects list arguments, got {}",
167
68
            format_expr(other)
168
68
        ))),
169
    }
170
4692
}
171

            
172
1156
pub(super) fn compile_map(
173
1156
    ctx: &mut CompileContext,
174
1156
    emit: &mut FunctionEmitter,
175
1156
    symbols: &mut SymbolTable,
176
1156
    args: &[Expr],
177
1156
) -> Result<()> {
178
1156
    if args.len() < 2 {
179
        return Err(Error::Arity {
180
            name: "MAP".to_string(),
181
            expected: 2,
182
            actual: args.len(),
183
        });
184
1156
    }
185
1156
    let fn_resolved = eval_value(symbols, &args[0])?;
186
1156
    let runtime_closure = matches!(fn_resolved.wasm_type(), Some(WasmType::Closure(_)));
187
1156
    let runtime_list = list_args_have_runtime_pair(symbols, &args[1..])?;
188

            
189
1156
    if runtime_closure || runtime_list {
190
408
        let ty = compile_map_to_stack(ctx, emit, symbols, args)?;
191
408
        return crate::compiler::expr::serialize_stack_to_output(ctx, emit, ty);
192
748
    }
193
748
    let result = constant_fold_map(symbols, &args[0], &args[1..])?;
194
476
    if mapped_runtime_element(&result).is_some() {
195
        // Body isn't constant-foldable — lower the static fn + constant list
196
        // element-by-element via the runtime path.
197
272
        let ty = compile_map_to_stack(ctx, emit, symbols, args)?;
198
272
        return crate::compiler::expr::serialize_stack_to_output(ctx, emit, ty);
199
204
    }
200
204
    compile_expr(ctx, emit, symbols, &result)
201
1156
}
202

            
203
748
pub(super) fn compile_map_to_stack(
204
748
    ctx: &mut CompileContext,
205
748
    emit: &mut FunctionEmitter,
206
748
    symbols: &mut SymbolTable,
207
748
    args: &[Expr],
208
748
) -> Result<WasmType> {
209
748
    if args.len() < 2 {
210
        return Err(Error::Compile(
211
            "MAP: stack-position lowering requires a function and at least one list".to_string(),
212
        ));
213
748
    }
214
    // Single runtime `PairRef` list: walk the chain at runtime.
215
748
    if args.len() == 2
216
680
        && let Some(WasmType::PairRef(elem)) = eval_value(symbols, &args[1])?.wasm_type()
217
    {
218
204
        return compile_map_runtime_list(ctx, emit, symbols, &args[0], &args[1], elem);
219
544
    }
220
    // One or more constant lists with a callable whose body isn't constant-
221
    // foldable (runtime/side-effecting). Zip the constant lists and apply the
222
    // function per row via FUNCALL — serves a closure OR a bare lambda, and any
223
    // arity of lists. (A multi-list MAP over RUNTIME pair chains is not
224
    // supported here — only the single-list runtime chain is.)
225
544
    let lists: Vec<Vec<Expr>> = args[1..]
226
544
        .iter()
227
612
        .map(|arg| {
228
612
            let resolved = eval_value(symbols, arg)?;
229
612
            if matches!(resolved.wasm_type(), Some(WasmType::PairRef(_))) {
230
                return Err(Error::Compile(
231
                    "MAP: a multi-list mapping requires constant lists; a runtime list \
232
                     is only supported as the sole list argument"
233
                        .to_string(),
234
                ));
235
612
            }
236
612
            extract_list_elements(&resolved)
237
612
        })
238
544
        .collect::<Result<_>>()?;
239
544
    compile_map_literal(ctx, emit, symbols, &args[0], &lists)
240
748
}
241

            
242
/// Lower `(map <fn> '(a0 a1 ...) '(b0 b1 ...) ...)` row-by-row over one or more
243
/// constant lists (zipped to the shortest). Each per-row call rides FUNCALL (→
244
/// `call_ref` for a closure, inline for a bare lambda); results land in fresh
245
/// per-row locals. We then walk the locals in reverse and prepend each car onto
246
/// the accumulator via `pair_new` so the output keeps input order without an
247
/// extra reverse pass.
248
544
fn compile_map_literal(
249
544
    ctx: &mut CompileContext,
250
544
    emit: &mut FunctionEmitter,
251
544
    symbols: &mut SymbolTable,
252
544
    fn_arg: &Expr,
253
544
    lists: &[Vec<Expr>],
254
544
) -> Result<WasmType> {
255
544
    let rows = lists.iter().map(Vec::len).min().unwrap_or(0);
256
544
    if rows == 0 {
257
        emit.ref_null(ctx.ids.ty_pair);
258
        return Ok(WasmType::PairRef(PairElement::Ratio));
259
544
    }
260
544
    let mut element_locals: Vec<(u32, PairElement)> = Vec::with_capacity(rows);
261
544
    let mut shared_elem: Option<PairElement> = None;
262
1292
    for row in 0..rows {
263
1292
        let mut call = vec![Expr::Symbol("FUNCALL".to_string()), fn_arg.clone()];
264
1428
        call.extend(lists.iter().map(|list| as_literal_arg(&list[row])));
265
1292
        let funcall = Expr::List(call);
266
1292
        let ty = compile_for_stack(ctx, emit, symbols, &funcall)?;
267
1292
        let elem = PairElement::from_wasm_type(ty).ok_or_else(|| {
268
            Error::Compile(format!(
269
                "MAP: closure result type {ty} can't ride a typed pair; \
270
                 flatten via let-bind first"
271
            ))
272
        })?;
273
        // Heterogeneous per-element results widen the chain to `AnyRef` (the
274
        // ADR-0025 escape hatch) instead of erroring — each car is boxed to
275
        // anyref at prepend. Matches the eval-side `mapped_runtime_element`.
276
1292
        shared_elem = Some(match shared_elem {
277
748
            Some(prev) => prev.widen(elem),
278
544
            None => elem,
279
        });
280
1292
        let local = ctx.alloc_local(elem.as_wasm_type())?;
281
1292
        emit.local_set(local);
282
1292
        element_locals.push((local, elem));
283
    }
284
544
    let Some(chain_elem) = shared_elem else {
285
        return Err(Error::Compile(
286
            "MAP: empty literal-list closure mapping reached prepend phase".to_string(),
287
        ));
288
    };
289
544
    let pair_idx = ctx.ids.ty_pair;
290
544
    let acc_local = ctx.alloc_local(WasmType::PairRef(chain_elem))?;
291
544
    emit.ref_null(pair_idx);
292
544
    emit.local_set(acc_local);
293
1292
    for (local, elem_ty) in element_locals.iter().rev() {
294
1292
        emit.local_get(*local);
295
1292
        // Box each car for the CHAIN's element slot: an AnyRef chain needs
296
1292
        // every car widened to anyref (i31 for i32/bool); a homogeneous chain
297
1292
        // boxes per the element's own type.
298
1292
        box_for_pair_car(emit, chain_car_box(chain_elem, *elem_ty));
299
1292
        emit.local_get(acc_local);
300
1292
        emit.call(ctx.ids.pair_new);
301
1292
        emit.local_set(acc_local);
302
1292
    }
303
544
    emit.local_get(acc_local);
304
544
    Ok(WasmType::PairRef(chain_elem))
305
544
}
306

            
307
/// The `PairElement` to box a car AS when prepending into a chain whose slot is
308
/// `chain_elem`. For an `AnyRef` chain, an i32/bool car must still be i31-boxed
309
/// (so `box_for_pair_car` sees its own value type); ref-typed cars are anyref
310
/// subtypes already. For a homogeneous chain the car boxes per the chain type.
311
1292
fn chain_car_box(chain_elem: PairElement, car_elem: PairElement) -> PairElement {
312
1292
    match chain_elem {
313
136
        PairElement::AnyRef => car_elem,
314
1156
        other => other,
315
    }
316
1292
}
317

            
318
/// Walks a runtime `PairRef(elem)` input. Builds the result reversed
319
/// (one prepend per car) by calling the function on each car, then
320
/// reverses once at the end so the output preserves input order.
321
204
fn compile_map_runtime_list(
322
204
    ctx: &mut CompileContext,
323
204
    emit: &mut FunctionEmitter,
324
204
    symbols: &mut SymbolTable,
325
204
    fn_arg: &Expr,
326
204
    list_expr: &Expr,
327
204
    elem: PairElement,
328
204
) -> Result<WasmType> {
329
204
    let result_elem = elem;
330
204
    let pair_idx = ctx.ids.ty_pair;
331
204
    let pair_local = ctx.alloc_local(WasmType::PairRef(result_elem))?;
332
204
    let acc_local = ctx.alloc_local(WasmType::PairRef(result_elem))?;
333
204
    let car_local = ctx.alloc_local(elem.as_wasm_type())?;
334

            
335
204
    compile_for_stack(ctx, emit, symbols, list_expr)?;
336
204
    emit.local_set(pair_local);
337

            
338
204
    emit.ref_null(pair_idx);
339
204
    emit.local_set(acc_local);
340

            
341
204
    emit.block_start();
342
204
    emit.loop_start();
343

            
344
204
    emit.local_get(pair_local);
345
204
    emit.ref_is_null();
346
204
    emit.br_if(1);
347

            
348
204
    emit.local_get(pair_local);
349
204
    emit.struct_get(pair_idx, 0);
350
204
    crate::compiler::native::list::emit_pair_car_downcast(ctx, emit, elem);
351
204
    emit.local_set(car_local);
352

            
353
204
    let mapped_ty = compile_map_call_with_local(ctx, emit, symbols, fn_arg, car_local, elem)?;
354
204
    let actual_elem = PairElement::from_wasm_type(mapped_ty).ok_or_else(|| {
355
        Error::Compile(format!(
356
            "MAP: closure result type {mapped_ty} can't ride a typed pair; \
357
             flatten via let-bind first"
358
        ))
359
    })?;
360
204
    if actual_elem != result_elem {
361
        return Err(Error::Compile(format!(
362
            "MAP: closure result element {actual_elem} doesn't match input element {result_elem}; \
363
             heterogeneous mapping isn't supported yet"
364
        )));
365
204
    }
366
204
    box_for_pair_car(emit, actual_elem);
367
204
    emit.local_get(acc_local);
368
204
    emit.call(ctx.ids.pair_new);
369
204
    emit.local_set(acc_local);
370

            
371
204
    emit.local_get(pair_local);
372
204
    emit.struct_get(pair_idx, 1);
373
204
    emit.local_set(pair_local);
374

            
375
204
    emit.br(0);
376
204
    emit.block_end();
377
204
    emit.block_end();
378

            
379
    // The accumulator is in reverse order — walk it once more through
380
    // the standard reverse loop so the output preserves input order.
381
204
    let acc_expr = Expr::WasmLocal(acc_local, WasmType::PairRef(result_elem));
382
204
    emit_reverse_loop(ctx, emit, symbols, &acc_expr, result_elem)?;
383
204
    Ok(WasmType::PairRef(result_elem))
384
204
}
385

            
386
1496
fn box_for_pair_car(emit: &mut FunctionEmitter, elem: PairElement) {
387
    // I32 and Bool share the i31-boxed car representation.
388
1496
    if matches!(elem, PairElement::I32 | PairElement::Bool) {
389
408
        emit.ref_i31();
390
1088
    }
391
1496
}
392

            
393
/// Emit a per-iteration call of `fn_arg` against the value held in
394
/// `car_local`. Mirrors the FUNCALL stack-position emit but inlined so
395
/// we don't have to allocate a temporary `Expr` for the local.
396
204
fn compile_map_call_with_local(
397
204
    ctx: &mut CompileContext,
398
204
    emit: &mut FunctionEmitter,
399
204
    symbols: &mut SymbolTable,
400
204
    fn_arg: &Expr,
401
204
    car_local: u32,
402
204
    elem: PairElement,
403
204
) -> Result<WasmType> {
404
204
    let car_expr = Expr::WasmLocal(car_local, elem.as_wasm_type());
405
204
    let funcall = Expr::List(vec![
406
204
        Expr::Symbol("FUNCALL".to_string()),
407
204
        fn_arg.clone(),
408
204
        car_expr,
409
204
    ]);
410
204
    compile_for_stack(ctx, emit, symbols, &funcall)
411
204
}