rune/
arith.rs

1//! Arithmetic operators.
2use crate::core::object::{Gc, IntoObject, Number, NumberType, ObjectType};
3use float_cmp::ApproxEq;
4use num_bigint::BigInt;
5use num_traits::{FromPrimitive, ToPrimitive, Zero};
6use rune_macros::defun;
7use std::cmp::PartialEq;
8use std::ops::{Add, Div, Mul, Neg, Rem, Sub};
9
10pub(crate) const MAX_FIXNUM: i64 = i64::MAX >> 8;
11pub(crate) const MIN_FIXNUM: i64 = i64::MIN >> 8;
12
13/// Similar to the object type [NumberType], but contains a float instead of a
14/// reference to a float. This makes it easier to construct and mutate.
15#[derive(Debug, PartialEq, Clone)]
16pub(crate) enum NumberValue {
17    Int(i64),
18    Float(f64),
19    Big(BigInt),
20}
21
22impl Number<'_> {
23    pub(crate) fn val(self) -> NumberValue {
24        match self.untag() {
25            NumberType::Int(x) => NumberValue::Int(x),
26            NumberType::Float(x) => NumberValue::Float(**x),
27            NumberType::Big(x) => NumberValue::Big((**x).clone()),
28        }
29    }
30}
31
32impl IntoObject for NumberValue {
33    type Out<'ob> = ObjectType<'ob>;
34
35    fn into_obj<const C: bool>(self, block: &crate::core::gc::Block<C>) -> Gc<Self::Out<'_>> {
36        match self {
37            NumberValue::Int(x) => x.into(),
38            NumberValue::Float(x) => block.add(x),
39            NumberValue::Big(x) => block.add(x),
40        }
41    }
42}
43
44impl NumberValue {
45    pub fn coerce_integer(self) -> NumberValue {
46        match self {
47            NumberValue::Float(x) => {
48                if x.is_finite() && x <= MAX_FIXNUM as f64 && x >= MIN_FIXNUM as f64 {
49                    NumberValue::Int(x as i64)
50                } else {
51                    NumberValue::Big(BigInt::from_f64(x).unwrap_or_else(BigInt::zero))
52                }
53            }
54            NumberValue::Big(x) => x
55                .to_i64()
56                .filter(|&n| (MIN_FIXNUM..=MAX_FIXNUM).contains(&n))
57                .map(NumberValue::Int)
58                .unwrap_or_else(|| NumberValue::Big(x)),
59            other => other,
60        }
61    }
62}
63
64pub(crate) fn arith(
65    cur: NumberValue,
66    next: NumberValue,
67    int_fn: fn(i64, i64) -> i64,
68    float_fn: fn(f64, f64) -> f64,
69    big_fn: fn(BigInt, BigInt) -> BigInt,
70) -> NumberValue {
71    use NumberValue as N;
72    match (cur, next) {
73        (N::Int(l), N::Int(r)) => N::Int(int_fn(l, r)),
74        (N::Int(l), N::Float(r)) => N::Float(float_fn(l as f64, r)),
75        (N::Float(l), N::Int(r)) => N::Float(float_fn(l, r as f64)),
76        (N::Float(l), N::Float(r)) => N::Float(float_fn(l, r)),
77        (N::Int(l), N::Big(r)) => N::Big(big_fn(l.into(), r)),
78        (N::Big(l), N::Int(r)) => N::Big(big_fn(l, r.into())),
79        (N::Big(l), N::Big(r)) => N::Big(big_fn(l, r)),
80        (N::Float(l), N::Big(r)) => N::Float(float_fn(l, r.to_f64().unwrap())), // TODO: Should round to nearest float on error
81        (N::Big(l), N::Float(r)) => N::Float(float_fn(l.to_f64().unwrap(), r)), // TODO: Should round to nearest float on error
82    }
83}
84
85//////////////////////////
86// Arithmetic operators //
87//////////////////////////
88
89impl Zero for NumberValue {
90    fn zero() -> Self {
91        NumberValue::Int(0)
92    }
93    fn is_zero(&self) -> bool {
94        match self {
95            NumberValue::Int(x) => *x == 0,
96            NumberValue::Float(x) => *x == 0.0,
97            NumberValue::Big(x) => x.is_zero(),
98        }
99    }
100}
101
102impl Neg for NumberValue {
103    type Output = Self;
104    fn neg(self) -> Self::Output {
105        match self {
106            NumberValue::Int(x) => NumberValue::Int(-x),
107            NumberValue::Float(x) => NumberValue::Float(-x),
108            NumberValue::Big(x) => NumberValue::Big(-x),
109        }
110    }
111}
112
113impl Add for NumberValue {
114    type Output = Self;
115    fn add(self, rhs: Self) -> Self::Output {
116        arith(self, rhs, Add::add, Add::add, Add::add)
117    }
118}
119
120impl Sub for NumberValue {
121    type Output = Self;
122    fn sub(self, rhs: Self) -> Self::Output {
123        arith(self, rhs, Sub::sub, Sub::sub, Sub::sub)
124    }
125}
126
127impl Mul for NumberValue {
128    type Output = Self;
129    fn mul(self, rhs: Self) -> Self::Output {
130        arith(self, rhs, Mul::mul, Mul::mul, Mul::mul)
131    }
132}
133
134impl Div for NumberValue {
135    type Output = Self;
136    fn div(self, rhs: Self) -> Self::Output {
137        arith(self, rhs, Div::div, Div::div, Div::div)
138    }
139}
140
141impl Rem for NumberValue {
142    type Output = Self;
143    fn rem(self, rhs: Self) -> Self::Output {
144        arith(self, rhs, Rem::rem, Rem::rem, Rem::rem)
145    }
146}
147
148impl PartialEq<i64> for Number<'_> {
149    fn eq(&self, other: &i64) -> bool {
150        match self.val() {
151            NumberValue::Int(num) => num == *other,
152            NumberValue::Float(num) => num == *other as f64,
153            NumberValue::Big(num) => num == BigInt::from(*other),
154        }
155    }
156}
157
158impl PartialEq<f64> for Number<'_> {
159    fn eq(&self, other: &f64) -> bool {
160        match self.val() {
161            NumberValue::Int(num) => num as f64 == *other,
162            NumberValue::Float(num) => num.approx_eq(*other, (f64::EPSILON, 2)),
163            NumberValue::Big(num) => {
164                num.to_f64().is_some_and(|n| n.approx_eq(*other, (f64::EPSILON, 2)))
165            } // TODO: Check behavior when conversion fails
166        }
167    }
168}
169
170impl PartialEq<BigInt> for Number<'_> {
171    fn eq(&self, other: &BigInt) -> bool {
172        match self.val() {
173            NumberValue::Int(num) => BigInt::from(num) == *other,
174            NumberValue::Float(num) => {
175                other.to_f64().is_some_and(|n| n.approx_eq(num, (f64::EPSILON, 2)))
176            } // TODO: Check
177            NumberValue::Big(num) => num == *other,
178        }
179    }
180}
181
182impl PartialOrd for NumberValue {
183    fn partial_cmp(&self, other: &NumberValue) -> Option<std::cmp::Ordering> {
184        match self {
185            NumberValue::Int(lhs) => match other {
186                NumberValue::Int(rhs) => lhs.partial_cmp(rhs),
187                NumberValue::Float(rhs) => (*lhs as f64).partial_cmp(rhs),
188                NumberValue::Big(rhs) => BigInt::from(*lhs).partial_cmp(rhs),
189            },
190            NumberValue::Float(lhs) => match other {
191                NumberValue::Int(rhs) => lhs.partial_cmp(&(*rhs as f64)),
192                NumberValue::Float(rhs) => lhs.partial_cmp(rhs),
193                NumberValue::Big(rhs) => {
194                    lhs.partial_cmp(&rhs.to_f64().unwrap_or(f64::NAN)) // TODO: Handle conversion failure
195                }
196            },
197            NumberValue::Big(lhs) => match other {
198                NumberValue::Int(rhs) => lhs.partial_cmp(&BigInt::from(*rhs)),
199                NumberValue::Float(rhs) => lhs.to_f64().and_then(|n| n.partial_cmp(rhs)),
200                NumberValue::Big(rhs) => lhs.partial_cmp(rhs),
201            },
202        }
203    }
204}
205
206#[defun(name = "+")]
207pub(crate) fn add(vars: &[Number]) -> NumberValue {
208    vars.iter().fold(NumberValue::Int(0), |acc, x| acc + x.val())
209}
210
211#[defun(name = "-")]
212pub(crate) fn sub(number: Option<Number>, numbers: &[Number]) -> NumberValue {
213    match number {
214        Some(num) => {
215            let num = num.val();
216            if numbers.is_empty() {
217                -num
218            } else {
219                numbers.iter().fold(num, |acc, x| acc - x.val())
220            }
221        }
222        None => NumberValue::Int(0),
223    }
224}
225
226#[defun(name = "*")]
227pub(crate) fn mul(numbers: &[Number]) -> NumberValue {
228    numbers.iter().fold(NumberValue::Int(1), |acc, x| acc * x.val())
229}
230
231#[defun(name = "/")]
232pub(crate) fn div(number: Number, divisors: &[Number]) -> NumberValue {
233    divisors.iter().fold(number.val(), |acc, x| acc / x.val())
234}
235
236#[defun(name = "1+")]
237pub(crate) fn add_one(number: Number) -> NumberValue {
238    number.val() + NumberValue::Int(1)
239}
240
241#[defun(name = "1-")]
242pub(crate) fn sub_one(number: Number) -> NumberValue {
243    number.val() - NumberValue::Int(1)
244}
245
246#[defun(name = "=")]
247pub(crate) fn num_eq(number: Number, numbers: &[Number]) -> bool {
248    match number.val() {
249        NumberValue::Int(num) => numbers.iter().all(|&x| x == num),
250        NumberValue::Float(num) => numbers.iter().all(|&x| x == num),
251        NumberValue::Big(num) => numbers.iter().all(|&x| x == num),
252    }
253}
254
255#[defun(name = "/=")]
256pub(crate) fn num_ne(number: Number, numbers: &[Number]) -> bool {
257    match number.val() {
258        NumberValue::Int(num) => numbers.iter().all(|&x| x != num),
259        NumberValue::Float(num) => numbers.iter().all(|&x| x != num),
260        NumberValue::Big(num) => numbers.iter().all(|&x| x != num),
261    }
262}
263
264fn cmp(number: Number, numbers: &[Number], cmp: fn(&NumberValue, &NumberValue) -> bool) -> bool {
265    numbers
266        .iter()
267        .try_fold(number.val(), |acc, &x| cmp(&acc, &x.val()).then_some(NumberValue::Int(0)))
268        .is_some()
269}
270
271#[defun(name = "<")]
272pub(crate) fn less_than(number: Number, numbers: &[Number]) -> bool {
273    cmp(number, numbers, NumberValue::lt)
274}
275
276#[defun(name = "<=")]
277pub(crate) fn less_than_or_eq(number: Number, numbers: &[Number]) -> bool {
278    cmp(number, numbers, NumberValue::le)
279}
280
281#[defun(name = ">")]
282pub(crate) fn greater_than(number: Number, numbers: &[Number]) -> bool {
283    cmp(number, numbers, NumberValue::gt)
284}
285
286#[defun(name = ">=")]
287pub(crate) fn greater_than_or_eq(number: Number, numbers: &[Number]) -> bool {
288    cmp(number, numbers, NumberValue::ge)
289}
290
291#[defun]
292pub(crate) fn logior(ints_or_markers: &[Gc<i64>]) -> i64 {
293    ints_or_markers.iter().fold(0, |acc, x| acc | x.untag())
294}
295
296#[defun]
297fn logand(int_or_markers: &[Gc<i64>]) -> i64 {
298    int_or_markers.iter().fold(-1, |accum, x| accum & x.untag())
299}
300
301#[defun(name = "mod")]
302pub(crate) fn modulo(x: Number, y: Number) -> NumberValue {
303    x.val() % y.val()
304}
305
306#[defun(name = "%")]
307pub(crate) fn remainder(x: i64, y: i64) -> i64 {
308    // TODO: Handle markers
309    x % y
310}
311
312#[expect(clippy::trivially_copy_pass_by_ref)]
313fn max_val(x: NumberValue, y: &Number) -> NumberValue {
314    let y = y.val();
315    if x > y { x } else { y }
316}
317
318#[expect(clippy::trivially_copy_pass_by_ref)]
319fn min_val(x: NumberValue, y: &Number) -> NumberValue {
320    let y = y.val();
321    if x < y { x } else { y }
322}
323
324#[defun]
325pub(crate) fn max(number_or_marker: Number, number_or_markers: &[Number]) -> NumberValue {
326    number_or_markers.iter().fold(number_or_marker.val(), max_val)
327}
328
329#[defun]
330pub(crate) fn min(number_or_marker: Number, number_or_markers: &[Number]) -> NumberValue {
331    number_or_markers.iter().fold(number_or_marker.val(), min_val)
332}
333
334#[cfg(test)]
335mod test {
336    use super::*;
337    use crate::core::gc::{Context, RootSet};
338
339    #[test]
340    fn test_add() {
341        let roots = &RootSet::default();
342        let cx = &Context::new(roots);
343        assert_eq!(add(&[]), NumberValue::Int(0));
344        assert_eq!(add(&[7.into(), 13.into()]), NumberValue::Int(20));
345        assert_eq!(add(&[1.into(), cx.add_as(2.5)]), NumberValue::Float(3.5));
346        assert_eq!(add(&[0.into(), (-1).into()]), NumberValue::Int(-1));
347    }
348
349    #[test]
350    fn test_sub() {
351        assert_eq!(sub(None, &[]), NumberValue::Int(0));
352        assert_eq!(sub(Some(7.into()), &[]), NumberValue::Int(-7));
353        assert_eq!(sub(Some(7.into()), &[13.into()]), NumberValue::Int(-6));
354        assert_eq!(sub(Some(0.into()), &[(-1).into()]), NumberValue::Int(1));
355    }
356
357    #[test]
358    fn test_mul() {
359        assert_eq!(mul(&[]), NumberValue::Int(1));
360        assert_eq!(mul(&[7.into(), 13.into()]), NumberValue::Int(91));
361        assert_eq!(mul(&[(-1).into(), 1.into()]), NumberValue::Int(-1));
362    }
363
364    #[test]
365    fn test_div() {
366        let roots = &RootSet::default();
367        let cx = &Context::new(roots);
368
369        assert_eq!(div(cx.add_as(12.0), &[]), NumberValue::Float(12.0));
370        assert_eq!(div(12.into(), &[5.into(), 2.into()]), NumberValue::Int(1));
371    }
372
373    #[test]
374    fn test_eq() {
375        let roots = &RootSet::default();
376        let cx = &Context::new(roots);
377        let int1 = 1.into();
378        let float1 = cx.add_as(1.0);
379        let float1_1 = cx.add_as(1.1);
380
381        assert!(num_eq(int1, &[]));
382        assert!(num_eq(int1, &[cx.add_as(1.0)]));
383        assert!(num_eq(float1, &[1.into()]));
384        assert!(!num_eq(float1, &[1.into(), 1.into(), float1_1]));
385    }
386
387    #[test]
388    fn test_cmp() {
389        let roots = &RootSet::default();
390        let cx = &Context::new(roots);
391        assert!(less_than(1.into(), &[]));
392        assert!(less_than(1.into(), &[cx.add_as(1.1)]));
393        assert!(!less_than(cx.add_as(1.0), &[1.into()]));
394        assert!(less_than(cx.add_as(1.0), &[cx.add_as(1.1), 2.into(), cx.add_as(2.1)]));
395    }
396
397    #[test]
398    fn test_max_min() {
399        let roots = &RootSet::default();
400        let cx = &Context::new(roots);
401        assert_eq!(
402            max(cx.add_as(1.0), &[cx.add_as(2.1), cx.add_as(1.1), cx.add_as(1.0)]),
403            cx.add_as(2.1).val()
404        );
405        assert_eq!(
406            min(cx.add_as(1.1), &[cx.add_as(1.0), cx.add_as(2.1), cx.add_as(1.0)]),
407            cx.add_as(1.0).val()
408        );
409    }
410
411    #[test]
412    fn test_other() {
413        let roots = &RootSet::default();
414        let cx = &Context::new(roots);
415        assert_eq!(logand(&[258.into_obj(cx), 255.into_obj(cx)]), 2);
416    }
417}