1
use wasm_encoder::BlockType;
2

            
3
use crate::ast::{Expr, WasmType};
4
use crate::error::{Error, Result};
5
use crate::runtime::SymbolTable;
6

            
7
use super::super::context::CompileContext;
8
use super::super::emit::FunctionEmitter;
9
use super::super::expr::{
10
    compile_body, compile_body_for_stack, compile_expr, compile_for_effect, compile_for_stack,
11
    compile_nil, compile_quoted_expr, eval_value, serialize_stack_to_output,
12
};
13

            
14
88
pub(super) fn compile_quote(
15
88
    ctx: &mut CompileContext,
16
88
    emit: &mut FunctionEmitter,
17
88
    args: &[Expr],
18
88
) -> Result<()> {
19
88
    if args.len() != 1 {
20
        return Err(Error::Arity {
21
            name: "quote".to_string(),
22
            expected: 1,
23
            actual: args.len(),
24
        });
25
88
    }
26
88
    compile_quoted_expr(ctx, emit, &args[0])
27
88
}
28

            
29
1584
pub(super) fn compile_if(
30
1584
    ctx: &mut CompileContext,
31
1584
    emit: &mut FunctionEmitter,
32
1584
    symbols: &mut SymbolTable,
33
1584
    args: &[Expr],
34
1584
) -> Result<()> {
35
1584
    if args.len() < 2 || args.len() > 3 {
36
88
        return Err(Error::Compile(
37
88
            "IF requires a test, a then-form, and an optional else-form".to_string(),
38
88
        ));
39
1496
    }
40
1496
    let test = eval_value(symbols, &args[0])?;
41
1188
    if matches!(test, Expr::WasmRuntime(WasmType::I32)) {
42
        // Runtime test — compile condition to stack, then emit WASM if/else
43
308
        compile_for_stack(ctx, emit, symbols, &args[0])?;
44
308
        emit.if_block(BlockType::Empty);
45
308
        compile_expr(ctx, emit, symbols, &args[1])?;
46
308
        if args.len() == 3 {
47
264
            emit.else_block();
48
264
            compile_expr(ctx, emit, symbols, &args[2])?;
49
44
        }
50
308
        emit.block_end();
51
308
        return Ok(());
52
1188
    }
53
1188
    if is_truthy(&test) {
54
528
        compile_expr(ctx, emit, symbols, &args[1])
55
660
    } else if args.len() == 3 {
56
572
        compile_expr(ctx, emit, symbols, &args[2])
57
    } else {
58
88
        compile_nil(ctx, emit);
59
88
        Ok(())
60
    }
61
1584
}
62

            
63
88
pub(super) fn compile_begin(
64
88
    ctx: &mut CompileContext,
65
88
    emit: &mut FunctionEmitter,
66
88
    symbols: &mut SymbolTable,
67
88
    args: &[Expr],
68
88
) -> Result<()> {
69
88
    if args.is_empty() {
70
        compile_nil(ctx, emit);
71
        return Ok(());
72
88
    }
73
88
    compile_body(ctx, emit, symbols, args)
74
88
}
75

            
76
pub(super) fn compile_begin_for_stack(
77
    ctx: &mut CompileContext,
78
    emit: &mut FunctionEmitter,
79
    symbols: &mut SymbolTable,
80
    args: &[Expr],
81
) -> Result<WasmType> {
82
    if args.is_empty() {
83
        emit.i32_const(0);
84
        return Ok(WasmType::I32);
85
    }
86
    compile_body_for_stack(ctx, emit, symbols, args)
87
}
88

            
89
88
pub(super) fn compile_and(
90
88
    ctx: &mut CompileContext,
91
88
    emit: &mut FunctionEmitter,
92
88
    symbols: &mut SymbolTable,
93
88
    args: &[Expr],
94
88
) -> Result<()> {
95
88
    if args.is_empty() {
96
        return compile_expr(ctx, emit, symbols, &Expr::Bool(true));
97
88
    }
98

            
99
88
    let mut last = Expr::Bool(true);
100
88
    for (i, arg) in args.iter().enumerate() {
101
88
        let value = eval_value(symbols, arg)?;
102
88
        if matches!(value, Expr::WasmRuntime(WasmType::I32)) {
103
            // Runtime AND: remaining args compiled as runtime short-circuit
104
88
            return compile_and_runtime(ctx, emit, symbols, &args[i..]);
105
        }
106
        if !is_truthy(&value) {
107
            return compile_expr(ctx, emit, symbols, &value);
108
        }
109
        last = value;
110
    }
111
    compile_expr(ctx, emit, symbols, &last)
112
88
}
113

            
114
44
pub(super) fn compile_or(
115
44
    ctx: &mut CompileContext,
116
44
    emit: &mut FunctionEmitter,
117
44
    symbols: &mut SymbolTable,
118
44
    args: &[Expr],
119
44
) -> Result<()> {
120
44
    if args.is_empty() {
121
        compile_nil(ctx, emit);
122
        return Ok(());
123
44
    }
124

            
125
44
    let mut last = Expr::Nil;
126
44
    for (i, arg) in args.iter().enumerate() {
127
44
        let value = eval_value(symbols, arg)?;
128
44
        if matches!(value, Expr::WasmRuntime(WasmType::I32)) {
129
44
            return compile_or_runtime(ctx, emit, symbols, &args[i..]);
130
        }
131
        if is_truthy(&value) {
132
            return compile_expr(ctx, emit, symbols, &value);
133
        }
134
        last = value;
135
    }
136
    compile_expr(ctx, emit, symbols, &last)
137
44
}
138

            
139
396
pub(super) fn compile_cond(
140
396
    ctx: &mut CompileContext,
141
396
    emit: &mut FunctionEmitter,
142
396
    symbols: &mut SymbolTable,
143
396
    args: &[Expr],
144
396
) -> Result<()> {
145
484
    for (i, clause) in args.iter().enumerate() {
146
484
        let elems = clause.as_list().ok_or_else(|| {
147
            Error::Compile(format!("COND: clause must be a list, got {clause:?}"))
148
        })?;
149
484
        if elems.is_empty() {
150
            return Err(Error::Compile("COND: empty clause".to_string()));
151
484
        }
152
484
        let test = eval_value(symbols, &elems[0])?;
153
440
        if matches!(test, Expr::WasmRuntime(WasmType::I32)) {
154
44
            return compile_cond_runtime(ctx, emit, symbols, &args[i..]);
155
440
        }
156
440
        if is_truthy(&test) {
157
264
            if elems.len() == 1 {
158
44
                return compile_expr(ctx, emit, symbols, &test);
159
220
            }
160
220
            return compile_body(ctx, emit, symbols, &elems[1..]);
161
176
        }
162
    }
163
88
    compile_nil(ctx, emit);
164
88
    Ok(())
165
396
}
166

            
167
/// Runtime AND: first arg is already on WASM stack as i32.
168
/// Emit short-circuit: if first is false, skip rest.
169
88
fn compile_and_runtime(
170
88
    ctx: &mut CompileContext,
171
88
    emit: &mut FunctionEmitter,
172
88
    symbols: &mut SymbolTable,
173
88
    args: &[Expr],
174
88
) -> Result<()> {
175
    // Compile first arg for stack (already known to be runtime I32)
176
88
    compile_for_stack(ctx, emit, symbols, &args[0])?;
177
88
    if args.len() == 1 {
178
        serialize_stack_to_output(ctx, emit, WasmType::I32);
179
        return Ok(());
180
88
    }
181
    // if first-arg != 0, evaluate rest; otherwise skip
182
88
    emit.if_block(BlockType::Empty);
183
88
    for arg in &args[1..args.len() - 1] {
184
        compile_for_stack(ctx, emit, symbols, arg)?;
185
        // Check intermediate: if false, skip to end
186
        emit.if_block(BlockType::Empty);
187
    }
188
    // Last arg — compile for side effects
189
88
    compile_expr(ctx, emit, symbols, args.last().unwrap())?;
190
    // Close nested if blocks
191
88
    for _ in 0..args.len() - 1 {
192
88
        emit.block_end();
193
88
    }
194
88
    Ok(())
195
88
}
196

            
197
/// Runtime OR: first arg is already on WASM stack as i32.
198
/// Emit short-circuit: if first is true, skip rest.
199
44
fn compile_or_runtime(
200
44
    ctx: &mut CompileContext,
201
44
    emit: &mut FunctionEmitter,
202
44
    symbols: &mut SymbolTable,
203
44
    args: &[Expr],
204
44
) -> Result<()> {
205
    // Compile first arg for stack
206
44
    compile_for_stack(ctx, emit, symbols, &args[0])?;
207
44
    if args.len() == 1 {
208
        serialize_stack_to_output(ctx, emit, WasmType::I32);
209
        return Ok(());
210
44
    }
211
    // if first-arg == 0 (false), evaluate rest
212
44
    emit.i32_eqz();
213
44
    emit.if_block(BlockType::Empty);
214
44
    for arg in &args[1..args.len() - 1] {
215
        compile_for_stack(ctx, emit, symbols, arg)?;
216
        emit.i32_eqz();
217
        emit.if_block(BlockType::Empty);
218
    }
219
44
    compile_expr(ctx, emit, symbols, args.last().unwrap())?;
220
44
    for _ in 0..args.len() - 1 {
221
44
        emit.block_end();
222
44
    }
223
44
    Ok(())
224
44
}
225

            
226
/// Runtime COND: compile remaining clauses as nested if/else.
227
44
fn compile_cond_runtime(
228
44
    ctx: &mut CompileContext,
229
44
    emit: &mut FunctionEmitter,
230
44
    symbols: &mut SymbolTable,
231
44
    args: &[Expr],
232
44
) -> Result<()> {
233
44
    let depth = args.len();
234
88
    for (i, clause) in args.iter().enumerate() {
235
88
        let elems = clause.as_list().ok_or_else(|| {
236
            Error::Compile(format!("COND: clause must be a list, got {clause:?}"))
237
        })?;
238
88
        if elems.is_empty() {
239
            return Err(Error::Compile("COND: empty clause".to_string()));
240
88
        }
241
88
        let test = eval_value(symbols, &elems[0])?;
242
88
        let is_last = i == depth - 1;
243
88
        if matches!(test, Expr::WasmRuntime(WasmType::I32)) {
244
            // Runtime test on stack
245
88
            compile_for_stack(ctx, emit, symbols, &elems[0])?;
246
88
            emit.if_block(BlockType::Empty);
247
88
            compile_body(ctx, emit, symbols, &elems[1..])?;
248
88
            if !is_last {
249
44
                emit.else_block();
250
44
            }
251
        } else if is_truthy(&test) {
252
            // Constant true — always taken, compile body and stop
253
            compile_body(ctx, emit, symbols, &elems[1..])?;
254
            // Close all open if blocks
255
            for _ in 0..i {
256
                emit.block_end();
257
            }
258
            return Ok(());
259
        }
260
        // Constant false — skip this clause
261
    }
262
    // Close all if blocks
263
88
    for _ in 0..depth {
264
88
        emit.block_end();
265
88
    }
266
44
    Ok(())
267
44
}
268

            
269
/// Compile IF for stack — both branches produce a value on the WASM stack.
270
88
pub(super) fn compile_if_for_stack(
271
88
    ctx: &mut CompileContext,
272
88
    emit: &mut FunctionEmitter,
273
88
    symbols: &mut SymbolTable,
274
88
    args: &[Expr],
275
88
) -> Result<WasmType> {
276
88
    if args.len() < 2 || args.len() > 3 {
277
        return Err(Error::Compile(
278
            "IF requires a test, a then-form, and an optional else-form".to_string(),
279
        ));
280
88
    }
281
88
    let test = eval_value(symbols, &args[0])?;
282
88
    if matches!(test, Expr::WasmRuntime(WasmType::I32)) {
283
        // Runtime test on stack — emit if/else that produces a value
284
88
        compile_for_stack(ctx, emit, symbols, &args[0])?;
285
88
        let then_ty = eval_value(symbols, &args[1])?;
286
88
        let result_ty = match then_ty {
287
88
            Expr::WasmRuntime(t) => t,
288
            _ => WasmType::I32,
289
        };
290
88
        let val_type = ctx.wasm_val_type(result_ty);
291
88
        emit.if_block(BlockType::Result(val_type));
292
88
        compile_for_stack(ctx, emit, symbols, &args[1])?;
293
88
        emit.else_block();
294
88
        if args.len() == 3 {
295
88
            compile_for_stack(ctx, emit, symbols, &args[2])?;
296
        } else {
297
            emit.i32_const(0);
298
        }
299
88
        emit.block_end();
300
88
        return Ok(result_ty);
301
    }
302
    if is_truthy(&test) {
303
        compile_for_stack(ctx, emit, symbols, &args[1])
304
    } else if args.len() == 3 {
305
        compile_for_stack(ctx, emit, symbols, &args[2])
306
    } else {
307
        emit.i32_const(0);
308
        Ok(WasmType::I32)
309
    }
310
88
}
311

            
312
/// Compile AND for stack — returns I32 (boolean result).
313
1206
pub(super) fn compile_and_for_stack(
314
1206
    ctx: &mut CompileContext,
315
1206
    emit: &mut FunctionEmitter,
316
1206
    symbols: &mut SymbolTable,
317
1206
    args: &[Expr],
318
1206
) -> Result<WasmType> {
319
1206
    if args.is_empty() {
320
        emit.i32_const(1);
321
        return Ok(WasmType::I32);
322
1206
    }
323
1206
    if args.len() == 1 {
324
        return compile_for_stack(ctx, emit, symbols, &args[0]);
325
1206
    }
326
    // Compile first arg
327
1206
    compile_for_stack(ctx, emit, symbols, &args[0])?;
328
1340
    for arg in &args[1..] {
329
        // Short-circuit: if top is false (0), skip rest
330
1340
        emit.if_block(BlockType::Result(wasm_encoder::ValType::I32));
331
1340
        compile_for_stack(ctx, emit, symbols, arg)?;
332
1340
        emit.else_block();
333
1340
        emit.i32_const(0);
334
1340
        emit.block_end();
335
    }
336
1206
    Ok(WasmType::I32)
337
1206
}
338

            
339
/// Compile OR for stack — returns I32 (boolean result).
340
pub(super) fn compile_or_for_stack(
341
    ctx: &mut CompileContext,
342
    emit: &mut FunctionEmitter,
343
    symbols: &mut SymbolTable,
344
    args: &[Expr],
345
) -> Result<WasmType> {
346
    if args.is_empty() {
347
        emit.i32_const(0);
348
        return Ok(WasmType::I32);
349
    }
350
    if args.len() == 1 {
351
        return compile_for_stack(ctx, emit, symbols, &args[0]);
352
    }
353
    // Compile first arg
354
    compile_for_stack(ctx, emit, symbols, &args[0])?;
355
    for arg in &args[1..] {
356
        // Short-circuit: if top is true (nonzero), keep it; otherwise try next
357
        emit.i32_eqz();
358
        emit.if_block(BlockType::Result(wasm_encoder::ValType::I32));
359
        compile_for_stack(ctx, emit, symbols, arg)?;
360
        emit.else_block();
361
        emit.i32_const(1);
362
        emit.block_end();
363
    }
364
    Ok(WasmType::I32)
365
}
366

            
367
134
pub(super) fn compile_cond_for_effect(
368
134
    ctx: &mut CompileContext,
369
134
    emit: &mut FunctionEmitter,
370
134
    symbols: &mut SymbolTable,
371
134
    args: &[Expr],
372
134
) -> Result<()> {
373
134
    for (i, clause) in args.iter().enumerate() {
374
134
        let elems = clause.as_list().ok_or_else(|| {
375
            Error::Compile(format!("COND: clause must be a list, got {clause:?}"))
376
        })?;
377
134
        if elems.is_empty() {
378
            return Err(Error::Compile("COND: empty clause".to_string()));
379
134
        }
380
134
        let test = eval_value(symbols, &elems[0])?;
381
134
        if test.is_wasm_runtime() {
382
134
            return compile_cond_runtime_for_effect(ctx, emit, symbols, &args[i..]);
383
        }
384
        if is_truthy(&test) {
385
            for expr in &elems[1..] {
386
                compile_for_effect(ctx, emit, symbols, expr)?;
387
            }
388
            return Ok(());
389
        }
390
    }
391
    Ok(())
392
134
}
393

            
394
134
fn compile_cond_runtime_for_effect(
395
134
    ctx: &mut CompileContext,
396
134
    emit: &mut FunctionEmitter,
397
134
    symbols: &mut SymbolTable,
398
134
    args: &[Expr],
399
134
) -> Result<()> {
400
134
    let depth = args.len();
401
402
    for (i, clause) in args.iter().enumerate() {
402
402
        let elems = clause.as_list().ok_or_else(|| {
403
            Error::Compile(format!("COND: clause must be a list, got {clause:?}"))
404
        })?;
405
402
        if elems.is_empty() {
406
            return Err(Error::Compile("COND: empty clause".to_string()));
407
402
        }
408
402
        let test = eval_value(symbols, &elems[0])?;
409
402
        let is_last = i == depth - 1;
410
402
        if test.is_wasm_runtime() {
411
402
            compile_for_stack(ctx, emit, symbols, &elems[0])?;
412
402
            emit.if_block(BlockType::Empty);
413
402
            for expr in &elems[1..] {
414
402
                compile_for_effect(ctx, emit, symbols, expr)?;
415
            }
416
402
            if !is_last {
417
268
                emit.else_block();
418
268
            }
419
        } else if is_truthy(&test) {
420
            for expr in &elems[1..] {
421
                compile_for_effect(ctx, emit, symbols, expr)?;
422
            }
423
            for _ in 0..i {
424
                emit.block_end();
425
            }
426
            return Ok(());
427
        }
428
    }
429
402
    for _ in 0..depth {
430
402
        emit.block_end();
431
402
    }
432
134
    Ok(())
433
134
}
434

            
435
pub(super) fn compile_and_for_effect(
436
    ctx: &mut CompileContext,
437
    emit: &mut FunctionEmitter,
438
    symbols: &mut SymbolTable,
439
    args: &[Expr],
440
) -> Result<()> {
441
    for (i, arg) in args.iter().enumerate() {
442
        let value = eval_value(symbols, arg)?;
443
        if value.is_wasm_runtime() {
444
            compile_for_stack(ctx, emit, symbols, arg)?;
445
            emit.if_block(BlockType::Empty);
446
            for remaining in &args[i + 1..] {
447
                compile_for_effect(ctx, emit, symbols, remaining)?;
448
            }
449
            emit.block_end();
450
            return Ok(());
451
        }
452
        if !is_truthy(&value) {
453
            return Ok(());
454
        }
455
    }
456
    Ok(())
457
}
458

            
459
pub(super) fn compile_or_for_effect(
460
    ctx: &mut CompileContext,
461
    emit: &mut FunctionEmitter,
462
    symbols: &mut SymbolTable,
463
    args: &[Expr],
464
) -> Result<()> {
465
    for (i, arg) in args.iter().enumerate() {
466
        let value = eval_value(symbols, arg)?;
467
        if value.is_wasm_runtime() {
468
            compile_for_stack(ctx, emit, symbols, arg)?;
469
            emit.i32_eqz();
470
            emit.if_block(BlockType::Empty);
471
            for remaining in &args[i + 1..] {
472
                compile_for_effect(ctx, emit, symbols, remaining)?;
473
            }
474
            emit.block_end();
475
            return Ok(());
476
        }
477
        if is_truthy(&value) {
478
            return Ok(());
479
        }
480
    }
481
    Ok(())
482
}
483

            
484
15840
pub(in crate::compiler) fn is_truthy(expr: &Expr) -> bool {
485
15840
    !matches!(expr, Expr::Nil | Expr::Bool(false))
486
15840
}
487

            
488
794
pub(super) fn if_form(symbols: &mut SymbolTable, args: &[Expr]) -> Result<Expr> {
489
794
    if args.len() < 2 || args.len() > 3 {
490
        return Err(Error::Compile(
491
            "IF requires a test, a then-form, and an optional else-form".to_string(),
492
        ));
493
794
    }
494
794
    let test = eval_value(symbols, &args[0])?;
495
794
    if matches!(test, Expr::WasmRuntime(_)) {
496
        // Runtime test — both branches may produce runtime values
497
134
        let then_ty = eval_value(symbols, &args[1])?;
498
134
        let else_ty = if args.len() == 3 {
499
134
            eval_value(symbols, &args[2])?
500
        } else {
501
            Expr::Nil
502
        };
503
134
        return match (&then_ty, &else_ty) {
504
134
            (Expr::WasmRuntime(t), _) | (_, Expr::WasmRuntime(t)) => Ok(Expr::WasmRuntime(*t)),
505
            _ => Ok(Expr::WasmRuntime(WasmType::I32)),
506
        };
507
660
    }
508
660
    if is_truthy(&test) {
509
660
        eval_value(symbols, &args[1])
510
    } else if args.len() == 3 {
511
        eval_value(symbols, &args[2])
512
    } else {
513
        Ok(Expr::Nil)
514
    }
515
794
}
516

            
517
178
pub(super) fn begin_form(symbols: &mut SymbolTable, args: &[Expr]) -> Result<Expr> {
518
178
    if args.is_empty() {
519
        return Ok(Expr::Nil);
520
178
    }
521
178
    super::binding::eval_body(symbols, args)
522
178
}
523

            
524
1742
pub(super) fn and_form(symbols: &mut SymbolTable, args: &[Expr]) -> Result<Expr> {
525
1742
    if args.is_empty() {
526
        return Ok(Expr::Bool(true));
527
1742
    }
528

            
529
1742
    let mut last = Expr::Bool(true);
530
1742
    for arg in args {
531
1742
        let value = eval_value(symbols, arg)?;
532
1742
        if matches!(value, Expr::WasmRuntime(_)) {
533
1742
            return Ok(Expr::WasmRuntime(WasmType::I32));
534
        }
535
        if !is_truthy(&value) {
536
            return Ok(value);
537
        }
538
        last = value;
539
    }
540
    Ok(last)
541
1742
}
542

            
543
pub(super) fn or_form(symbols: &mut SymbolTable, args: &[Expr]) -> Result<Expr> {
544
    if args.is_empty() {
545
        return Ok(Expr::Nil);
546
    }
547

            
548
    let mut last = Expr::Nil;
549
    for arg in args {
550
        let value = eval_value(symbols, arg)?;
551
        if matches!(value, Expr::WasmRuntime(_)) {
552
            return Ok(Expr::WasmRuntime(WasmType::I32));
553
        }
554
        if is_truthy(&value) {
555
            return Ok(value);
556
        }
557
        last = value;
558
    }
559
    Ok(last)
560
}
561

            
562
402
pub(super) fn cond_form(symbols: &mut SymbolTable, args: &[Expr]) -> Result<Expr> {
563
402
    for clause in args {
564
402
        let elems = clause.as_list().ok_or_else(|| {
565
            Error::Compile(format!("COND: clause must be a list, got {clause:?}"))
566
        })?;
567
402
        if elems.is_empty() {
568
            return Err(Error::Compile("COND: empty clause".to_string()));
569
402
        }
570
402
        let test = eval_value(symbols, &elems[0])?;
571
402
        if matches!(test, Expr::WasmRuntime(_)) {
572
402
            return Ok(Expr::WasmRuntime(WasmType::I32));
573
        }
574
        if is_truthy(&test) {
575
            if elems.len() == 1 {
576
                return Ok(test);
577
            }
578
            return super::binding::eval_body(symbols, &elems[1..]);
579
        }
580
    }
581
    Ok(Expr::Nil)
582
402
}
583

            
584
132
pub(super) fn quote(args: &[Expr]) -> Result<Expr> {
585
132
    if args.len() != 1 {
586
        return Err(Error::Arity {
587
            name: "quote".to_string(),
588
            expected: 1,
589
            actual: args.len(),
590
        });
591
132
    }
592
132
    Ok(args[0].clone())
593
132
}