1
use anodized::spec;
2

            
3
use super::super::context::CompileContext;
4
use super::super::emit::FunctionEmitter;
5
use super::super::expr::{
6
    compile_for_stack_ratio, compile_number, eval_value, format_expr, push_ratio,
7
    serialize_stack_to_output,
8
};
9
use crate::ast::{Expr, Fraction, WasmType};
10
use crate::error::{Error, Result};
11
use crate::runtime::SymbolTable;
12

            
13
6350
fn try_fold(symbols: &mut SymbolTable, args: &[Expr]) -> Option<Vec<Fraction>> {
14
6350
    args.iter()
15
9386
        .map(|arg| {
16
9386
            eval_value(symbols, arg).ok().and_then(|e| match e {
17
5456
                Expr::Number(n) => Some(n),
18
3842
                _ => None,
19
9298
            })
20
9386
        })
21
6350
        .collect()
22
6350
}
23

            
24
308
fn validate_ratio_args(resolved: &[Expr], name: &str) -> Result<()> {
25
616
    for r in resolved {
26
        match r {
27
308
            Expr::Number(_) => {}
28
308
            _ if r.wasm_type() == Some(WasmType::Ratio) => {}
29
            _ if r.wasm_type() == Some(WasmType::I32) => {
30
                return Err(Error::Compile(format!(
31
                    "{name} requires ratio values, not indices"
32
                )));
33
            }
34
            other => {
35
                return Err(Error::Compile(format!(
36
                    "{name} expects number arguments, got {}",
37
                    format_expr(other)
38
                )));
39
            }
40
        }
41
    }
42
308
    Ok(())
43
308
}
44

            
45
#[spec(ensures: [output.as_ref().map_or(true, |v| v.len() == args.len())])]
46
pub(super) fn resolve_numbers(
47
    symbols: &mut SymbolTable,
48
    args: &[Expr],
49
    name: &str,
50
) -> Result<Vec<Fraction>> {
51
    args.iter()
52
        .map(|arg| {
53
            let resolved = eval_value(symbols, arg)?;
54
            match resolved {
55
                Expr::Number(n) => Ok(n),
56
                other => Err(Error::Compile(format!(
57
                    "{name} expects number arguments, got {}",
58
                    format_expr(&other)
59
                ))),
60
            }
61
        })
62
        .collect()
63
}
64

            
65
// --- Eval functions (compile-time evaluation, WasmRuntime-aware) ---
66

            
67
25388
pub(super) fn add(symbols: &mut SymbolTable, args: &[Expr]) -> Result<Expr> {
68
25388
    let resolved: Vec<Expr> = args
69
25388
        .iter()
70
50556
        .map(|arg| eval_value(symbols, arg))
71
25388
        .collect::<Result<_>>()?;
72
25388
    if let Some(nums) = try_fold_resolved(&resolved) {
73
25212
        let sum = nums
74
25212
            .into_iter()
75
50204
            .fold(Fraction::from_integer(0), |a, b| a + b);
76
25212
        return Ok(Expr::Number(sum));
77
176
    }
78
176
    validate_ratio_args(&resolved, "+")?;
79
176
    Ok(Expr::WasmRuntime(WasmType::Ratio))
80
25388
}
81

            
82
1716
pub(super) fn sub(symbols: &mut SymbolTable, args: &[Expr]) -> Result<Expr> {
83
1716
    if args.is_empty() {
84
        return Err(Error::Compile("- requires at least 1 argument".to_string()));
85
1716
    }
86
1716
    let resolved: Vec<Expr> = args
87
1716
        .iter()
88
3432
        .map(|arg| eval_value(symbols, arg))
89
1716
        .collect::<Result<_>>()?;
90
1716
    if let Some(nums) = try_fold_resolved(&resolved) {
91
1716
        let result = if nums.len() == 1 {
92
            -nums[0]
93
        } else {
94
1716
            nums[1..].iter().fold(nums[0], |a, b| a - b)
95
        };
96
1716
        return Ok(Expr::Number(result));
97
    }
98
    validate_ratio_args(&resolved, "-")?;
99
    Ok(Expr::WasmRuntime(WasmType::Ratio))
100
1716
}
101

            
102
792
pub(super) fn mul(symbols: &mut SymbolTable, args: &[Expr]) -> Result<Expr> {
103
792
    let resolved: Vec<Expr> = args
104
792
        .iter()
105
1584
        .map(|arg| eval_value(symbols, arg))
106
792
        .collect::<Result<_>>()?;
107
792
    if let Some(nums) = try_fold_resolved(&resolved) {
108
660
        let product = nums
109
660
            .into_iter()
110
1320
            .fold(Fraction::from_integer(1), |a, b| a * b);
111
660
        return Ok(Expr::Number(product));
112
132
    }
113
132
    validate_ratio_args(&resolved, "*")?;
114
132
    Ok(Expr::WasmRuntime(WasmType::Ratio))
115
792
}
116

            
117
220
pub(super) fn div(symbols: &mut SymbolTable, args: &[Expr]) -> Result<Expr> {
118
220
    if args.is_empty() {
119
        return Err(Error::Compile("/ requires at least 1 argument".to_string()));
120
220
    }
121
220
    let resolved: Vec<Expr> = args
122
220
        .iter()
123
440
        .map(|arg| eval_value(symbols, arg))
124
220
        .collect::<Result<_>>()?;
125
220
    if let Some(nums) = try_fold_resolved(&resolved) {
126
220
        let result = if nums.len() == 1 {
127
            if *nums[0].numer() == 0 {
128
                return Err(Error::Compile("division by zero".to_string()));
129
            }
130
            nums[0].recip()
131
        } else {
132
220
            let mut acc = nums[0];
133
220
            for &n in &nums[1..] {
134
220
                if *n.numer() == 0 {
135
                    return Err(Error::Compile("division by zero".to_string()));
136
220
                }
137
220
                acc /= n;
138
            }
139
220
            acc
140
        };
141
220
        return Ok(Expr::Number(result));
142
    }
143
    validate_ratio_args(&resolved, "/")?;
144
    Ok(Expr::WasmRuntime(WasmType::Ratio))
145
220
}
146

            
147
pub(super) fn modulo(symbols: &mut SymbolTable, args: &[Expr]) -> Result<Expr> {
148
    if args.len() != 2 {
149
        return Err(Error::Arity {
150
            name: "MOD".to_string(),
151
            expected: 2,
152
            actual: args.len(),
153
        });
154
    }
155
    let resolved: Vec<Expr> = args
156
        .iter()
157
        .map(|arg| eval_value(symbols, arg))
158
        .collect::<Result<_>>()?;
159
    if let Some(nums) = try_fold_resolved(&resolved) {
160
        if *nums[1].numer() == 0 {
161
            return Err(Error::Compile("division by zero in MOD".to_string()));
162
        }
163
        let quotient = (nums[0] / nums[1]).floor();
164
        let result = nums[0] - nums[1] * quotient;
165
        return Ok(Expr::Number(result));
166
    }
167
    validate_ratio_args(&resolved, "MOD")?;
168
    Ok(Expr::WasmRuntime(WasmType::Ratio))
169
}
170

            
171
28116
fn try_fold_resolved(resolved: &[Expr]) -> Option<Vec<Fraction>> {
172
28116
    resolved
173
28116
        .iter()
174
55704
        .map(|e| match e {
175
55396
            Expr::Number(n) => Some(*n),
176
308
            _ => None,
177
55704
        })
178
28116
        .collect()
179
28116
}
180

            
181
// --- Compile functions (serializing, for top-level output) ---
182

            
183
1892
pub(super) fn compile_add(
184
1892
    ctx: &mut CompileContext,
185
1892
    emit: &mut FunctionEmitter,
186
1892
    symbols: &mut SymbolTable,
187
1892
    args: &[Expr],
188
1892
) -> Result<()> {
189
1892
    if let Some(nums) = try_fold(symbols, args) {
190
1452
        let sum = nums
191
1452
            .into_iter()
192
3300
            .fold(Fraction::from_integer(0), |a, b| a + b);
193
1452
        compile_number(ctx, emit, *sum.numer(), *sum.denom());
194
1452
        return Ok(());
195
440
    }
196
440
    let ty = compile_add_to_stack(ctx, emit, symbols, args)?;
197
396
    serialize_stack_to_output(ctx, emit, ty);
198
396
    Ok(())
199
1892
}
200

            
201
528
pub(super) fn compile_sub(
202
528
    ctx: &mut CompileContext,
203
528
    emit: &mut FunctionEmitter,
204
528
    symbols: &mut SymbolTable,
205
528
    args: &[Expr],
206
528
) -> Result<()> {
207
528
    if args.is_empty() {
208
44
        return Err(Error::Compile("- requires at least 1 argument".to_string()));
209
484
    }
210
484
    if let Some(nums) = try_fold(symbols, args) {
211
220
        let result = if nums.len() == 1 {
212
44
            -nums[0]
213
        } else {
214
308
            nums[1..].iter().fold(nums[0], |a, b| a - b)
215
        };
216
220
        compile_number(ctx, emit, *result.numer(), *result.denom());
217
220
        return Ok(());
218
264
    }
219
264
    let ty = compile_sub_to_stack(ctx, emit, symbols, args)?;
220
264
    serialize_stack_to_output(ctx, emit, ty);
221
264
    Ok(())
222
528
}
223

            
224
440
pub(super) fn compile_mul(
225
440
    ctx: &mut CompileContext,
226
440
    emit: &mut FunctionEmitter,
227
440
    symbols: &mut SymbolTable,
228
440
    args: &[Expr],
229
440
) -> Result<()> {
230
440
    if let Some(nums) = try_fold(symbols, args) {
231
352
        let product = nums
232
352
            .into_iter()
233
660
            .fold(Fraction::from_integer(1), |a, b| a * b);
234
352
        compile_number(ctx, emit, *product.numer(), *product.denom());
235
352
        return Ok(());
236
88
    }
237
88
    let ty = compile_mul_to_stack(ctx, emit, symbols, args)?;
238
44
    serialize_stack_to_output(ctx, emit, ty);
239
44
    Ok(())
240
440
}
241

            
242
308
pub(super) fn compile_div(
243
308
    ctx: &mut CompileContext,
244
308
    emit: &mut FunctionEmitter,
245
308
    symbols: &mut SymbolTable,
246
308
    args: &[Expr],
247
308
) -> Result<()> {
248
308
    if args.is_empty() {
249
44
        return Err(Error::Compile("/ requires at least 1 argument".to_string()));
250
264
    }
251
264
    if let Some(nums) = try_fold(symbols, args) {
252
220
        let result = if nums.len() == 1 {
253
88
            if *nums[0].numer() == 0 {
254
44
                return Err(Error::Compile("division by zero".to_string()));
255
44
            }
256
44
            nums[0].recip()
257
        } else {
258
132
            let mut acc = nums[0];
259
220
            for &n in &nums[1..] {
260
220
                if *n.numer() == 0 {
261
44
                    return Err(Error::Compile("division by zero".to_string()));
262
176
                }
263
176
                acc /= n;
264
            }
265
88
            acc
266
        };
267
132
        compile_number(ctx, emit, *result.numer(), *result.denom());
268
132
        return Ok(());
269
44
    }
270
44
    let ty = compile_div_to_stack(ctx, emit, symbols, args)?;
271
44
    serialize_stack_to_output(ctx, emit, ty);
272
44
    Ok(())
273
308
}
274

            
275
176
pub(super) fn compile_mod(
276
176
    ctx: &mut CompileContext,
277
176
    emit: &mut FunctionEmitter,
278
176
    symbols: &mut SymbolTable,
279
176
    args: &[Expr],
280
176
) -> Result<()> {
281
176
    if args.len() != 2 {
282
88
        return Err(Error::Arity {
283
88
            name: "MOD".to_string(),
284
88
            expected: 2,
285
88
            actual: args.len(),
286
88
        });
287
88
    }
288
88
    if let Some(nums) = try_fold(symbols, args) {
289
88
        if *nums[1].numer() == 0 {
290
            return Err(Error::Compile("division by zero in MOD".to_string()));
291
88
        }
292
88
        let quotient = (nums[0] / nums[1]).floor();
293
88
        let result = nums[0] - nums[1] * quotient;
294
88
        compile_number(ctx, emit, *result.numer(), *result.denom());
295
88
        return Ok(());
296
    }
297
    let ty = compile_mod_to_stack(ctx, emit, symbols, args)?;
298
    serialize_stack_to_output(ctx, emit, ty);
299
    Ok(())
300
176
}
301

            
302
// --- Stack compile functions (for sub-expressions) ---
303

            
304
2566
pub(super) fn compile_add_to_stack(
305
2566
    ctx: &mut CompileContext,
306
2566
    emit: &mut FunctionEmitter,
307
2566
    symbols: &mut SymbolTable,
308
2566
    args: &[Expr],
309
2566
) -> Result<WasmType> {
310
2566
    if let Some(nums) = try_fold(symbols, args) {
311
        let sum = nums
312
            .into_iter()
313
            .fold(Fraction::from_integer(0), |a, b| a + b);
314
        push_ratio(ctx, emit, *sum.numer(), *sum.denom());
315
        return Ok(WasmType::Ratio);
316
2566
    }
317
2566
    if args.is_empty() {
318
        push_ratio(ctx, emit, 0, 1);
319
        return Ok(WasmType::Ratio);
320
2566
    }
321
2566
    compile_for_stack_ratio(ctx, emit, symbols, &args[0])?;
322
2522
    for arg in &args[1..] {
323
2522
        compile_for_stack_ratio(ctx, emit, symbols, arg)?;
324
2522
        emit.call(ctx.func("ratio_add"));
325
    }
326
2522
    Ok(WasmType::Ratio)
327
2566
}
328

            
329
308
pub(super) fn compile_sub_to_stack(
330
308
    ctx: &mut CompileContext,
331
308
    emit: &mut FunctionEmitter,
332
308
    symbols: &mut SymbolTable,
333
308
    args: &[Expr],
334
308
) -> Result<WasmType> {
335
308
    if args.is_empty() {
336
        return Err(Error::Compile("- requires at least 1 argument".to_string()));
337
308
    }
338
308
    if let Some(nums) = try_fold(symbols, args) {
339
        let result = if nums.len() == 1 {
340
            -nums[0]
341
        } else {
342
            nums[1..].iter().fold(nums[0], |a, b| a - b)
343
        };
344
        push_ratio(ctx, emit, *result.numer(), *result.denom());
345
        return Ok(WasmType::Ratio);
346
308
    }
347
308
    if args.len() == 1 {
348
        push_ratio(ctx, emit, 0, 1);
349
        compile_for_stack_ratio(ctx, emit, symbols, &args[0])?;
350
        emit.call(ctx.func("ratio_sub"));
351
    } else {
352
308
        compile_for_stack_ratio(ctx, emit, symbols, &args[0])?;
353
308
        for arg in &args[1..] {
354
308
            compile_for_stack_ratio(ctx, emit, symbols, arg)?;
355
308
            emit.call(ctx.func("ratio_sub"));
356
        }
357
    }
358
308
    Ok(WasmType::Ratio)
359
308
}
360

            
361
176
pub(super) fn compile_mul_to_stack(
362
176
    ctx: &mut CompileContext,
363
176
    emit: &mut FunctionEmitter,
364
176
    symbols: &mut SymbolTable,
365
176
    args: &[Expr],
366
176
) -> Result<WasmType> {
367
176
    if let Some(nums) = try_fold(symbols, args) {
368
        let product = nums
369
            .into_iter()
370
            .fold(Fraction::from_integer(1), |a, b| a * b);
371
        push_ratio(ctx, emit, *product.numer(), *product.denom());
372
        return Ok(WasmType::Ratio);
373
176
    }
374
176
    if args.is_empty() {
375
        push_ratio(ctx, emit, 1, 1);
376
        return Ok(WasmType::Ratio);
377
176
    }
378
176
    compile_for_stack_ratio(ctx, emit, symbols, &args[0])?;
379
220
    for arg in &args[1..] {
380
220
        compile_for_stack_ratio(ctx, emit, symbols, arg)?;
381
176
        emit.call(ctx.func("ratio_mul"));
382
    }
383
132
    Ok(WasmType::Ratio)
384
176
}
385

            
386
132
pub(super) fn compile_div_to_stack(
387
132
    ctx: &mut CompileContext,
388
132
    emit: &mut FunctionEmitter,
389
132
    symbols: &mut SymbolTable,
390
132
    args: &[Expr],
391
132
) -> Result<WasmType> {
392
132
    if args.is_empty() {
393
        return Err(Error::Compile("/ requires at least 1 argument".to_string()));
394
132
    }
395
132
    if let Some(nums) = try_fold(symbols, args) {
396
88
        let result = if nums.len() == 1 {
397
            if *nums[0].numer() == 0 {
398
                return Err(Error::Compile("division by zero".to_string()));
399
            }
400
            nums[0].recip()
401
        } else {
402
88
            let mut acc = nums[0];
403
88
            for &n in &nums[1..] {
404
88
                if *n.numer() == 0 {
405
                    return Err(Error::Compile("division by zero".to_string()));
406
88
                }
407
88
                acc /= n;
408
            }
409
88
            acc
410
        };
411
88
        push_ratio(ctx, emit, *result.numer(), *result.denom());
412
88
        return Ok(WasmType::Ratio);
413
44
    }
414
44
    if args.len() == 1 {
415
        push_ratio(ctx, emit, 1, 1);
416
        compile_for_stack_ratio(ctx, emit, symbols, &args[0])?;
417
        emit.call(ctx.func("ratio_div"));
418
    } else {
419
44
        compile_for_stack_ratio(ctx, emit, symbols, &args[0])?;
420
44
        for arg in &args[1..] {
421
44
            compile_for_stack_ratio(ctx, emit, symbols, arg)?;
422
44
            emit.call(ctx.func("ratio_div"));
423
        }
424
    }
425
44
    Ok(WasmType::Ratio)
426
132
}
427

            
428
pub(super) fn compile_mod_to_stack(
429
    ctx: &mut CompileContext,
430
    emit: &mut FunctionEmitter,
431
    symbols: &mut SymbolTable,
432
    args: &[Expr],
433
) -> Result<WasmType> {
434
    if args.len() != 2 {
435
        return Err(Error::Arity {
436
            name: "MOD".to_string(),
437
            expected: 2,
438
            actual: args.len(),
439
        });
440
    }
441
    if let Some(nums) = try_fold(symbols, args) {
442
        if *nums[1].numer() == 0 {
443
            return Err(Error::Compile("division by zero in MOD".to_string()));
444
        }
445
        let quotient = (nums[0] / nums[1]).floor();
446
        let result = nums[0] - nums[1] * quotient;
447
        push_ratio(ctx, emit, *result.numer(), *result.denom());
448
        return Ok(WasmType::Ratio);
449
    }
450
    // MOD runtime: a - b * floor(a/b)
451
    // For simplicity, use: a/b -> floor -> *b -> a - result
452
    // But we don't have floor for ratios yet.
453
    // For now, emit the components using ratio arithmetic
454
    compile_for_stack_ratio(ctx, emit, symbols, &args[0])?;
455
    compile_for_stack_ratio(ctx, emit, symbols, &args[1])?;
456
    // We need: a - b * floor(a / b)
457
    // This requires a floor function we don't have yet for runtime.
458
    // For Phase 0, MOD with runtime args is not supported.
459
    Err(Error::Compile(
460
        "MOD with runtime arguments is not yet supported".to_string(),
461
    ))
462
}