1
//! Constant-folding eval handlers for `+ - * / MOD` over the ADR-0028
2
//! lattice. An all-literal call folds to an `Expr::Number` (integer
3
//! operands fold with Index semantics — `/` truncates, `MOD` is `rem` —
4
//! so the eval surface agrees with the `i32.div_s` / `i32.rem_s` codegen);
5
//! otherwise the handler returns `Expr::WasmRuntime(result_dim)`. Real
6
//! cross-strata refusal happens at codegen time in `compile::*`.
7

            
8
use num_traits::{CheckedAdd, CheckedDiv, CheckedMul, CheckedSub};
9

            
10
use crate::ast::{Expr, Fraction, WasmType};
11
use crate::compiler::expr::{eval_value, format_expr};
12
use crate::error::{Error, Result};
13
use crate::runtime::SymbolTable;
14

            
15
28653
pub(super) fn try_fold(symbols: &mut SymbolTable, args: &[Expr]) -> Option<Vec<Fraction>> {
16
    // Probe on a CLONE: this only decides whether the call const-folds; it must
17
    // not apply an operand's compile-time effects to the live table (the actual
18
    // emission re-walks the operands on the live path).
19
28653
    args.iter()
20
39398
        .map(|arg| {
21
39398
            eval_value(&mut symbols.clone(), arg)
22
39398
                .ok()
23
39398
                .and_then(|e| match e {
24
20198
                    Expr::Number(n) => Some(n),
25
18996
                    _ => None,
26
39194
                })
27
39398
        })
28
28653
        .collect()
29
28653
}
30

            
31
95284
pub(super) fn try_fold_resolved(resolved: &[Expr]) -> Option<Vec<Fraction>> {
32
95284
    resolved
33
95284
        .iter()
34
167908
        .map(|e| match e {
35
145384
            Expr::Number(n) => Some(*n),
36
22524
            _ => None,
37
167908
        })
38
95284
        .collect()
39
95284
}
40

            
41
5644
fn all_integers(nums: &[Fraction]) -> bool {
42
10812
    nums.iter().all(|n| *n.denom() == 1)
43
5644
}
44

            
45
/// The runtime dimension of a resolved operand, or `None` for a numeric literal
46
/// (a dimension-flexible token).
47
44640
fn operand_dim(r: &Expr) -> Result<Option<WasmType>> {
48
44640
    match r {
49
19872
        Expr::Number(_) => Ok(None),
50
24768
        _ => match r.wasm_type() {
51
24768
            Some(t @ (WasmType::I32 | WasmType::Ratio | WasmType::Commodity)) => Ok(Some(t)),
52
            _ => Err(Error::Compile(format!(
53
                "expected number arguments, got {}",
54
                format_expr(r)
55
            ))),
56
        },
57
    }
58
44640
}
59

            
60
/// Mirror codegen's refusal (`emit_literal` / `additive_dim`) to mix a runtime
61
/// Index with any Scalar — a runtime Ratio OR a fractional literal. Integer
62
/// literals stay flexible (they coerce to the Index). Without this the eval
63
/// surface would optimistically predict `I32` for `(+ IDX 1/2)` while codegen
64
/// rejects it — eval↔codegen drift. The only legal crossing is an explicit
65
/// `index->scalar` / `scalar->index` bridge.
66
22524
fn reject_index_scalar_mix(resolved: &[Expr]) -> Result<()> {
67
22524
    let has_runtime_index = resolved
68
22524
        .iter()
69
35376
        .any(|r| !matches!(r, Expr::Number(_)) && r.wasm_type() == Some(WasmType::I32));
70
22524
    if !has_runtime_index {
71
12852
        return Ok(());
72
9672
    }
73
19344
    let has_scalar = resolved.iter().any(|r| match r {
74
9536
        Expr::Number(n) => *n.denom() != 1,
75
9808
        other => other.wasm_type() == Some(WasmType::Ratio),
76
19344
    });
77
9672
    if has_scalar {
78
        return Err(Error::Compile(
79
            "cannot mix an index (count) with a scalar; bridge with \
80
             index->scalar / scalar->index"
81
                .to_string(),
82
        ));
83
9672
    }
84
9672
    Ok(())
85
22524
}
86

            
87
/// The dominant runtime dimension across operands: Money > Scalar > Index;
88
/// all-literal defaults to Scalar. Used for `+ - * MOD`, whose result is the
89
/// single shared dimension in the legal (non-mixed) case.
90
22116
fn dominant_result_type(resolved: &[Expr]) -> Result<WasmType> {
91
22116
    reject_index_scalar_mix(resolved)?;
92
22116
    let mut money = false;
93
22116
    let mut scalar = false;
94
22116
    let mut index = false;
95
44232
    for r in resolved {
96
44232
        match operand_dim(r)? {
97
544
            Some(WasmType::Commodity) => money = true,
98
14008
            Some(WasmType::Ratio) => scalar = true,
99
9808
            Some(WasmType::I32) => index = true,
100
19872
            _ => {}
101
        }
102
    }
103
22116
    Ok(if money {
104
272
        WasmType::Commodity
105
21844
    } else if scalar {
106
12172
        WasmType::Ratio
107
9672
    } else if index {
108
9672
        WasmType::I32
109
    } else {
110
        WasmType::Ratio
111
    })
112
22116
}
113

            
114
/// The result dimension of `/`, mirroring codegen's LEFT-associative fold
115
/// (`combine_div`): Index÷Index→Index, Scalar÷Scalar→Scalar, Money÷Scalar→Money,
116
/// Money÷Money→Scalar. A dominant-type shortcut would drift from codegen for
117
/// chains like `(/ money scalar money)` (→ Scalar, not Money).
118
408
fn div_result_type(resolved: &[Expr]) -> Result<WasmType> {
119
408
    reject_index_scalar_mix(resolved)?;
120
    // The result dimension is the LEFT operand's dimension, mirroring codegen's
121
    // left-associative `combine_div`: Index÷Index→Index, Scalar÷*→Scalar, and
122
    // (ADR-0028 E2) Money÷anything→Money — money ÷ money no longer collapses to
123
    // Ratio, it stays Money carrying a dimensionless/compound unit term, in
124
    // lockstep with `commodity_div`. A leading bare literal seeds Index only
125
    // when the dominant runtime dim is Index, else Scalar.
126
408
    let leading = operand_dim(&resolved[0])?;
127
    Ok(match leading {
128
408
        Some(t) => t,
129
        None if dominant_result_type(resolved)? == WasmType::I32 => WasmType::I32,
130
        None => WasmType::Ratio,
131
    })
132
408
}
133

            
134
/// The `Fraction` (`Ratio<i64>`) operators panic (debug / `overflow-checks`) or
135
/// wrap (release) when a cross-multiply exceeds i64 — reachable from any
136
/// all-literal `(* huge huge)` / `(+ a/b c/d)` whose reduced terms overflow. So
137
/// const-fold through the `Checked*` traits and surface a structured
138
/// `Error::Compile` instead, per CLAUDE.md (never a panic / SIGABRT on input).
139
340
fn overflow() -> Error {
140
340
    Error::Compile("arithmetic overflow in constant expression".to_string())
141
340
}
142

            
143
58617
pub(super) fn fold_add(nums: &[Fraction]) -> Result<Fraction> {
144
117710
    nums.iter().try_fold(Fraction::from_integer(0), |a, b| {
145
117710
        a.checked_add(b).ok_or_else(overflow)
146
117710
    })
147
58617
}
148

            
149
12852
pub(super) fn fold_sub(nums: &[Fraction]) -> Result<Fraction> {
150
12852
    if nums.len() == 1 {
151
204
        return Fraction::from_integer(0)
152
204
            .checked_sub(&nums[0])
153
204
            .ok_or_else(overflow);
154
12648
    }
155
12648
    nums[1..]
156
12648
        .iter()
157
12988
        .try_fold(nums[0], |a, b| a.checked_sub(b).ok_or_else(overflow))
158
12852
}
159

            
160
5032
pub(super) fn fold_mul(nums: &[Fraction]) -> Result<Fraction> {
161
9996
    nums.iter().try_fold(Fraction::from_integer(1), |a, b| {
162
9996
        a.checked_mul(b).ok_or_else(overflow)
163
9996
    })
164
5032
}
165

            
166
/// Const-fold `/`: all-integer operands divide as Index (truncating toward
167
/// zero, matching `i32.div_s`); any fractional operand divides rationally.
168
5304
pub(super) fn fold_div(nums: &[Fraction]) -> Result<Fraction> {
169
5304
    if nums.len() == 1 {
170
204
        if *nums[0].numer() == 0 {
171
68
            return Err(Error::Compile("division by zero".to_string()));
172
136
        }
173
136
        return Ok(if all_integers(nums) {
174
            // `1 / numer` can't overflow, but keep the lattice's truncating
175
            // Index semantics.
176
68
            Fraction::from_integer(1 / *nums[0].numer())
177
        } else {
178
            // `recip` swaps numer/denom — infallible for a non-zero ratio.
179
68
            nums[0].recip()
180
        });
181
5100
    }
182
5100
    if all_integers(nums) {
183
4692
        let mut acc = *nums[0].numer();
184
4828
        for n in &nums[1..] {
185
4828
            if *n.numer() == 0 {
186
136
                return Err(Error::Compile("division by zero".to_string()));
187
4692
            }
188
            // `i64::MIN / -1` overflows — `checked_div` catches it.
189
4692
            acc = acc.checked_div(*n.numer()).ok_or_else(overflow)?;
190
        }
191
4556
        return Ok(Fraction::from_integer(acc));
192
408
    }
193
408
    let mut acc = nums[0];
194
408
    for n in &nums[1..] {
195
408
        if *n.numer() == 0 {
196
            return Err(Error::Compile("division by zero".to_string()));
197
408
        }
198
        // Rational division cross-multiplies — can overflow i64.
199
408
        acc = acc.checked_div(n).ok_or_else(overflow)?;
200
    }
201
340
    Ok(acc)
202
5304
}
203

            
204
/// Const-fold `MOD`: all-integer operands use `rem` (sign of the dividend,
205
/// matching `i32.rem_s`); fractional operands use floored modulo.
206
408
pub(super) fn fold_mod(nums: &[Fraction]) -> Result<Fraction> {
207
408
    if *nums[1].numer() == 0 {
208
        return Err(Error::Compile("division by zero in MOD".to_string()));
209
408
    }
210
408
    if all_integers(nums) {
211
        // Raw `%` panics on `i64::MIN % -1` (overflow). wasm `i32.rem_s` defines
212
        // that case as 0 (it does NOT trap, unlike `i32.div_s`), so `checked_rem`
213
        // returning `None` there folds to 0 — keeping the eval surface in lockstep
214
        // with `i32.rem_s` codegen rather than panicking the compiler.
215
340
        let r = nums[0].numer().checked_rem(*nums[1].numer()).unwrap_or(0);
216
340
        return Ok(Fraction::from_integer(r));
217
68
    }
218
    // Floored modulo `a - b*floor(a/b)` cross-multiplies twice — guard both.
219
68
    let quotient = nums[0].checked_div(&nums[1]).ok_or_else(overflow)?.floor();
220
    let scaled = nums[1].checked_mul(&quotient).ok_or_else(overflow)?;
221
    nums[0].checked_sub(&scaled).ok_or_else(overflow)
222
408
}
223

            
224
98420
fn resolve_all(symbols: &mut SymbolTable, args: &[Expr]) -> Result<Vec<Expr>> {
225
193500
    args.iter().map(|arg| eval_value(symbols, arg)).collect()
226
98420
}
227

            
228
67412
pub(super) fn add(symbols: &mut SymbolTable, args: &[Expr]) -> Result<Expr> {
229
67412
    let resolved = resolve_all(symbols, args)?;
230
64276
    if let Some(nums) = try_fold_resolved(&resolved) {
231
53516
        return Ok(Expr::Number(fold_add(&nums)?));
232
10760
    }
233
10760
    Ok(Expr::WasmRuntime(dominant_result_type(&resolved)?))
234
67412
}
235

            
236
20740
pub(super) fn sub(symbols: &mut SymbolTable, args: &[Expr]) -> Result<Expr> {
237
20740
    if args.is_empty() {
238
        return Err(Error::Compile("- requires at least 1 argument".to_string()));
239
20740
    }
240
20740
    let resolved = resolve_all(symbols, args)?;
241
20740
    if let Some(nums) = try_fold_resolved(&resolved) {
242
12308
        return Ok(Expr::Number(fold_sub(&nums)?));
243
8432
    }
244
8432
    Ok(Expr::WasmRuntime(dominant_result_type(&resolved)?))
245
20740
}
246

            
247
6052
pub(super) fn mul(symbols: &mut SymbolTable, args: &[Expr]) -> Result<Expr> {
248
6052
    let resolved = resolve_all(symbols, args)?;
249
6052
    if let Some(nums) = try_fold_resolved(&resolved) {
250
3128
        return Ok(Expr::Number(fold_mul(&nums)?));
251
2924
    }
252
2924
    Ok(Expr::WasmRuntime(dominant_result_type(&resolved)?))
253
6052
}
254

            
255
4148
pub(super) fn div(symbols: &mut SymbolTable, args: &[Expr]) -> Result<Expr> {
256
4148
    if args.is_empty() {
257
        return Err(Error::Compile("/ requires at least 1 argument".to_string()));
258
4148
    }
259
4148
    let resolved = resolve_all(symbols, args)?;
260
4148
    if let Some(nums) = try_fold_resolved(&resolved) {
261
3740
        return Ok(Expr::Number(fold_div(&nums)?));
262
408
    }
263
408
    Ok(Expr::WasmRuntime(div_result_type(&resolved)?))
264
4148
}
265

            
266
68
pub(super) fn modulo(symbols: &mut SymbolTable, args: &[Expr]) -> Result<Expr> {
267
68
    if args.len() != 2 {
268
        return Err(Error::Arity {
269
            name: "MOD".to_string(),
270
            expected: 2,
271
            actual: args.len(),
272
        });
273
68
    }
274
68
    let resolved = resolve_all(symbols, args)?;
275
68
    if let Some(nums) = try_fold_resolved(&resolved) {
276
68
        return Ok(Expr::Number(fold_mod(&nums)?));
277
    }
278
    Ok(Expr::WasmRuntime(dominant_result_type(&resolved)?))
279
68
}