diff --git a/src/benchmarks/bignum_benchmarks.nr b/src/benchmarks/bignum_benchmarks.nr index cf9fe2a5..6b7139a9 100644 --- a/src/benchmarks/bignum_benchmarks.nr +++ b/src/benchmarks/bignum_benchmarks.nr @@ -87,12 +87,12 @@ comptime fn make_bench(m: Module, params: Quoted) -> Quoted { #[export] fn $udiv_mod_bench_name(a: $typ, b: $typ) -> ($typ, $typ) { - $BigNum::udiv_mod(a, b) + $BigNum::udiv_mod(&a, &b) } #[export] fn $udiv_bench_name(a: $typ, b: $typ) -> $typ { - $BigNum::udiv(a, b) + $BigNum::udiv(&a, &b) } #[export] @@ -107,37 +107,37 @@ comptime fn make_bench(m: Module, params: Quoted) -> Quoted { #[export] fn $is_zero_bench_name(a: $typ) -> bool { - $BigNum::is_zero(a) + $BigNum::is_zero(&a) } #[export] fn $is_zero_integer_bench_name(a: $typ) -> bool { - $BigNum::is_zero_integer(a) + $BigNum::is_zero_integer(&a) } #[export] fn $assert_is_not_equal_bench_name(a: $typ, b: $typ) { - $BigNum::assert_is_not_equal(a, b) + $BigNum::assert_is_not_equal(&a, &b) } #[export] fn $assert_is_not_zero_bench_name(a: $typ) { - $BigNum::assert_is_not_zero(a) + $BigNum::assert_is_not_zero(&a) } #[export] fn $assert_is_not_zero_integer_bench_name(a: $typ) { - $BigNum::assert_is_not_zero_integer(a) + $BigNum::assert_is_not_zero_integer(&a) } #[export] fn $validate_in_range_bench_name(a: $typ) { - $BigNum::validate_in_range(a) + $BigNum::validate_in_range(&a) } #[export] fn $validate_in_field_bench_name(a: $typ) { - $BigNum::validate_in_field(a) + $BigNum::validate_in_field(&a) } #[export] @@ -152,22 +152,22 @@ comptime fn make_bench(m: Module, params: Quoted) -> Quoted { #[export] fn $from_be_bytes_bench_name(a: [u8; ($MOD_BITS+7) / 8]) -> [u128; $N] { - crate::fns::serialization::from_be_bytes(a) + crate::fns::serialization::from_be_bytes::<$N, $MOD_BITS>(&a) } #[export] fn $to_be_bytes_bench_name(a: $typ) -> [u8; ($MOD_BITS+7) / 8] { - crate::fns::serialization::to_be_bytes($BigNum::get_limbs(a)) + crate::fns::serialization::to_be_bytes::<$N, $MOD_BITS>(&$BigNum::get_limbs(&a)) } #[export] fn $from_le_bytes_bench_name(a: [u8; ($MOD_BITS+7) / 8]) -> [u128; $N] { - crate::fns::serialization::from_le_bytes(a) - } + crate::fns::serialization::from_le_bytes::<$N, $MOD_BITS>(&a) + } #[export] fn $to_le_bytes_bench_name(a: $typ) -> [u8; ($MOD_BITS+7) / 8] { - crate::fns::serialization::to_le_bytes($BigNum::get_limbs(a)) + crate::fns::serialization::to_le_bytes::<$N, $MOD_BITS>(&$BigNum::get_limbs(&a)) } #[export] @@ -208,13 +208,13 @@ comptime fn make_bench(m: Module, params: Quoted) -> Quoted { #[export] fn $pow_bench_name(a: $typ, b: $typ) -> $typ{ // Safety: Benchmarking - unsafe { $BigNum::__pow(a, b) } + unsafe { $BigNum::__pow(&a, &b) } } #[export] fn $sqrt_bench_name(a: $typ) -> Option<$typ> { // Safety: Benchmarking - unsafe { $BigNum::__sqrt(a) } + unsafe { $BigNum::__sqrt(&a) } } } } diff --git a/src/bignum.nr b/src/bignum.nr index 29729fdf..4c0ba121 100644 --- a/src/bignum.nr +++ b/src/bignum.nr @@ -93,7 +93,7 @@ pub trait BigNum: Neg + Add + Sub + Mul + Div + Eq { /// Returns the raw limb representation. /// /// Limbs are in little-endian order: `limbs[0]` contains the least significant 120 bits. - fn get_limbs(self) -> [u128; N]; + fn get_limbs(&self) -> [u128; N]; /// Sets a specific limb to a new value. /// @@ -103,7 +103,7 @@ pub trait BigNum: Neg + Add + Sub + Mul + Div + Eq { fn set_limb(self: &mut Self, idx: u32, value: u128); /// Returns the limb at the given index. - fn get_limb(self: Self, idx: u32) -> u128; + fn get_limb(&self, idx: u32) -> u128; /// Generates a deterministic pseudorandom BigNum from a seed. /// @@ -133,7 +133,7 @@ pub trait BigNum: Neg + Add + Sub + Mul + Div + Eq { fn from_be_bytes(x: [u8; (MOD_BITS + 7) / 8]) -> Self; /// Converts to big-endian byte representation. - fn to_be_bytes(self) -> [u8; (MOD_BITS + 7) / 8]; + fn to_be_bytes(&self) -> [u8; (MOD_BITS + 7) / 8]; /// Constructs a BigNum from little-endian bytes. /// @@ -144,7 +144,7 @@ pub trait BigNum: Neg + Add + Sub + Mul + Div + Eq { fn from_le_bytes(x: [u8; (MOD_BITS + 7) / 8]) -> Self; /// Converts to little-endian byte representation. - fn to_le_bytes(self) -> [u8; (MOD_BITS + 7) / 8]; + fn to_le_bytes(&self) -> [u8; (MOD_BITS + 7) / 8]; /// Unconstrained equality check. /// @@ -152,7 +152,7 @@ pub trait BigNum: Neg + Add + Sub + Mul + Div + Eq { /// /// # Safety /// No constraints are generated. Use `==` operator for constrained equality. - unconstrained fn __eq(self: Self, other: Self) -> bool; + unconstrained fn __eq(&self, other: &Self) -> bool; /// Unconstrained zero check. /// @@ -164,7 +164,7 @@ pub trait BigNum: Neg + Add + Sub + Mul + Div + Eq { /// /// # Safety /// No constraints are generated. Use `is_zero()` for constrained check. - unconstrained fn __is_zero(self: Self) -> bool; + unconstrained fn __is_zero(&self) -> bool; /// Unconstrained negation: computes `-self (mod MOD)`. /// @@ -173,31 +173,31 @@ pub trait BigNum: Neg + Add + Sub + Mul + Div + Eq { /// /// # Safety /// No constraints are generated. Constrain the result using `evaluate_quadratic_expression`. - unconstrained fn __neg(self) -> Self; + unconstrained fn __neg(&self) -> Self; /// Unconstrained addition: computes `self + other (mod MOD)`. /// /// # Safety /// No constraints are generated. Constrain the result using `evaluate_quadratic_expression`. - unconstrained fn __add(self, other: Self) -> Self; + unconstrained fn __add(&self, other: &Self) -> Self; /// Unconstrained subtraction: computes `self - other (mod MOD)`. /// /// # Safety /// No constraints are generated. Constrain the result using `evaluate_quadratic_expression`. - unconstrained fn __sub(self, other: Self) -> Self; + unconstrained fn __sub(&self, other: &Self) -> Self; /// Unconstrained multiplication: computes `self * other (mod MOD)`. /// /// # Safety /// No constraints are generated. Constrain the result using `evaluate_quadratic_expression`. - unconstrained fn __mul(self, other: Self) -> Self; + unconstrained fn __mul(&self, other: &Self) -> Self; /// Unconstrained squaring: computes `self * self (mod MOD)`. /// /// # Safety /// No constraints are generated. Constrain the result using `evaluate_quadratic_expression`. - unconstrained fn __sqr(self) -> Self; + unconstrained fn __sqr(&self) -> Self; /// Unconstrained modular division: computes `self * other^{-1} (mod MOD)`. /// @@ -211,7 +211,7 @@ pub trait BigNum: Neg + Add + Sub + Mul + Div + Eq { /// /// # Safety /// No constraints are generated. Constrain the result using `evaluate_quadratic_expression`. - unconstrained fn __div(self, other: Self) -> Self; + unconstrained fn __div(&self, other: &Self) -> Self; /// Unconstrained integer division with remainder. /// @@ -224,7 +224,7 @@ pub trait BigNum: Neg + Add + Sub + Mul + Div + Eq { /// /// # Safety /// No constraints are generated. Use `udiv_mod()` for the constrained version. - unconstrained fn __udiv_mod(self, divisor: Self) -> (Self, Self); + unconstrained fn __udiv_mod(&self, divisor: &Self) -> (Self, Self); /// Unconstrained modular inverse: computes `self^{-1} (mod MOD)`. /// @@ -237,7 +237,7 @@ pub trait BigNum: Neg + Add + Sub + Mul + Div + Eq { /// /// # Safety /// No constraints are generated. - unconstrained fn __invmod(self) -> Self; + unconstrained fn __invmod(&self) -> Self; /// Unconstrained modular exponentiation: computes `self^exponent (mod MOD)`. /// @@ -246,10 +246,10 @@ pub trait BigNum: Neg + Add + Sub + Mul + Div + Eq { /// /// # Safety /// No constraints are generated. - unconstrained fn __pow(self, exponent: Self) -> Self; + unconstrained fn __pow(&self, exponent: &Self) -> Self; /// **Deprecated**: use `__sqrt` instead. - unconstrained fn __tonelli_shanks_sqrt(self) -> std::option::Option; + unconstrained fn __tonelli_shanks_sqrt(&self) -> std::option::Option; /// Unconstrained modular square root. /// @@ -264,7 +264,7 @@ pub trait BigNum: Neg + Add + Sub + Mul + Div + Eq { /// /// # Safety /// No constraints are generated. - unconstrained fn __sqrt(self) -> std::option::Option; + unconstrained fn __sqrt(&self) -> std::option::Option; /// Asserts that `self != other (mod MOD)`. /// @@ -274,7 +274,7 @@ pub trait BigNum: Neg + Add + Sub + Mul + Div + Eq { /// # Soundness /// Sound for `MOD < circuit_field_modulus`. For very large moduli approaching the circuit /// field size, there's a negligible (~3/p) probability of false positives. - fn assert_is_not_equal(self: Self, other: Self); + fn assert_is_not_equal(&self, other: &Self); /// Returns `true` if `self == 0 (mod MOD)`. /// @@ -284,20 +284,20 @@ pub trait BigNum: Neg + Add + Sub + Mul + Div + Eq { /// /// # Performance Note /// Cheaper than `self == BigNum::zero()`. - fn is_zero(self) -> bool; + fn is_zero(&self) -> bool; /// Returns `true` if all limbs are zero (integer zero, not modular). /// /// # Note /// Unlike `is_zero()`, this returns `false` for the modulus value even though /// `modulus == 0 (mod MOD)`. - fn is_zero_integer(self) -> bool; + fn is_zero_integer(&self) -> bool; /// Asserts that `self != 0 (mod MOD)`. /// /// # Performance Note /// Cheaper than `assert(!self.is_zero())` or `assert(self != BigNum::zero())`. - fn assert_is_not_zero(self); + fn assert_is_not_zero(&self); /// Asserts that at least one limb is non-zero (integer assertion). /// @@ -306,7 +306,7 @@ pub trait BigNum: Neg + Add + Sub + Mul + Div + Eq { /// /// # Performance Note /// Cheaper than `assert(!self.is_zero_integer())`. - fn assert_is_not_zero_integer(self); + fn assert_is_not_zero_integer(&self); /// Validates that each limb is properly range-constrained. /// @@ -323,7 +323,7 @@ pub trait BigNum: Neg + Add + Sub + Mul + Div + Eq { /// /// # Note /// This does NOT guarantee `self < MOD`. Use `validate_in_field()` for that. - fn validate_in_range(self); + fn validate_in_range(&self); /// Validates that `self <= MOD`. /// @@ -333,13 +333,13 @@ pub trait BigNum: Neg + Add + Sub + Mul + Div + Eq { /// - This is a STRONGER check than `validate_in_range()`. /// - Unlike `validate_in_range()`, this check is NOT deduplicated. Repeated calls will /// add redundant constraints. - fn validate_in_field(self); + fn validate_in_field(&self); /// Constrained squaring: computes `self * self (mod MOD)`. /// /// # Performance Note /// For multiple operations, prefer using `__sqr()` followed by `evaluate_quadratic_expression`. - fn sqr(self) -> Self; + fn sqr(&self) -> Self; /// Constrained integer division with remainder. /// @@ -354,7 +354,7 @@ pub trait BigNum: Neg + Add + Sub + Mul + Div + Eq { /// /// # Note /// This is INTEGER division. For modular division on prime fields, use the `/` operator. - fn udiv_mod(self, divisor: Self) -> (Self, Self); + fn udiv_mod(&self, divisor: &Self) -> (Self, Self); /// Constrained integer division: returns `floor(self / divisor)`. /// @@ -363,7 +363,7 @@ pub trait BigNum: Neg + Add + Sub + Mul + Div + Eq { /// /// # Note /// This is INTEGER division, not modular. The remainder is discarded. - fn udiv(self, divisor: Self) -> Self; + fn udiv(&self, divisor: &Self) -> Self; /// Constrained integer modulo: returns `self % divisor`. /// @@ -372,7 +372,7 @@ pub trait BigNum: Neg + Add + Sub + Mul + Div + Eq { /// /// # Note /// This is INTEGER modulo. The result is the remainder after integer division. - fn umod(self, divisor: Self) -> Self; + fn umod(&self, divisor: &Self) -> Self; } // We need macros that implement the BigNum, Default, From, Neg, Add, Sub, Mul, Div, Eq, Ord traits for each bignum type @@ -415,7 +415,7 @@ pub comptime fn derive_bignum( Self { limbs } } - fn get_limbs(self: Self) -> [u128; $N] { + fn get_limbs(&self) -> [u128; $N] { self.limbs } @@ -424,7 +424,7 @@ pub comptime fn derive_bignum( self.limbs[idx] = value; } - fn get_limb(self: Self, idx: u32) -> u128 { + fn get_limb(&self, idx: u32) -> u128 { self.limbs[idx] } @@ -440,152 +440,166 @@ pub comptime fn derive_bignum( fn derive_from_seed(seed: [u8; SeedBytes]) -> Self { let params = Self::params(); - $typ::from_limbs($crate::internal::derive_from_seed::<$N, $MOD_BITS, SeedBytes>(params, seed)) + $typ::from_limbs($crate::internal::derive_from_seed::<$N, $MOD_BITS, SeedBytes>(¶ms, seed)) } unconstrained fn __derive_from_seed(seed: [u8; SeedBytes]) -> Self { let params = Self::params(); - Self { limbs: $crate::internal::__derive_from_seed::<$N, $MOD_BITS, SeedBytes>(params, seed) } + Self { limbs: $crate::internal::__derive_from_seed::<$N, $MOD_BITS, SeedBytes>(¶ms, seed) } } fn from_be_bytes(x: [u8; ($MOD_BITS + 7) / 8]) -> Self { - Self { limbs: $crate::internal::from_be_bytes::<$N, $MOD_BITS>(x) } + Self { limbs: $crate::internal::from_be_bytes::<$N, $MOD_BITS>(&x) } } - fn to_be_bytes(self) -> [u8; ($MOD_BITS + 7) / 8] { - $crate::internal::to_be_bytes::<$N, $MOD_BITS>(self.limbs) + fn to_be_bytes(&self) -> [u8; ($MOD_BITS + 7) / 8] { + $crate::internal::to_be_bytes::<$N, $MOD_BITS>(&self.limbs) } fn from_le_bytes(x: [u8; ($MOD_BITS + 7) / 8]) -> Self { - Self { limbs: $crate::internal::from_le_bytes::<$N, $MOD_BITS>(x) } + Self { limbs: $crate::internal::from_le_bytes::<$N, $MOD_BITS>(&x) } } - fn to_le_bytes(self) -> [u8; ($MOD_BITS + 7) / 8] { - $crate::internal::to_le_bytes::<$N, $MOD_BITS>(self.limbs) + fn to_le_bytes(&self) -> [u8; ($MOD_BITS + 7) / 8] { + $crate::internal::to_le_bytes::<$N, $MOD_BITS>(&self.limbs) } - unconstrained fn __eq(self: Self, other: Self) -> bool { - $crate::internal::__eq(self.get_limbs(), other.get_limbs()) + unconstrained fn __eq(&self, other: &Self) -> bool { + $crate::internal::__eq(&self.get_limbs(), &other.get_limbs()) } - unconstrained fn __is_zero(self: Self) -> bool { - $crate::internal::__is_zero(self.get_limbs()) + unconstrained fn __is_zero(&self) -> bool { + $crate::internal::__is_zero(&self.get_limbs()) } - unconstrained fn __neg(self: Self) -> Self { + unconstrained fn __neg(&self) -> Self { let params = Self::params(); - Self {limbs: $crate::internal::__neg(params.modulus, self.get_limbs())} + Self {limbs: $crate::internal::__neg(¶ms.modulus, &self.get_limbs())} } - unconstrained fn __add(self: Self, other: Self) -> Self { + unconstrained fn __add(&self, other: &Self) -> Self { let params = Self::params(); - Self {limbs: $crate::internal::__add(params.modulus, self.get_limbs(), other.get_limbs())} + Self {limbs: $crate::internal::__add(¶ms.modulus, &self.get_limbs(), &other.get_limbs())} } - unconstrained fn __sub(self: Self, other: Self) -> Self { + unconstrained fn __sub(&self, other: &Self) -> Self { let params = Self::params(); - Self {limbs: $crate::internal::__sub(params.modulus, self.get_limbs(), other.get_limbs())} + Self {limbs: $crate::internal::__sub(¶ms.modulus, &self.get_limbs(), &other.get_limbs())} } - unconstrained fn __mul(self: Self, other: Self) -> Self { + unconstrained fn __mul(&self, other: &Self) -> Self { let params = Self::params(); - Self {limbs: $crate::internal::__mul(params, self.get_limbs(), other.get_limbs())} + Self {limbs: $crate::internal::__mul(¶ms, &self.get_limbs(), &other.get_limbs())} } - unconstrained fn __sqr(self: Self) -> Self { + unconstrained fn __sqr(&self) -> Self { let params = Self::params(); - Self {limbs: $crate::internal::__sqr(params, self.get_limbs()) } + Self {limbs: $crate::internal::__sqr(¶ms, &self.get_limbs()) } } - unconstrained fn __div(self: Self, divisor: Self) -> Self { + unconstrained fn __div(&self, divisor: &Self) -> Self { let params = Self::params(); if $params.has_multiplicative_inverse { - Self { limbs: $crate::internal::__div(params, self.get_limbs(), divisor.get_limbs()) } + Self { limbs: $crate::internal::__div(¶ms, &self.get_limbs(), &divisor.get_limbs()) } } else { - Self { limbs: $crate::internal::__udiv_mod(self.get_limbs(), divisor.get_limbs()).0 } + Self { limbs: $crate::internal::__udiv_mod(&self.get_limbs(), &divisor.get_limbs()).0 } } } - unconstrained fn __udiv_mod(self: Self, divisor: Self) -> (Self, Self) { - let (q, r) = $crate::internal::__udiv_mod(self.get_limbs(), divisor.get_limbs()); + unconstrained fn __udiv_mod(&self, divisor: &Self) -> (Self, Self) { + let (q, r) = $crate::internal::__udiv_mod(&self.get_limbs(), &divisor.get_limbs()); (Self{limbs: q}, Self{limbs: r}) } - unconstrained fn __invmod(self: Self) -> Self { + unconstrained fn __invmod(&self) -> Self { let params = Self::params(); - Self {limbs: $crate::internal::__invmod(params, self.get_limbs())} + Self {limbs: $crate::internal::__invmod(¶ms, &self.get_limbs())} } - unconstrained fn __pow(self: Self, exponent: Self) -> Self { + unconstrained fn __pow(&self, exponent: &Self) -> Self { let params = Self::params(); - Self {limbs: $crate::internal::__pow(params, self.get_limbs(), exponent.get_limbs())} + Self {limbs: $crate::internal::__pow(¶ms, &self.get_limbs(), &exponent.get_limbs())} } #[deprecated("use __sqrt")] - unconstrained fn __tonelli_shanks_sqrt(self: Self) -> std::option::Option { + unconstrained fn __tonelli_shanks_sqrt(&self) -> std::option::Option { let params = Self::params(); - let maybe_limbs: Option<[u128; $N]> = $crate::internal::__sqrt(params, self.get_limbs()); + let maybe_limbs: Option<[u128; $N]> = $crate::internal::__sqrt(¶ms, &self.get_limbs()); maybe_limbs.map(|limbs| Self {limbs: limbs}) } - unconstrained fn __sqrt(self: Self) -> std::option::Option { + unconstrained fn __sqrt(&self) -> std::option::Option { let params = Self::params(); - let maybe_limbs: Option<[u128; $N]> = $crate::internal::__sqrt(params, self.get_limbs()); + let maybe_limbs: Option<[u128; $N]> = $crate::internal::__sqrt(¶ms, &self.get_limbs()); maybe_limbs.map(|limbs| Self {limbs: limbs }) } - fn assert_is_not_equal(self: Self, other: Self) { + fn assert_is_not_equal(&self, other: &Self) { let params = Self::params(); + let self_limbs = self.get_limbs(); + let other_limbs = other.get_limbs(); $crate::internal::assert_is_not_equal( - params, - self.get_limbs(), - other.get_limbs(), + ¶ms, + &self_limbs, + &other_limbs, ); } - fn validate_in_field(self: Self) { + fn validate_in_field(&self) { let params = Self::params(); - $crate::internal::validate_in_field::<$N, $MOD_BITS>(params, self.get_limbs()); + let limbs = self.get_limbs(); + $crate::internal::validate_in_field::<$N, $MOD_BITS>(¶ms, &limbs); } - fn validate_in_range(self: Self) { - $crate::internal::validate_in_range::(self.get_limbs()); + fn validate_in_range(&self) { + let limbs = self.get_limbs(); + $crate::internal::validate_in_range::(&limbs); } - fn sqr(self: Self) -> Self { + fn sqr(&self) -> Self { let params = Self::params(); Self { limbs: $crate::internal::sqr::<$N, $MOD_BITS>(params, self.get_limbs()) } } - fn udiv_mod(self: Self, divisor: Self) -> (Self, Self) { - let (q, r) = $crate::internal::udiv_mod::<$N, $MOD_BITS>(self.get_limbs(), divisor.get_limbs()); + fn udiv_mod(&self, divisor: &Self) -> (Self, Self) { + let self_limbs = self.get_limbs(); + let divisor_limbs = divisor.get_limbs(); + let (q, r) = $crate::internal::udiv_mod::<$N, $MOD_BITS>(&self_limbs, &divisor_limbs); (Self {limbs: q}, Self {limbs: r}) } - fn udiv(self: Self, divisor: Self) -> Self { - Self {limbs: $crate::internal::udiv::<$N, $MOD_BITS>(self.get_limbs(), divisor.get_limbs())} + fn udiv(&self, divisor: &Self) -> Self { + let self_limbs = self.get_limbs(); + let divisor_limbs = divisor.get_limbs(); + Self {limbs: $crate::internal::udiv::<$N, $MOD_BITS>(&self_limbs, &divisor_limbs)} } - fn umod(self: Self, divisor: Self) -> Self { - Self {limbs: $crate::internal::umod::<$N, $MOD_BITS>(self.get_limbs(), divisor.get_limbs())} + fn umod(&self, divisor: &Self) -> Self { + let self_limbs = self.get_limbs(); + let divisor_limbs = divisor.get_limbs(); + Self {limbs: $crate::internal::umod::<$N, $MOD_BITS>(&self_limbs, &divisor_limbs)} } - fn is_zero(self: Self) -> bool { + fn is_zero(&self) -> bool { let params = Self::params(); - $crate::internal::is_zero::<$N, $MOD_BITS>(params, self.get_limbs()) + let limbs = self.get_limbs(); + $crate::internal::is_zero::<$N, $MOD_BITS>(¶ms, &limbs) } - fn is_zero_integer(self: Self) -> bool { - $crate::internal::is_zero_integer::(self.get_limbs()) + fn is_zero_integer(&self) -> bool { + let limbs = self.get_limbs(); + $crate::internal::is_zero_integer::(&limbs) } - fn assert_is_not_zero(self: Self) { + fn assert_is_not_zero(&self) { let params = Self::params(); - $crate::internal::assert_is_not_zero::<$N, $MOD_BITS>(params, self.get_limbs()); + let limbs = self.get_limbs(); + $crate::internal::assert_is_not_zero::<$N, $MOD_BITS>(¶ms, &limbs); } - fn assert_is_not_zero_integer(self: Self) { - $crate::internal::assert_is_not_zero_integer(self.get_limbs()); + fn assert_is_not_zero_integer(&self) { + let limbs = self.get_limbs(); + $crate::internal::assert_is_not_zero_integer(&limbs); } } @@ -598,25 +612,25 @@ pub comptime fn derive_bignum( impl std::convert::From for $typ { fn from(input: Field) -> Self { - $typ { limbs: $crate::internal::from_field::<$N, $MOD_BITS>($params, input) } + $typ { limbs: $crate::internal::from_field::<$N, $MOD_BITS>(&$params, input) } } } impl std::ops::Neg for $typ { fn neg(self) -> Self { - $typ { limbs: $crate::internal::neg::<$N, $MOD_BITS>($params, self.limbs) } + $typ { limbs: $crate::internal::neg::<$N, $MOD_BITS>(&$params, &self.limbs) } } } impl std::ops::Add for $typ { fn add(self, other: Self) -> Self { - $typ { limbs: $crate::internal::add::<$N, $MOD_BITS>($params, self.limbs, other.limbs) } + $typ { limbs: $crate::internal::add::<$N, $MOD_BITS>(&$params, &self.limbs, &other.limbs) } } } impl std::ops::Sub for $typ { fn sub(self, other: Self) -> Self { - $typ { limbs: $crate::internal::sub::<$N, $MOD_BITS>($params, self.limbs, other.limbs) } + $typ { limbs: $crate::internal::sub::<$N, $MOD_BITS>(&$params, &self.limbs, &other.limbs) } } } @@ -631,20 +645,20 @@ pub comptime fn derive_bignum( if $params.has_multiplicative_inverse { $typ { limbs: $crate::internal::div::<$N, $MOD_BITS>($params, self.limbs, other.limbs) } } else { - $typ { limbs: $crate::internal::udiv::<$N, $MOD_BITS>(self.limbs, other.limbs) } + $typ { limbs: $crate::internal::udiv::<$N, $MOD_BITS>(&self.limbs, &other.limbs) } } } } impl std::cmp::Eq for $typ { fn eq(self, other: Self) -> bool { - $crate::internal::eq::<$N, $MOD_BITS>($params, self.limbs, other.limbs) + $crate::internal::eq::<$N, $MOD_BITS>(&$params, &self.limbs, &other.limbs) } } impl std::cmp::Ord for $typ { fn cmp(self, other: Self) -> std::cmp::Ordering { - $crate::internal::cmp::<$N, $MOD_BITS>(self.limbs, other.limbs) + $crate::internal::cmp::<$N, $MOD_BITS>(&self.limbs, &other.limbs) } } @@ -715,19 +729,19 @@ pub unconstrained fn compute_quadratic_expression (T, T) { let params = T::params(); let (q_limbs, r_limbs) = crate::fns::expressions::__compute_quadratic_expression( - params, - crate::utils::map::map( - lhs_terms, - |bns| crate::utils::map::map(bns, |bn: T| bn.get_limbs()), + ¶ms, + &crate::utils::map::map( + &lhs_terms, + |bns| crate::utils::map::map(bns, |bn: &T| bn.get_limbs()), ), - lhs_flags, - crate::utils::map::map( - rhs_terms, - |bns| crate::utils::map::map(bns, |bn: T| bn.get_limbs()), + &lhs_flags, + &crate::utils::map::map( + &rhs_terms, + |bns| crate::utils::map::map(bns, |bn: &T| bn.get_limbs()), ), - rhs_flags, - crate::utils::map::map(linear_terms, |bn: T| bn.get_limbs()), - linear_flags, + &rhs_flags, + &crate::utils::map::map(&linear_terms, |bn: &T| bn.get_limbs()), + &linear_flags, ); (T::from_limbs(q_limbs), T::from_limbs(r_limbs)) } @@ -799,19 +813,19 @@ pub fn evaluate_quadratic_expression(x: [T; M]) -> [T; M] { let params = T::params(); assert(params.has_multiplicative_inverse); - crate::fns::unconstrained_ops::batch_invert(params, x.map(|bn: T| bn.get_limbs())).map(|limbs| { + let limb_arr = x.map(|bn: T| bn.get_limbs()); + crate::fns::unconstrained_ops::batch_invert(¶ms, &limb_arr).map(|limbs| { T::from_limbs(limbs) }) } @@ -865,7 +880,7 @@ pub unconstrained fn batch_invert(x: [T; M]) -> [T; M] { pub unconstrained fn batch_invert_slice(x: [T]) -> [T] { let params = T::params(); assert(params.has_multiplicative_inverse); - crate::fns::unconstrained_ops::batch_invert_slice(params, x.map(|bn: T| bn.get_limbs())) + crate::fns::unconstrained_ops::batch_invert_slice(¶ms, x.map(|bn: T| bn.get_limbs())) .map(|limbs| T::from_limbs(limbs)) } @@ -885,5 +900,6 @@ pub unconstrained fn batch_invert_slice(x: [T]) -> [T] { /// field modulus before calling this. pub fn to_field(bn: T) -> Field { let params = T::params(); - limbs_to_field(params, bn.get_limbs()) + let limbs = bn.get_limbs(); + limbs_to_field(¶ms, &limbs) } diff --git a/src/fns/constrained_ops.nr b/src/fns/constrained_ops.nr index 0dad3a5f..e7701939 100644 --- a/src/fns/constrained_ops.nr +++ b/src/fns/constrained_ops.nr @@ -18,8 +18,8 @@ use std::cmp::Ordering; /// - check that it's a proper `BigNum` value /// - validate the limbs sum up to a `Field` value pub(crate) fn limbs_to_field( - _params: BigNumParams, - limbs: [u128; N], + _params: &BigNumParams, + limbs: &[u128; N], ) -> Field { validate_in_range::(limbs); if N > 2 { @@ -28,7 +28,7 @@ pub(crate) fn limbs_to_field( grumpkin_modulus[0] = GRUMPKIN_MODULUS[0]; grumpkin_modulus[1] = GRUMPKIN_MODULUS[1]; grumpkin_modulus[2] = GRUMPKIN_MODULUS[2]; - validate_gt::(grumpkin_modulus, limbs); + validate_gt::(&grumpkin_modulus, limbs); } if N < 2 { @@ -52,7 +52,7 @@ pub(crate) fn limbs_to_field( /// Next we verify that all the limbs are properly ranged /// and that the accumulated limbs are equal to the input `Field` value pub fn from_field( - _params: BigNumParams, + _params: &BigNumParams, val: Field, ) -> [u128; N] { // Safety: we check that the resulting limbs represent the intended field element @@ -63,7 +63,7 @@ pub fn from_field( // validate the limbs are in range and the value in total is less than 2^254 if MOD_BITS < 253 { // this means that the field modulus is smaller than grumpkin modulus so we have to check if the element fields in the field size - validate_in_field(_params, result); + validate_in_field(_params, &result); } else { let mut grumpkin_modulus: [u128; N] = [0; N]; grumpkin_modulus[0] = GRUMPKIN_MODULUS[0]; @@ -73,7 +73,7 @@ pub fn from_field( if MOD_BITS > 253 { // this means that the field modulus is larger than grumpkin modulus so we have to check if the element fields in the field size are less than the grumpkin modulus. // also for correct params N is always larger than 3 here - validate_gt::(grumpkin_modulus, result); + validate_gt::(&grumpkin_modulus, &result); } else { // this is the tricky part, when MOD_BITS = 253, we have to compare the limbs of the modulus to the grumpkin modulus limbs // any `BigNum` with 253 bits will have 3 limbs @@ -90,10 +90,10 @@ pub fn from_field( } else { grumpkin_modulus }; - validate_gt::(min_modulus, result); + validate_gt::(&min_modulus, &result); } } - validate_in_range::(result); + validate_in_range::(&result); // validate the limbs sum up to the field value let field_val: Field = if N < 2 { @@ -131,7 +131,7 @@ pub fn from_field( /// If `MOD = 2^{120 * (N - 1)}` /// It will use only `MOD_BITS - 1` bits of entropy pub fn derive_from_seed( - params: BigNumParams, + params: &BigNumParams, seed: [u8; SeedBytes], ) -> [u128; N] { // Pack seed bytes into Fields. @@ -246,8 +246,8 @@ pub fn derive_from_seed( let mut result: [u128; N] = bigfield_chunks[0]; for i in 1..num_bigfield_chunks { - result = mul(params, result, bigfield_rhs_limbs); - result = add(params, result, bigfield_chunks[i]); + result = mul(*params, result, bigfield_rhs_limbs); + result = add(params, &result, &bigfield_chunks[i]); } result @@ -290,9 +290,9 @@ pub fn derive_from_seed( /// /// In case `MOD` < `p` this function becomes *complete* pub fn assert_is_not_equal( - params: BigNumParams, - lhs: [u128; N], - rhs: [u128; N], + params: &BigNumParams, + lhs: &[u128; N], + rhs: &[u128; N], ) { let mut l: Field = 0; let mut r: Field = 0; @@ -327,12 +327,12 @@ pub fn assert_is_not_equal( /// ## TODO /// can do this more efficiently via witngen in unconstrained functions? pub fn eq( - params: BigNumParams, - lhs: [u128; N], - rhs: [u128; N], + params: &BigNumParams, + lhs: &[u128; N], + rhs: &[u128; N], ) -> bool { let diff: [u128; N] = sub::(params, lhs, rhs); - is_zero::(params, diff) + is_zero::(params, &diff) } /// Validate that `val` is not equal to zero when interpreted as an integer. @@ -360,7 +360,7 @@ pub fn eq( /// ## Note /// This is slightly cheaper than doing `val != [0; N]`, as we avoid /// creating per-limb boolean equalities and chaining them with `and`s. -pub fn assert_is_not_zero_integer(val: [T; N]) +pub fn assert_is_not_zero_integer(val: &[T; N]) where T: Into, { @@ -377,7 +377,7 @@ where /// ## Note /// This is slightly cheaper than testing `val == [0; N]`, as we avoid /// creating per-limb boolean equalities and chaining them with `and`s. -pub fn is_zero_integer(val: [T; N]) -> bool +pub fn is_zero_integer(val: &[T; N]) -> bool where T: Into, { @@ -392,10 +392,11 @@ where /// /// Convenience wrapper around `assert_is_not_equal`. pub fn assert_is_not_zero( - params: BigNumParams, - val: [u128; N], + params: &BigNumParams, + val: &[u128; N], ) { - assert_is_not_equal::(params, val, [0; N]); + let zero: [u128; N] = [0; N]; + assert_is_not_equal::(params, val, &zero); } /// Check whether a `BigNum` value is zero modulo `MOD`. @@ -408,10 +409,10 @@ pub fn assert_is_not_zero( /// ## Note /// This is cheaper than calling `eq(val, [0; N])` pub fn is_zero( - params: BigNumParams, - val: [u128; N], + params: &BigNumParams, + val: &[u128; N], ) -> bool { - is_zero_integer(val) | (val == params.modulus) + is_zero_integer(val) | (*val == params.modulus) } /// Validate a `BigNum` instance is correctly range constrained to contain no more than `MOD_BITS` bits @@ -431,7 +432,7 @@ pub fn is_zero( /// It is nicely decomposed into /// 14-bit range check and 3-bit range check - much much cheaper, since 14-bit range checks /// are already pretty common for 120-bit range checks -pub fn validate_in_range(limbs: [T; N]) +pub fn validate_in_range(limbs: &[T; N]) where T: Into, { @@ -451,7 +452,7 @@ where /// 2. in `evaluate_quadratic_expression`, we require that for `expression - quotient * modulus`, /// limbs cannot exceed `246` bits (246 magic number due to a higher number adding extra range check gates) /// because of factor 2 and the fact that modulus limbs are 120 bits, quotient limbs cannot be > 126 bits -pub(crate) fn validate_quotient_in_range(limbs: [u128; N]) { +pub(crate) fn validate_quotient_in_range(limbs: &[u128; N]) { for i in 0..(N - 1) { (limbs[i] as Field).assert_max_bit_size::<120>(); } @@ -482,17 +483,17 @@ pub(crate) fn validate_quotient_in_range(limbs: [ /// `assert_is_not_zero_integer(result)` is crucial. Without it, we could always provide /// two identical inputs `x`, `x` and set `borrow_flags = [false; N]`, /// which would satisfy the limb constraints. -pub(crate) fn validate_gt(lhs: [u128; N], rhs: [u128; N]) { +pub(crate) fn validate_gt(lhs: &[u128; N], rhs: &[u128; N]) { // Safety: borrow_flags are constrained by `assert_sub_no_overflow`: // - incorrect flags cause the computed result to fail 128-bit range checks // - if lhs < rhs, the subtraction overflows Field and fails range checks let borrow_flags: [bool; N - 1] = unsafe { compute_borrow_flags(lhs, rhs) }; // Compute and validate the result in constrained code - let result: [u128; N] = assert_sub_no_overflow::(lhs, rhs, borrow_flags); + let result: [u128; N] = assert_sub_no_overflow::(lhs, rhs, &borrow_flags); // Constrain it to be strict inequality - assert_is_not_zero_integer(result); + assert_is_not_zero_integer(&result); } /// Compute the result of lhs - rhs given borrow flags and validate it @@ -506,9 +507,9 @@ pub(crate) fn validate_gt(lhs: [u128; N], rhs: [u /// Also validates that the result is a valid BigNum value (each limb in range). /// If lhs < rhs, the result will have a negative value (wrapped in Field) and fail validation. fn assert_sub_no_overflow( - lhs: [u128; N], - rhs: [u128; N], - borrow_flags: [bool; N - 1], + lhs: &[u128; N], + rhs: &[u128; N], + borrow_flags: &[bool; N - 1], ) -> [u128; N] { let mut result: [u128; N] = [0; N]; @@ -534,7 +535,7 @@ fn assert_sub_no_overflow( result[N - 1] = limb_last as u128; // Validate that the result is a valid BigNum value (120-bit limbs, TOP_LIMB_BITS for last) - validate_in_range::(result); + validate_in_range::(&result); result } @@ -549,14 +550,14 @@ fn assert_sub_no_overflow( /// /// Also validates that the result is a valid BigNum value (each limb in range). fn compute_sub_result( - modulus: [u128; N], - lhs: [u128; N], - rhs: [u128; N], - carry_flags: [bool; N - 1], - borrow_flags: [bool; N - 1], + modulus: &[u128; N], + lhs: &[u128; N], + rhs: &[u128; N], + carry_flags: &[bool; N - 1], + borrow_flags: &[bool; N - 1], underflow: bool, ) -> [u128; N] { - let addend: [u128; N] = if underflow { modulus } else { [0; N] }; + let addend: [u128; N] = if underflow { *modulus } else { [0; N] }; let mut result: [u128; N] = [0; N]; @@ -588,7 +589,7 @@ fn compute_sub_result( result[N - 1] = limb_last as u128; // Validate that the result is a valid BigNum value (120-bit limbs, TOP_LIMB_BITS for last) - validate_in_range::(result); + validate_in_range::(&result); result } @@ -603,14 +604,14 @@ fn compute_sub_result( /// /// Also validates that the result is a valid BigNum value (each limb in range). fn compute_add_result( - modulus: [u128; N], - lhs: [u128; N], - rhs: [u128; N], - carry_flags: [bool; N - 1], - borrow_flags: [bool; N - 1], + modulus: &[u128; N], + lhs: &[u128; N], + rhs: &[u128; N], + carry_flags: &[bool; N - 1], + borrow_flags: &[bool; N - 1], overflow: bool, ) -> [u128; N] { - let subtrahend: [u128; N] = if overflow { modulus } else { [0; N] }; + let subtrahend: [u128; N] = if overflow { *modulus } else { [0; N] }; let mut result: [u128; N] = [0; N]; @@ -642,7 +643,7 @@ fn compute_add_result( result[N - 1] = limb_last as u128; // Validate that the result is a valid BigNum value (120-bit limbs, TOP_LIMB_BITS for last) - validate_in_range::(result); + validate_in_range::(&result); result } @@ -659,15 +660,15 @@ fn compute_add_result( /// In contrast to `validate_gt`, we allow the value to be `MOD` /// Since it is consistent with the rest of the library pub fn validate_in_field( - params: BigNumParams, - val: [u128; N], + params: &BigNumParams, + val: &[u128; N], ) { let modulus: [u128; N] = params.modulus; // Safety: borrow_flags are constrained by the `validate_in_range` check on `p_minus_self`: // - incorrect flags cause `p_minus_self` limbs to overflow Field, failing range checks // - if val > modulus, the subtraction overflows and fails the range check - let borrow_flags: [bool; (N - 1)] = unsafe { compute_borrow_flags(modulus, val) }; + let borrow_flags: [bool; (N - 1)] = unsafe { compute_borrow_flags(&modulus, val) }; let mut p_minus_self: [Field; N] = [0; N]; p_minus_self[0] = (modulus[0] as Field) - (val[0] as Field) @@ -679,7 +680,7 @@ pub fn validate_in_field( } p_minus_self[N - 1] = (modulus[N - 1] as Field) - (val[N - 1] as Field) - (borrow_flags[N - 2] as Field); - validate_in_range::(p_minus_self); + validate_in_range::(&p_minus_self); } /// Compare two `BigNum` values @@ -689,7 +690,7 @@ pub fn validate_in_field( /// ## Note /// This is a strict value comparison over the integers, /// the values do not have to be reduced modulo `MOD`. -pub fn cmp(lhs: [u128; N], rhs: [u128; N]) -> Ordering { +pub fn cmp(lhs: &[u128; N], rhs: &[u128; N]) -> Ordering { // Safety: underflow and borrow_flags are constrained by `assert_sub_no_overflow`: // - we swap (a, b) based on underflow, then compute a - b // - if underflow is wrong, a < b and the subtraction overflows, failing range checks @@ -697,12 +698,16 @@ pub fn cmp(lhs: [u128; N], rhs: [u128; N]) -> Ord let (underflow, borrow_flags): (bool, [bool; N - 1]) = unsafe { compute_gte_flags(lhs, rhs) }; // if underflow is true, swap lhs and rhs so we compute larger - smaller - let (a, b): ([u128; N], [u128; N]) = if underflow { (rhs, lhs) } else { (lhs, rhs) }; + let (a, b): ([u128; N], [u128; N]) = if underflow { + (*rhs, *lhs) + } else { + (*lhs, *rhs) + }; // Enforce correctness of `underflow` by asserting that the subtraction does not overflow. - let _: [u128; N] = assert_sub_no_overflow::(a, b, borrow_flags); + let _: [u128; N] = assert_sub_no_overflow::(&a, &b, &borrow_flags); - if lhs == rhs { + if *lhs == *rhs { Ordering::equal() } else if underflow { Ordering::less() @@ -746,23 +751,24 @@ pub fn cmp(lhs: [u128; N], rhs: [u128; N]) -> Ord /// ## Note /// This function returns `MOD` when `val` is zero. pub fn neg( - params: BigNumParams, - val: [u128; N], + params: &BigNumParams, + val: &[u128; N], ) -> [u128; N] { if std::runtime::is_unconstrained() { // Safety: unconstrained runtime requires no constraints unsafe { - __neg(params.modulus, val) + __neg(¶ms.modulus, val) } } else { + let modulus = params.modulus; // Safety: borrow_flags are constrained by `assert_sub_no_overflow`: // - incorrect flags cause computed limbs to fail 128-bit range checks // - if val > modulus, the subtraction overflows and fails range checks // (but val > modulus violates function preconditions) - let borrow_flags: [bool; N - 1] = unsafe { compute_borrow_flags(params.modulus, val) }; + let borrow_flags: [bool; N - 1] = unsafe { compute_borrow_flags(&modulus, val) }; // Subtract `val` from the modulus to negate. - assert_sub_no_overflow::(params.modulus, val, borrow_flags) + assert_sub_no_overflow::(&modulus, val, &borrow_flags) } } @@ -865,31 +871,25 @@ pub fn neg( /// no longer corresponds to a unique, well-defined addition in the field /// `Z / MOD Z`. Such uses are outside the intended semantics of this function. pub fn add( - params: BigNumParams, - lhs: [u128; N], - rhs: [u128; N], + params: &BigNumParams, + lhs: &[u128; N], + rhs: &[u128; N], ) -> [u128; N] { if std::runtime::is_unconstrained() { // Safety: unconstrained runtime requires no constraints unsafe { - __add(params.modulus, lhs, rhs) + __add(¶ms.modulus, lhs, rhs) } } else { + let modulus = params.modulus; // Safety: flags are constrained by `compute_add_result`: // - incorrect flags cause computed limbs to fail 128-bit range checks // - wrong overflow causes result to be off by modulus, failing `validate_in_range` let (carry_flags, borrow_flags, overflow): ([bool; N - 1], [bool; N - 1], bool) = - unsafe { compute_add_flags(params.modulus, lhs, rhs) }; + unsafe { compute_add_flags(&modulus, lhs, rhs) }; // Compute and validate the result in constrained code - compute_add_result::( - params.modulus, - lhs, - rhs, - carry_flags, - borrow_flags, - overflow, - ) + compute_add_result::(&modulus, lhs, rhs, &carry_flags, &borrow_flags, overflow) } } @@ -994,29 +994,30 @@ pub fn add( /// no longer corresponds to a unique, well-defined subtraction in the field /// `Z / MOD Z`. Such uses are outside the intended semantics of this function. pub fn sub( - params: BigNumParams, - lhs: [u128; N], - rhs: [u128; N], + params: &BigNumParams, + lhs: &[u128; N], + rhs: &[u128; N], ) -> [u128; N] { if std::runtime::is_unconstrained() { // Safety: unconstrained runtime requires no constraints unsafe { - __sub(params.modulus, lhs, rhs) + __sub(¶ms.modulus, lhs, rhs) } } else { + let modulus = params.modulus; // Safety: flags are constrained by `compute_sub_result`: // - incorrect flags cause computed limbs to fail 128-bit range checks // - wrong underflow causes result to be off by modulus, failing `validate_in_range` let (carry_flags, borrow_flags, underflow): ([bool; N - 1], [bool; N - 1], bool) = - unsafe { compute_sub_flags(params.modulus, lhs, rhs) }; + unsafe { compute_sub_flags(&modulus, lhs, rhs) }; // Compute and validate the result in constrained code compute_sub_result::( - params.modulus, + &modulus, lhs, rhs, - carry_flags, - borrow_flags, + &carry_flags, + &borrow_flags, underflow, ) } @@ -1042,17 +1043,23 @@ pub fn mul( rhs: [u128; N], ) -> [u128; N] { // Safety: we constrain the multiplication result immediately after - let result: [u128; N] = unsafe { __mul::(params, lhs, rhs) }; + let result: [u128; N] = unsafe { __mul::(¶ms, &lhs, &rhs) }; if !std::runtime::is_unconstrained() { // lhs * rhs - result = 0 - evaluate_quadratic_expression( - params, - [[lhs]], - [[false]], - [[rhs]], - [[false]], - [result], - [true], + let lhs_terms: [[[u128; N]; 1]; 1] = [[lhs]]; + let lhs_flags: [[bool; 1]; 1] = [[false]]; + let rhs_terms: [[[u128; N]; 1]; 1] = [[rhs]]; + let rhs_flags: [[bool; 1]; 1] = [[false]]; + let linear_terms: [[u128; N]; 1] = [result]; + let linear_flags: [bool; 1] = [true]; + evaluate_quadratic_expression::( + ¶ms, + &lhs_terms, + &lhs_flags, + &rhs_terms, + &rhs_flags, + &linear_terms, + &linear_flags, ); } result @@ -1077,17 +1084,23 @@ pub fn sqr( val: [u128; N], ) -> [u128; N] { // Safety: we constrain the multiplication result immediately after - let result: [u128; N] = unsafe { __sqr::<_, MOD_BITS>(params, val) }; + let result: [u128; N] = unsafe { __sqr::<_, MOD_BITS>(¶ms, &val) }; if !std::runtime::is_unconstrained() { // val * val - result = 0 - evaluate_quadratic_expression( - params, - [[val]], - [[false]], - [[val]], - [[false]], - [result], - [true], + let lhs_terms: [[[u128; N]; 1]; 1] = [[val]]; + let lhs_flags: [[bool; 1]; 1] = [[false]]; + let rhs_terms: [[[u128; N]; 1]; 1] = [[val]]; + let rhs_flags: [[bool; 1]; 1] = [[false]]; + let linear_terms: [[u128; N]; 1] = [result]; + let linear_flags: [bool; 1] = [true]; + evaluate_quadratic_expression::( + ¶ms, + &lhs_terms, + &lhs_flags, + &rhs_terms, + &rhs_flags, + &linear_terms, + &linear_flags, ); } result @@ -1123,19 +1136,25 @@ pub fn div( "BigNum has no multiplicative inverse. Use udiv for unsigned integer division", ); // Safety: We constrain the result of division immediately after - let result: [u128; N] = unsafe { __div::<_, MOD_BITS>(params, lhs, rhs) }; + let result: [u128; N] = unsafe { __div::<_, MOD_BITS>(¶ms, &lhs, &rhs) }; if !std::runtime::is_unconstrained() { // result * rhs - lhs = 0 - evaluate_quadratic_expression( - params, - [[result]], - [[false]], - [[rhs]], - [[false]], - [lhs], - [true], + let lhs_terms: [[[u128; N]; 1]; 1] = [[result]]; + let lhs_flags: [[bool; 1]; 1] = [[false]]; + let rhs_terms: [[[u128; N]; 1]; 1] = [[rhs]]; + let rhs_flags: [[bool; 1]; 1] = [[false]]; + let linear_terms: [[u128; N]; 1] = [lhs]; + let linear_flags: [bool; 1] = [true]; + evaluate_quadratic_expression::( + ¶ms, + &lhs_terms, + &lhs_flags, + &rhs_terms, + &rhs_flags, + &linear_terms, + &linear_flags, ); - assert_is_not_zero(params, rhs); + assert_is_not_zero(¶ms, &rhs); } result } @@ -1163,16 +1182,16 @@ pub fn div( /// Enforcing `divisor != 0` is not necessary. `remainder < divisor` /// Already enforces this. pub fn udiv_mod( - numerator: [u128; N], - divisor: [u128; N], + numerator: &[u128; N], + divisor: &[u128; N], ) -> ([u128; N], [u128; N]) { // Safety: We constrain the result of __udiv_mod immediately after let (quotient, remainder): ([u128; N], [u128; N]) = unsafe { __udiv_mod(numerator, divisor) }; if !std::runtime::is_unconstrained() { // quotient * divisor + remainder - numerator = 0 - validate_udiv_mod_expression::(numerator, divisor, quotient, remainder); + validate_udiv_mod_expression::(numerator, divisor, "ient, &remainder); // remainder < divisor - validate_gt::(divisor, remainder); + validate_gt::(divisor, &remainder); } (quotient, remainder) } @@ -1181,7 +1200,10 @@ pub fn udiv_mod( /// /// Returns `floor(numerator / divisor)`. /// All constraints and soundness details are handled inside `udiv_mod`. -pub fn udiv(numerator: [u128; N], divisor: [u128; N]) -> [u128; N] { +pub fn udiv( + numerator: &[u128; N], + divisor: &[u128; N], +) -> [u128; N] { udiv_mod::(numerator, divisor).0 } @@ -1189,6 +1211,9 @@ pub fn udiv(numerator: [u128; N], divisor: [u128; /// /// Returns `numerator % divisor`. /// All constraints and soundness details are handled inside `udiv_mod`. -pub fn umod(numerator: [u128; N], divisor: [u128; N]) -> [u128; N] { +pub fn umod( + numerator: &[u128; N], + divisor: &[u128; N], +) -> [u128; N] { udiv_mod::(numerator, divisor).1 } diff --git a/src/fns/expressions.nr b/src/fns/expressions.nr index be1343ee..67d215e9 100644 --- a/src/fns/expressions.nr +++ b/src/fns/expressions.nr @@ -21,9 +21,9 @@ use crate::params::BigNumParams; /// /// 2. Returns the `Field` values that are not normalized to be 120-bit unconstrained fn __add_linear_expression( - params: BigNumParams, - vals: [[u128; N]; M], - flags: [bool; M], + params: &BigNumParams, + vals: &[[u128; N]; M], + flags: &[bool; M], ) -> ([Field; N]) { let mut sum: [Field; N] = [0; N]; let modulus2: [u128; N] = params.double_modulus; @@ -48,20 +48,20 @@ unconstrained fn __add_linear_expression( - params: BigNumParams, - lhs_terms: [[[u128; N]; LHS_N]; NUM_PRODUCTS], - lhs_flags: [[bool; LHS_N]; NUM_PRODUCTS], - rhs_terms: [[[u128; N]; RHS_N]; NUM_PRODUCTS], - rhs_flags: [[bool; RHS_N]; NUM_PRODUCTS], - linear_terms: [[u128; N]; ADD_N], - linear_flags: [bool; ADD_N], + params: &BigNumParams, + lhs_terms: &[[[u128; N]; LHS_N]; NUM_PRODUCTS], + lhs_flags: &[[bool; LHS_N]; NUM_PRODUCTS], + rhs_terms: &[[[u128; N]; RHS_N]; NUM_PRODUCTS], + rhs_flags: &[[bool; RHS_N]; NUM_PRODUCTS], + linear_terms: &[[u128; N]; ADD_N], + linear_flags: &[bool; ADD_N], ) -> [Field; 2 * N] { let mut lhs: [[Field; N]; NUM_PRODUCTS] = [[0; N]; NUM_PRODUCTS]; let mut rhs: [[Field; N]; NUM_PRODUCTS] = [[0; N]; NUM_PRODUCTS]; for i in 0..NUM_PRODUCTS { - lhs[i] = __add_linear_expression(params, lhs_terms[i], lhs_flags[i]); - rhs[i] = __add_linear_expression(params, rhs_terms[i], rhs_flags[i]); + lhs[i] = __add_linear_expression(params, &lhs_terms[i], &lhs_flags[i]); + rhs[i] = __add_linear_expression(params, &rhs_terms[i], &rhs_flags[i]); } let add: [Field; N] = __add_linear_expression(params, linear_terms, linear_flags); @@ -161,13 +161,13 @@ unconstrained fn __compute_borrow_flags( /// in-circuit by `compute_quadratic_expression_with_modulus` and /// `evaluate_quadratic_expression`. unconstrained fn __compute_quadratic_expression_with_borrow_flags( - params: BigNumParams, - lhs_terms: [[[u128; N]; LHS_N]; NUM_PRODUCTS], - lhs_flags: [[bool; LHS_N]; NUM_PRODUCTS], - rhs_terms: [[[u128; N]; RHS_N]; NUM_PRODUCTS], - rhs_flags: [[bool; RHS_N]; NUM_PRODUCTS], - linear_terms: [[u128; N]; ADD_N], - linear_flags: [bool; ADD_N], + params: &BigNumParams, + lhs_terms: &[[[u128; N]; LHS_N]; NUM_PRODUCTS], + lhs_flags: &[[bool; LHS_N]; NUM_PRODUCTS], + rhs_terms: &[[[u128; N]; RHS_N]; NUM_PRODUCTS], + rhs_flags: &[[bool; RHS_N]; NUM_PRODUCTS], + linear_terms: &[[u128; N]; ADD_N], + linear_flags: &[bool; ADD_N], ) -> ([u128; N], [bool; 2 * N - 2]) { let mulout_p: [Field; 2 * N] = __compute_quadratic_expression_product( params, @@ -181,11 +181,15 @@ unconstrained fn __compute_quadratic_expression_with_borrow_flags( - params: BigNumParams, - lhs_terms: [[[u128; N]; LHS_N]; NUM_PRODUCTS], - lhs_flags: [[bool; LHS_N]; NUM_PRODUCTS], - rhs_terms: [[[u128; N]; RHS_N]; NUM_PRODUCTS], - rhs_flags: [[bool; RHS_N]; NUM_PRODUCTS], - linear_terms: [[u128; N]; ADD_N], - linear_flags: [bool; ADD_N], + params: &BigNumParams, + lhs_terms: &[[[u128; N]; LHS_N]; NUM_PRODUCTS], + lhs_flags: &[[bool; LHS_N]; NUM_PRODUCTS], + rhs_terms: &[[[u128; N]; RHS_N]; NUM_PRODUCTS], + rhs_flags: &[[bool; RHS_N]; NUM_PRODUCTS], + linear_terms: &[[u128; N]; ADD_N], + linear_flags: &[bool; ADD_N], ) -> ([u128; N], [u128; N]) { let mulout: [Field; 2 * N] = __compute_quadratic_expression_product( params, @@ -232,8 +236,12 @@ pub(crate) unconstrained fn __compute_quadratic_expression( - params: BigNumParams, - lhs_terms: [[[u128; N]; LHS_N]; NUM_PRODUCTS], - lhs_flags: [[bool; LHS_N]; NUM_PRODUCTS], - rhs_terms: [[[u128; N]; RHS_N]; NUM_PRODUCTS], - rhs_flags: [[bool; RHS_N]; NUM_PRODUCTS], - linear_terms: [[u128; N]; ADD_N], - linear_flags: [bool; ADD_N], + params: &BigNumParams, + lhs_terms: &[[[u128; N]; LHS_N]; NUM_PRODUCTS], + lhs_flags: &[[bool; LHS_N]; NUM_PRODUCTS], + rhs_terms: &[[[u128; N]; RHS_N]; NUM_PRODUCTS], + rhs_flags: &[[bool; RHS_N]; NUM_PRODUCTS], + linear_terms: &[[u128; N]; ADD_N], + linear_flags: &[bool; ADD_N], ) -> ([[Field; N]; NUM_PRODUCTS], [[Field; N]; NUM_PRODUCTS], [Field; N]) { // lhs linear terms let mut lhs_linear: [[Field; N]; NUM_PRODUCTS] = [[0; N]; NUM_PRODUCTS]; @@ -337,13 +345,13 @@ fn compute_linear_expressions( - params: BigNumParams, - lhs_terms: [[[u128; N]; LHS_N]; NUM_PRODUCTS], - lhs_flags: [[bool; LHS_N]; NUM_PRODUCTS], - rhs_terms: [[[u128; N]; RHS_N]; NUM_PRODUCTS], - rhs_flags: [[bool; RHS_N]; NUM_PRODUCTS], - linear_terms: [[u128; N]; ADD_N], - linear_flags: [bool; ADD_N], + params: &BigNumParams, + lhs_terms: &[[[u128; N]; LHS_N]; NUM_PRODUCTS], + lhs_flags: &[[bool; LHS_N]; NUM_PRODUCTS], + rhs_terms: &[[[u128; N]; RHS_N]; NUM_PRODUCTS], + rhs_flags: &[[bool; RHS_N]; NUM_PRODUCTS], + linear_terms: &[[u128; N]; ADD_N], + linear_flags: &[bool; ADD_N], ) -> [Field; 2 * N - 1] { // Safety: use an unconstrained function to compute the value of the quotient and borrow_flags out-of-circuit let (quotient, borrow_flags): ([u128; N], [bool; 2 * N - 2]) = unsafe { @@ -360,7 +368,7 @@ fn compute_quadratic_expression_with_modulus(quotient); + validate_quotient_in_range::("ient); // Compute the linear sums that represent L_i, R_i, A let (lhs_linear, rhs_linear, lin_expr): ([[Field; N]; NUM_PRODUCTS], [[Field; N]; NUM_PRODUCTS], [Field; N]) = compute_linear_expressions::( @@ -631,24 +639,24 @@ fn validate_expression_is_zero(mut limbs: [Field; N]) { /// requires canonical outputs in `[0, MOD)` must additionally enforce a /// range check (for example via `validate_in_field`) on the relevant terms. pub(crate) fn evaluate_quadratic_expression( - params: BigNumParams, - lhs_terms: [[[u128; N]; LHS_N]; NUM_PRODUCTS], - lhs_flags: [[bool; LHS_N]; NUM_PRODUCTS], - rhs_terms: [[[u128; N]; RHS_N]; NUM_PRODUCTS], - rhs_flags: [[bool; RHS_N]; NUM_PRODUCTS], - linear_terms: [[u128; N]; ADD_N], - linear_flags: [bool; ADD_N], + params: &BigNumParams, + lhs_terms: &[[[u128; N]; LHS_N]; NUM_PRODUCTS], + lhs_flags: &[[bool; LHS_N]; NUM_PRODUCTS], + rhs_terms: &[[[u128; N]; RHS_N]; NUM_PRODUCTS], + rhs_flags: &[[bool; RHS_N]; NUM_PRODUCTS], + linear_terms: &[[u128; N]; ADD_N], + linear_flags: &[bool; ADD_N], ) { assert(NUM_PRODUCTS < 64, f"evaluate_quadratic_expression overflow in operands count"); // NUM_PRODUCTS < 64 is a light bound that tries to ensure each limb sum < 2^{246} so that the 126-bit bound is valid. lhs_terms.for_each(|lhs_limbs: [[u128; N]; LHS_N]| { - lhs_limbs.for_each(|term: [u128; N]| validate_in_range::(term)) + lhs_limbs.for_each(|term: [u128; N]| validate_in_range::(&term)) }); rhs_terms.for_each(|rhs_limbs: [[u128; N]; RHS_N]| { - rhs_limbs.for_each(|term: [u128; N]| validate_in_range::(term)) + rhs_limbs.for_each(|term: [u128; N]| validate_in_range::(&term)) }); - linear_terms.for_each(|term: [u128; N]| validate_in_range::(term)); + linear_terms.for_each(|term: [u128; N]| validate_in_range::(&term)); let expression_limbs: [Field; 2 * N - 1] = compute_quadratic_expression_with_modulus::( params, @@ -687,10 +695,10 @@ pub(crate) fn evaluate_quadratic_expression= 64 (middle limb will have 64 additions). And it is a pure completeness issue /// But the rest of the library will probably not work with that massive number anyway unconstrained fn __compute_udiv_mod_expression_with_borrow_flags( - numerator: [u128; N], - divisor: [u128; N], - quotient: [u128; N], - remainder: [u128; N], + numerator: &[u128; N], + divisor: &[u128; N], + quotient: &[u128; N], + remainder: &[u128; N], ) -> [bool; N - 1] { let mut product_limbs: [Field; N] = [0; N]; let mut numerator_field: [Field; N] = [0; N]; @@ -731,10 +739,10 @@ unconstrained fn __compute_udiv_mod_expression_with_borrow_flags( - numerator: [u128; N], - divisor: [u128; N], - quotient: [u128; N], - remainder: [u128; N], + numerator: &[u128; N], + divisor: &[u128; N], + quotient: &[u128; N], + remainder: &[u128; N], ) -> [Field; N] { // Safety: use an unconstrained function to compute the value of the quotient and borrow_flags out-of-circuit let borrow_flags: [bool; N - 1] = unsafe { @@ -793,10 +801,10 @@ fn compute_udiv_mod_expression( /// /// then we can set quotient' = quotient - 1, remainder' = remainder + divisor pub(crate) fn validate_udiv_mod_expression( - numerator: [u128; N], - divisor: [u128; N], - quotient: [u128; N], - remainder: [u128; N], + numerator: &[u128; N], + divisor: &[u128; N], + quotient: &[u128; N], + remainder: &[u128; N], ) { validate_in_range::(numerator); validate_in_range::(divisor); diff --git a/src/fns/serialization.nr b/src/fns/serialization.nr index c1a85b1e..d13f842b 100644 --- a/src/fns/serialization.nr +++ b/src/fns/serialization.nr @@ -20,7 +20,7 @@ use crate::utils::map::invert_array; /// In principle, accumulating `u8` values already bounds the integer, /// but relying on Noir to infer a `u128` from a large linear combination /// would trigger a very general (and expensive) range checks -pub fn from_be_bytes(x: [u8; (MOD_BITS + 7) / 8]) -> [u128; N] { +pub fn from_be_bytes(x: &[u8; (MOD_BITS + 7) / 8]) -> [u128; N] { let num_bytes: u32 = (MOD_BITS + 7) / 8; let mut result: [u128; N] = [0; N]; @@ -65,7 +65,7 @@ pub fn from_be_bytes(x: [u8; (MOD_BITS + 7) / 8]) /// Consistency between `N` and `MOD_BITS` is expected: /// - the most significant limb contributes `((MOD_BITS + 7) / 8) - (N - 1) * 15` bytes; /// - all other limbs are serialized as full 15-byte chunks. -pub fn to_be_bytes(val: [u128; N]) -> [u8; (MOD_BITS + 7) / 8] { +pub fn to_be_bytes(val: &[u128; N]) -> [u8; (MOD_BITS + 7) / 8] { let mut result: [u8; (MOD_BITS + 7) / 8] = [0; (MOD_BITS + 7) / 8]; let last_limb_num_bytes: u32 = (MOD_BITS + 7) / 8 - (N - 1) * 15; @@ -94,9 +94,9 @@ pub fn to_be_bytes(val: [u128; N]) -> [u8; (MOD_B /// Reverse an array and apply `from_be_bytes` /// /// See `from_be_bytes` for details -pub fn from_le_bytes(x: [u8; (MOD_BITS + 7) / 8]) -> [u128; N] { +pub fn from_le_bytes(x: &[u8; (MOD_BITS + 7) / 8]) -> [u128; N] { let be_x: [u8; (MOD_BITS + 7) / 8] = invert_array(x); - from_be_bytes(be_x) + from_be_bytes(&be_x) } /// Construct a little-endian byte array from a `BigNum` value @@ -104,7 +104,7 @@ pub fn from_le_bytes(x: [u8; (MOD_BITS + 7) / 8]) /// Apply `to_be_bytes` and reverse an array /// /// See `to_be_bytes` for details -pub fn to_le_bytes(val: [u128; N]) -> [u8; (MOD_BITS + 7) / 8] { +pub fn to_le_bytes(val: &[u128; N]) -> [u8; (MOD_BITS + 7) / 8] { let result_be: [u8; (MOD_BITS + 7) / 8] = to_be_bytes(val); - invert_array(result_be) + invert_array(&result_be) } diff --git a/src/fns/unconstrained_helpers.nr b/src/fns/unconstrained_helpers.nr index 7961797b..68515d31 100644 --- a/src/fns/unconstrained_helpers.nr +++ b/src/fns/unconstrained_helpers.nr @@ -62,16 +62,16 @@ pub(crate) unconstrained fn __from_field(val: Field) -> [u128; N] { /// - borrow_flags: borrows from subtracting modulus when overflow occurs /// - overflow: true if lhs + rhs >= modulus (so we need to subtract modulus) pub(crate) unconstrained fn compute_add_flags( - modulus: [u128; N], - lhs: [u128; N], - rhs: [u128; N], + modulus: &[u128; N], + lhs: &[u128; N], + rhs: &[u128; N], ) -> ([bool; N - 1], [bool; N - 1], bool) { let mask: u128 = TWO_POW_120 - 1; let add_res: [u128; N] = __helper_add(lhs, rhs); - let overflow: bool = __gte(add_res, modulus); + let overflow: bool = __gte(&add_res, modulus); - let subtrahend: [u128; N] = if overflow { modulus } else { [0; N] }; + let subtrahend: [u128; N] = if overflow { *modulus } else { [0; N] }; let mut borrow_flags: [bool; N - 1] = [false; N - 1]; let mut carry_flags: [bool; N - 1] = [false; N - 1]; @@ -107,15 +107,15 @@ pub(crate) unconstrained fn compute_add_flags( /// - borrow_flags: borrows from subtracting rhs /// - underflow: true if lhs < rhs (so we need to add modulus) pub(crate) unconstrained fn compute_sub_flags( - modulus: [u128; N], - lhs: [u128; N], - rhs: [u128; N], + modulus: &[u128; N], + lhs: &[u128; N], + rhs: &[u128; N], ) -> ([bool; N - 1], [bool; N - 1], bool) { let mask: u128 = TWO_POW_120 - 1; let underflow: bool = !__gte(lhs, rhs); - let addend: [u128; N] = if underflow { modulus } else { [0; N] }; + let addend: [u128; N] = if underflow { *modulus } else { [0; N] }; let mut borrow_flags: [bool; N - 1] = [false; N - 1]; let mut carry_flags: [bool; N - 1] = [false; N - 1]; @@ -144,8 +144,8 @@ pub(crate) unconstrained fn compute_sub_flags( /// /// The result is computed in constrained code using `compute_gte_result` pub(crate) unconstrained fn compute_borrow_flags( - lhs: [u128; N], - rhs: [u128; N], + lhs: &[u128; N], + rhs: &[u128; N], ) -> [bool; N - 1] { let mut borrow_flags: [bool; N - 1] = [false; N - 1]; borrow_flags[0] = lhs[0] < rhs[0]; @@ -161,14 +161,18 @@ pub(crate) unconstrained fn compute_borrow_flags( /// - underflow is true if lhs < rhs /// - borrow_flags correspond to max(lhs, rhs) - min(lhs, rhs) pub(crate) unconstrained fn compute_gte_flags( - lhs: [u128; N], - rhs: [u128; N], + lhs: &[u128; N], + rhs: &[u128; N], ) -> (bool, [bool; N - 1]) { let underflow: bool = !__gte(lhs, rhs); // swap if underflow so we're computing borrow flags for larger - smaller - let (a, b): ([u128; N], [u128; N]) = if underflow { (rhs, lhs) } else { (lhs, rhs) }; + let (a, b): ([u128; N], [u128; N]) = if underflow { + (*rhs, *lhs) + } else { + (*lhs, *rhs) + }; - let borrow_flags = compute_borrow_flags(a, b); + let borrow_flags = compute_borrow_flags(&a, &b); (underflow, borrow_flags) } @@ -235,10 +239,10 @@ global BARRETT_REDUCTION_OVERFLOW_BITS: u32 = 6; /// (for example (a1 + b1) * (c1 + d1) + ... 64 times) /// This is highly unlikely though, but there should be more reductions in that case. pub(crate) unconstrained fn __barrett_reduction( - x: [u128; 2 * N], - redc_param: [u128; N], + x: &[u128; 2 * N], + redc_param: &[u128; N], k: u32, - modulus: [u128; N], + modulus: &[u128; N], ) -> ([u128; N], [u128; N]) { // TODO: switch to __helper_mul, once the compiler is smart enough to handle this let mut mulout_field: [Field; 3 * N] = [0; 3 * N]; @@ -249,7 +253,7 @@ pub(crate) unconstrained fn __barrett_reduction( } let mulout: [u128; 3 * N] = __normalize_limbs(mulout_field); - let quotient: [u128; 3 * N] = __shr(mulout, (k + k + BARRETT_REDUCTION_OVERFLOW_BITS)); + let quotient: [u128; 3 * N] = __shr(&mulout, (k + k + BARRETT_REDUCTION_OVERFLOW_BITS)); // Remove a bunch of zeros from the end let mut smaller_quotient: [u128; N] = [0; N]; @@ -258,8 +262,8 @@ pub(crate) unconstrained fn __barrett_reduction( } // long_quotient_mul_modulus can never exceed input value `x` so can fit into size-2 array - let long_quotient_mul_modulus: [u128; 2 * N] = __helper_mul(smaller_quotient, modulus); - let long_remainder: [u128; 2 * N] = __helper_sub(x, long_quotient_mul_modulus); + let long_quotient_mul_modulus: [u128; 2 * N] = __helper_mul(&smaller_quotient, modulus); + let long_remainder: [u128; 2 * N] = __helper_sub(x, &long_quotient_mul_modulus); // Remove a bunch of zeros from the end let mut remainder: [u128; N] = [0; N]; @@ -267,9 +271,9 @@ pub(crate) unconstrained fn __barrett_reduction( remainder[i] = long_remainder[i]; } - if (__gte(remainder, modulus)) { - remainder = __helper_sub(remainder, modulus); - smaller_quotient = __increment(smaller_quotient); + if (__gte(&remainder, modulus)) { + remainder = __helper_sub(&remainder, modulus); + smaller_quotient = __increment(&smaller_quotient); } (smaller_quotient, remainder) @@ -283,7 +287,7 @@ pub(crate) unconstrained fn __barrett_reduction( /// ## Note /// The `carry` must be `0` at the end of the loop. /// No explicit assertion is made, as this condition is validated during evaluation. -pub(crate) unconstrained fn __increment(val: [u128; N]) -> [u128; N] { +pub(crate) unconstrained fn __increment(val: &[u128; N]) -> [u128; N] { let mask: u128 = TWO_POW_120 - 1; let mut result: [u128; N] = [0; N]; @@ -301,7 +305,10 @@ pub(crate) unconstrained fn __increment(val: [u128; N]) -> [u128; N] /// ## Note /// The `carry` must be `0` at the end of the loop. /// No explicit assertion is made, as this condition is validated during evaluation. -pub(crate) unconstrained fn __helper_add(lhs: [u128; N], rhs: [u128; N]) -> [u128; N] { +pub(crate) unconstrained fn __helper_add( + lhs: &[u128; N], + rhs: &[u128; N], +) -> [u128; N] { let mut result: [u128; N] = [0; N]; let mut carry: u128 = 0; @@ -320,7 +327,10 @@ pub(crate) unconstrained fn __helper_add(lhs: [u128; N], rhs: [u128; /// ## Note /// The `borrow` must be `0` at the end of the loop. /// No explicit assertion is made, as this condition is validated during evaluation. -pub(crate) unconstrained fn __helper_sub(lhs: [u128; N], rhs: [u128; N]) -> [u128; N] { +pub(crate) unconstrained fn __helper_sub( + lhs: &[u128; N], + rhs: &[u128; N], +) -> [u128; N] { let mut result: [u128; N] = [0; N]; let mut borrow: u128 = 0; @@ -341,8 +351,8 @@ pub(crate) unconstrained fn __helper_sub(lhs: [u128; N], rhs: [u128; /// limbs intentionally as the extra high limb safely absorbs a possible single limb overflow /// for moduli close to `120 * N` bits. pub(crate) unconstrained fn __helper_mul( - lhs: [u128; N], - rhs: [u128; N], + lhs: &[u128; N], + rhs: &[u128; N], ) -> [u128; 2 * N] { let mut result: [Field; 2 * N] = [0; 2 * N]; for i in 0..N { @@ -362,11 +372,11 @@ pub(crate) unconstrained fn __helper_mul( /// ## Note /// - `MOD` must be odd. pub(crate) unconstrained fn __half_mod_odd( - modulus: [u128; N], - x: [u128; N], + modulus: &[u128; N], + x: &[u128; N], ) -> [u128; N] { let temp = if __is_even::(x) { - x + *x } else { __helper_add(x, modulus) }; @@ -386,7 +396,7 @@ pub(crate) unconstrained fn __half_mod_odd( /// No bounds check is performed on `num_shifted_limbs`. /// However, we use it only in `__udiv_mod`, where it is not possible to reach /// `num_shifted_limbs` > `N` -pub(crate) unconstrained fn __shl(input: [u128; N], shift: u32) -> [u128; N] { +pub(crate) unconstrained fn __shl(input: &[u128; N], shift: u32) -> [u128; N] { let mut result: [u128; N] = [0; N]; let num_shifted_limbs: u32 = shift / 120; @@ -415,7 +425,7 @@ pub(crate) unconstrained fn __shl(input: [u128; N], shift: u32) -> [ /// No bounds check is performed on `num_shifted_limbs`. /// However, we use it only in `__tonelli_shanks_sqrt`, where it is not possible to reach /// `num_shifted_limbs` > `N` -pub(crate) unconstrained fn __shr(input: [u128; N], shift: u32) -> [u128; N] { +pub(crate) unconstrained fn __shr(input: &[u128; N], shift: u32) -> [u128; N] { let mut result: [u128; N] = [0; N]; let num_shifted_limbs: u32 = shift / 120; @@ -455,7 +465,7 @@ pub(crate) unconstrained fn __shr1(mut input: [u128; N]) -> [u128; N } /// Returns the index of the most significant set bit in a `BigNum` value (unconstrained). -pub(crate) unconstrained fn __get_msb(val: [u128; N]) -> u32 { +pub(crate) unconstrained fn __get_msb(val: &[u128; N]) -> u32 { let mut count: u32 = 0; for i in 0..N { let idx: u32 = N - 1 - i; @@ -472,7 +482,7 @@ pub(crate) unconstrained fn __get_msb(val: [u128; N]) -> u32 { /// /// ## Note /// No bounds check is performed on `bit` -pub(crate) fn __get_bit(input: [u128; N], bit: u32) -> bool { +pub(crate) fn __get_bit(input: &[u128; N], bit: u32) -> bool { let segment_index: u32 = bit / 120; let uint_index: u128 = (bit % 120) as u128; @@ -482,7 +492,7 @@ pub(crate) fn __get_bit(input: [u128; N], bit: u32) -> bool { } /// Returns `true` if the `BigNum` value is even (unconstrained) -pub(crate) unconstrained fn __is_even(x: [u128; N]) -> bool { +pub(crate) unconstrained fn __is_even(x: &[u128; N]) -> bool { (x[0] & 1) == 0 } @@ -493,10 +503,10 @@ pub(crate) unconstrained fn __is_even(x: [u128; N]) -> bool { /// /// Find the maximum value s such that `MOD = 2^s * q + 1`, where `q` is odd /// This is needed for our Tonelli-Shanks sqrt algorithm -pub(crate) unconstrained fn __primitive_root_log_size(modulus: [u128; N]) -> u32 { - let target: [u128; N] = __helper_sub(modulus, __one()); +pub(crate) unconstrained fn __primitive_root_log_size(modulus: &[u128; N]) -> u32 { + let target: [u128; N] = __helper_sub(modulus, &__one()); let mut result: u32 = 0; - while !__get_bit(target, result) { + while !__get_bit(&target, result) { result += 1; } result @@ -509,21 +519,21 @@ pub(crate) unconstrained fn __primitive_root_log_size(modulus: [u128 /// ## Note /// WARNING If the field is not prime, this function will enter an infinite loop! pub(crate) unconstrained fn __quadratic_non_residue( - params: BigNumParams, + params: &BigNumParams, ) -> [u128; N] { let one: [u128; N] = __one(); - let neg_one: [u128; N] = __neg(params.modulus, one); + let neg_one: [u128; N] = __neg(¶ms.modulus, &one); - let p_minus_one_over_two: [u128; N] = __shr1(__helper_sub(params.modulus, __one())); + let p_minus_one_over_two: [u128; N] = __shr1(__helper_sub(¶ms.modulus, &__one())); // We start with 2 let mut target: [u128; N] = [0; N]; target[0] = 2; - let mut expd: [u128; N] = __pow(params, target, p_minus_one_over_two); - while !__eq(expd, neg_one) { - target = __increment(target); - expd = __pow(params, target, p_minus_one_over_two); + let mut expd: [u128; N] = __pow(params, &target, &p_minus_one_over_two); + while !__eq(&expd, &neg_one) { + target = __increment(&target); + expd = __pow(params, &target, &p_minus_one_over_two); } target } @@ -533,16 +543,16 @@ pub(crate) unconstrained fn __quadratic_non_residue( - params: BigNumParams, - t: [u128; N], + params: &BigNumParams, + t: &[u128; N], ) -> u32 { let one: [u128; N] = __one(); - let mut c: [u128; N] = t; + let mut c: [u128; N] = *t; let mut i: u32 = 0; // Compute t^{2^k} until it hits 1 for the first time - while !__eq(c, one) { - c = __sqr::(params, c); + while !__eq(&c, &one) { + c = __sqr::(params, &c); i += 1; } i diff --git a/src/fns/unconstrained_ops.nr b/src/fns/unconstrained_ops.nr index e73e5946..a4075640 100644 --- a/src/fns/unconstrained_ops.nr +++ b/src/fns/unconstrained_ops.nr @@ -20,7 +20,7 @@ use crate::params::BigNumParams; /// /// See more information in `constrained_ops.nr`: `derive_from_seed` pub unconstrained fn __derive_from_seed( - params: BigNumParams, + params: &BigNumParams, seed: [u8; SeedBytes], ) -> [u128; N] { derive_from_seed::(params, seed) @@ -29,20 +29,20 @@ pub unconstrained fn __derive_from_seed(lhs: [u128; N], rhs: [u128; N]) -> bool { - lhs == rhs +pub unconstrained fn __eq(lhs: &[u128; N], rhs: &[u128; N]) -> bool { + *lhs == *rhs } /// Compare a limb array to a zero array (unconstrained) -pub unconstrained fn __is_zero(limbs: [u128; N]) -> bool { - limbs == [0; N] +pub unconstrained fn __is_zero(limbs: &[u128; N]) -> bool { + *limbs == [0; N] } /// Compare two little-endian limb arrays for `lhs >= rhs` over integers (unconstrained) /// /// Starts from the most significant limb (`N - 1`) and returns true /// if `lhs` is greater or equal to `rhs` -pub(crate) unconstrained fn __gte(lhs: [u128; N], rhs: [u128; N]) -> bool { +pub(crate) unconstrained fn __gte(lhs: &[u128; N], rhs: &[u128; N]) -> bool { let mut result: bool = true; for i in 0..N { let idx: u32 = N - 1 - i; @@ -60,7 +60,7 @@ pub(crate) unconstrained fn __gte(lhs: [u128; N], rhs: [u128; N]) -> /// /// ## Note /// The input is assumed to be less than modulus -pub unconstrained fn __neg(modulus: [u128; N], limbs: [u128; N]) -> [u128; N] { +pub unconstrained fn __neg(modulus: &[u128; N], limbs: &[u128; N]) -> [u128; N] { __helper_sub(modulus, limbs) } @@ -73,9 +73,9 @@ pub unconstrained fn __neg(modulus: [u128; N], limbs: [u128; N]) -> /// The `carry` must be `0` at the end of the loop. /// No explicit assertion is made, as this condition is validated during evaluation. pub unconstrained fn __add( - modulus: [u128; N], - lhs: [u128; N], - rhs: [u128; N], + modulus: &[u128; N], + lhs: &[u128; N], + rhs: &[u128; N], ) -> [u128; N] { let mut result: [u128; N] = [0; N]; let mut carry: u128 = 0; @@ -88,8 +88,8 @@ pub unconstrained fn __add( } // check if the result is greater than the modulus - if __gte(result, modulus) { - __helper_sub(result, modulus) + if __gte(&result, modulus) { + __helper_sub(&result, modulus) } else { result } @@ -99,11 +99,11 @@ pub unconstrained fn __add( /// /// Computes `x + (m - y)` (mod m) pub unconstrained fn __sub( - modulus: [u128; N], - lhs: [u128; N], - rhs: [u128; N], + modulus: &[u128; N], + lhs: &[u128; N], + rhs: &[u128; N], ) -> [u128; N] { - __add(modulus, lhs, __neg(modulus, rhs)) + __add(modulus, lhs, &__neg(modulus, rhs)) } /// Multiply `x` and `y` and reduce via Barrett, returning (Q, R) (unconstrained). @@ -112,29 +112,29 @@ pub unconstrained fn __sub( /// x * y = R + Q * m, 0 <= R < m /// See `__barrett_reduction` for details. pub(crate) unconstrained fn __mul_with_quotient( - params: BigNumParams, - lhs: [u128; N], - rhs: [u128; N], + params: &BigNumParams, + lhs: &[u128; N], + rhs: &[u128; N], ) -> ([u128; N], [u128; N]) { let to_reduce: [u128; N * 2] = __helper_mul(lhs, rhs); let (q, r): ([u128; N], [u128; N]) = - __barrett_reduction(to_reduce, params.redc_param, MOD_BITS, params.modulus); + __barrett_reduction(&to_reduce, ¶ms.redc_param, MOD_BITS, ¶ms.modulus); (q, r) } /// Multiplies two `BigNum` values with modular reduction (unconstrained). pub unconstrained fn __mul( - params: BigNumParams, - lhs: [u128; N], - rhs: [u128; N], + params: &BigNumParams, + lhs: &[u128; N], + rhs: &[u128; N], ) -> [u128; N] { __mul_with_quotient::(params, lhs, rhs).1 } /// Squares a `BigNum` value with modular reduction (unconstrained). pub unconstrained fn __sqr( - params: BigNumParams, - val: [u128; N], + params: &BigNumParams, + val: &[u128; N], ) -> [u128; N] { __mul_with_quotient::(params, val, val).1 } @@ -147,19 +147,19 @@ pub unconstrained fn __sqr( /// For the loop, we are using `MOD_BITS` instead of `__get_msb` /// because it is much much cheaper pub unconstrained fn __pow( - params: BigNumParams, - val: [u128; N], - exponent: [u128; N], + params: &BigNumParams, + val: &[u128; N], + exponent: &[u128; N], ) -> [u128; N] { let mut accumulator: [u128; N] = __one(); - let mut x: [u128; N] = val; + let mut x: [u128; N] = *val; let num_bits: u32 = MOD_BITS + 1; for i in 0..num_bits { if __get_bit(exponent, i) { - accumulator = __mul::(params, accumulator, x); + accumulator = __mul::(params, &accumulator, &x); } - x = __sqr::(params, x); + x = __sqr::(params, &x); } accumulator } @@ -174,14 +174,14 @@ pub unconstrained fn __pow( /// - `modulus` must be odd (required by `__half_mod_odd`) /// - `gcd(val, modulus) == 1` (i.e. `val` is invertible modulo `modulus`) pub unconstrained fn __invmod( - params: BigNumParams, - val: [u128; N], + params: &BigNumParams, + val: &[u128; N], ) -> [u128; N] { assert(params.has_multiplicative_inverse); - if (__is_zero(val) | __eq(val, params.modulus)) { + if (__is_zero(val) | __eq(val, ¶ms.modulus)) { [0; N] } else { - let mut u: [u128; N] = val; + let mut u: [u128; N] = *val; let mut v: [u128; N] = params.modulus; let one: [u128; N] = __one::(); @@ -189,30 +189,30 @@ pub unconstrained fn __invmod( let mut x1: [u128; N] = one; let mut x2: [u128; N] = [0; N]; - while (!__eq(u, one) & !__eq(v, one)) { + while (!__eq(&u, &one) & !__eq(&v, &one)) { // Get rid of the even part of u - while __is_even(u) { + while __is_even(&u) { u = __shr1(u); - x1 = __half_mod_odd(params.modulus, x1); + x1 = __half_mod_odd(¶ms.modulus, &x1); } // Get rid of the even part of v - while __is_even(v) { + while __is_even(&v) { v = __shr1(v); - x2 = __half_mod_odd(params.modulus, x2); + x2 = __half_mod_odd(¶ms.modulus, &x2); } // Update the intermediate values after both (u, v) are odd - if __gte(u, v) { - u = __helper_sub(u, v); - x1 = __sub(params.modulus, x1, x2); + if __gte(&u, &v) { + u = __helper_sub(&u, &v); + x1 = __sub(¶ms.modulus, &x1, &x2); } else { - v = __helper_sub(v, u); - x2 = __sub(params.modulus, x2, x1); + v = __helper_sub(&v, &u); + x2 = __sub(¶ms.modulus, &x2, &x1); } } - if __eq(u, one) { + if __eq(&u, &one) { x1 } else { x2 @@ -228,12 +228,12 @@ pub unconstrained fn __invmod( /// The divisor must be nonzero /// No explicit assertion is made, as this condition is validated during evaluation pub unconstrained fn __div( - params: BigNumParams, - numerator: [u128; N], - divisor: [u128; N], + params: &BigNumParams, + numerator: &[u128; N], + divisor: &[u128; N], ) -> [u128; N] { let inv_divisor: [u128; N] = __invmod::(params, divisor); - __mul::(params, numerator, inv_divisor) + __mul::(params, numerator, &inv_divisor) } /// Given the `BigNum` inputs `x, y`, compute integer division x / y (unconstrained) @@ -247,39 +247,39 @@ pub unconstrained fn __div( /// The divisor must be nonzero /// No explicit assertion is made, as this condition is validated during evaluation pub unconstrained fn __udiv_mod( - numerator: [u128; N], - divisor: [u128; N], + numerator: &[u128; N], + divisor: &[u128; N], ) -> ([u128; N], [u128; N]) { let mut quotient: [u128; N] = [0; N]; - let mut remainder: [u128; N] = numerator; - let b: [u128; N] = divisor; + let mut remainder: [u128; N] = *numerator; + let b: [u128; N] = *divisor; let numerator_msb: u32 = __get_msb(numerator); let divisor_msb: u32 = __get_msb(divisor); if divisor_msb > numerator_msb { - ([0; N], numerator) + ([0; N], *numerator) } else { - let bit_difference: u32 = __get_msb(remainder) - __get_msb(divisor); + let bit_difference: u32 = __get_msb(&remainder) - __get_msb(divisor); let mut divisor: [u128; N] = __shl(divisor, bit_difference); - let mut accumulator: [u128; N] = __shl(__one(), bit_difference); + let mut accumulator: [u128; N] = __shl(&__one(), bit_difference); // The same as divisor > remainder - if (__gte(divisor, __increment(remainder))) { + if (__gte(&divisor, &__increment(&remainder))) { divisor = __shr1(divisor); accumulator = __shr1(accumulator); } for _ in 0..(N * 120) { - if (__gte(remainder, b) == false) { + if (__gte(&remainder, &b) == false) { break; } // we've shunted 'divisor' up to have the same bit length as our remainder. // If remainder >= divisor, then a is at least '1 << bit_difference' multiples of b - if (__gte(remainder, divisor)) { - remainder = __helper_sub(remainder, divisor); + if (__gte(&remainder, &divisor)) { + remainder = __helper_sub(&remainder, &divisor); // we can use OR here instead of +, as // accumulator is always a nice power of two - quotient = __helper_add(quotient, accumulator); + quotient = __helper_add("ient, &accumulator); } divisor = __shr1(divisor); accumulator = __shr1(accumulator); @@ -322,26 +322,26 @@ pub unconstrained fn __udiv_mod( /// This edge case should be rare, but it's worth keeping in mind when /// composing operations or debugging unexpected behavior pub(crate) unconstrained fn batch_invert( - params: BigNumParams, - vals: [[u128; N]; M], + params: &BigNumParams, + vals: &[[u128; N]; M], ) -> [[u128; N]; M] { let mut accumulator: [u128; N] = __one(); let mut temporaries: [[u128; N]; M] = [[0; N]; M]; for i in 0..M { temporaries[i] = accumulator; - if (!__is_zero(vals[i])) { - accumulator = __mul::(params, accumulator, vals[i]); + if (!__is_zero(&vals[i])) { + accumulator = __mul::(params, &accumulator, &vals[i]); } } let mut result: [[u128; N]; M] = [[0; N]; M]; - accumulator = __invmod::(params, accumulator); + accumulator = __invmod::(params, &accumulator); for i in 0..M { let idx: u32 = M - 1 - i; - if (!__is_zero(vals[idx])) { - let T0: [u128; N] = __mul::(params, accumulator, temporaries[idx]); - accumulator = __mul::(params, accumulator, vals[idx]); + if (!__is_zero(&vals[idx])) { + let T0: [u128; N] = __mul::(params, &accumulator, &temporaries[idx]); + accumulator = __mul::(params, &accumulator, &vals[idx]); result[idx] = T0; } } @@ -352,7 +352,7 @@ pub(crate) unconstrained fn batch_invert( - params: BigNumParams, + params: &BigNumParams, vals: [[u128; N]], ) -> [[u128; N]] { let mut accumulator: [u128; N] = __one(); @@ -362,18 +362,18 @@ pub(crate) unconstrained fn batch_invert_slice( for i in 0..M { temporaries = temporaries.push_back(accumulator); - if (!__is_zero(vals[i])) { - accumulator = __mul::(params, accumulator, vals[i]); + if (!__is_zero(&vals[i])) { + accumulator = __mul::(params, &accumulator, &vals[i]); } } let mut result: [[u128; N]] = []; - accumulator = __invmod::(params, accumulator); + accumulator = __invmod::(params, &accumulator); for i in 0..M { let idx: u32 = M - 1 - i; - if (!__is_zero(vals[idx])) { - let T0: [u128; N] = __mul::(params, accumulator, temporaries[idx]); - accumulator = __mul::(params, accumulator, vals[idx]); + if (!__is_zero(&vals[idx])) { + let T0: [u128; N] = __mul::(params, &accumulator, &temporaries[idx]); + accumulator = __mul::(params, &accumulator, &vals[idx]); result = result.push_front(T0); } else { result = result.push_front([0; N]); @@ -385,8 +385,8 @@ pub(crate) unconstrained fn batch_invert_slice( /// Compute a modular square root in a prime field (unconstrained) pub unconstrained fn __sqrt( - params: BigNumParams, - input: [u128; N], + params: &BigNumParams, + input: &[u128; N], ) -> std::option::Option<[u128; N]> { assert( params.has_multiplicative_inverse, @@ -394,7 +394,7 @@ pub unconstrained fn __sqrt( ); if (__is_zero(input)) { - Option::some(input) + Option::some(*input) } else if (params.modulus[0] % 4 == 3) { __easy_sqrt(params, input) } else { @@ -446,15 +446,15 @@ pub unconstrained fn __sqrt( /// /// The input is assumed to be a nonzero value pub(crate) unconstrained fn __tonelli_shanks_sqrt( - params: BigNumParams, - input: [u128; N], + params: &BigNumParams, + input: &[u128; N], ) -> std::option::Option<[u128; N]> { let mut result: Option<[u128; N]> = Option::none(); let one: [u128; N] = __one(); - let s: u32 = __primitive_root_log_size(params.modulus); // p - 1 = 2^s * Q, where Q is odd - let Q: [u128; N] = __shr(__helper_sub(params.modulus, one), s); - let Q_minus_one_over_two: [u128; N] = __shr1(__helper_sub(Q, one)); // (Q - 1) / 2 + let s: u32 = __primitive_root_log_size(¶ms.modulus); // p - 1 = 2^s * Q, where Q is odd + let Q: [u128; N] = __shr(&__helper_sub(¶ms.modulus, &one), s); + let Q_minus_one_over_two: [u128; N] = __shr1(__helper_sub(&Q, &one)); // (Q - 1) / 2 let z: [u128; N] = __quadratic_non_residue::(params); @@ -462,18 +462,18 @@ pub(crate) unconstrained fn __tonelli_shanks_sqrt // b = a^{(Q - 1)/2} // R = a * b = a^{(Q + 1) / 2} => R^2 = a * a^Q // t = R * b = a^Q - let mut b: [u128; N] = __pow::(params, input, Q_minus_one_over_two); - let mut r: [u128; N] = __mul::(params, input, b); - let mut t: [u128; N] = __mul::(params, r, b); + let mut b: [u128; N] = __pow::(params, input, &Q_minus_one_over_two); + let mut r: [u128; N] = __mul::(params, input, &b); + let mut t: [u128; N] = __mul::(params, &r, &b); let mut check: [u128; N] = t; // Assure t^{2^{s - 1}} = a^{(p -1)/2} = 1, otherwise we have met a non-residue for _ in 0..s - 1 { - check = __sqr::(params, check); + check = __sqr::(params, &check); } - if (__eq(check, one)) { + if (__eq(&check, &one)) { let mut m: u32 = s; - let mut c: [u128; N] = __pow::(params, z, Q); // z^Q - proper 2^{M}'th root of unity + let mut c: [u128; N] = __pow::(params, &z, &Q); // z^Q - proper 2^{M}'th root of unity // Tonelli-Shanks main loop @@ -499,21 +499,21 @@ pub(crate) unconstrained fn __tonelli_shanks_sqrt // // The loop runs at most s times because M strictly decreases for _ in 0..s { - if (__eq(t, one)) { + if (__eq(&t, &one)) { result = Option::some(r); break; } - let i: u32 = __tonelli_shanks_sqrt_find_i::(params, t); + let i: u32 = __tonelli_shanks_sqrt_find_i::(params, &t); let j: u32 = m - i - 1; b = c; for _ in 0..j { - b = __sqr(params, b); + b = __sqr(params, &b); } - let b2: [u128; N] = __sqr::(params, b); + let b2: [u128; N] = __sqr::(params, &b); c = b2; - t = __mul::(params, t, b2); - r = __mul::(params, r, b); + t = __mul::(params, &t, &b2); + r = __mul::(params, &r, &b); m = i; } } @@ -532,19 +532,19 @@ pub(crate) unconstrained fn __tonelli_shanks_sqrt /// /// This is much cheaper than `__tonelli_shanks_sqrt` pub(crate) unconstrained fn __easy_sqrt( - params: BigNumParams, - input: [u128; N], + params: &BigNumParams, + input: &[u128; N], ) -> std::option::Option<[u128; N]> { let mut result: Option<[u128; N]> = Option::none(); let one: [u128; N] = __one(); - let p_minus_one_over_two: [u128; N] = __shr1(__helper_sub(params.modulus, one)); - let check: [u128; N] = __pow(params, input, p_minus_one_over_two); - if (__eq(check, one)) { + let p_minus_one_over_two: [u128; N] = __shr1(__helper_sub(¶ms.modulus, &one)); + let check: [u128; N] = __pow(params, input, &p_minus_one_over_two); + if (__eq(&check, &one)) { // a = (MOD - 1) / 2 // b = (a + 1) / 2 = ((MOD - 1) / 2 + 1) / 2 = (MOD + 1) / 4 - let p_plus_one_over_four: [u128; N] = __shr1(__increment(p_minus_one_over_two)); - result = Option::some(__pow(params, input, p_plus_one_over_four)); + let p_plus_one_over_four: [u128; N] = __shr1(__increment(&p_minus_one_over_two)); + result = Option::some(__pow(params, input, &p_plus_one_over_four)); } result } @@ -609,10 +609,10 @@ mod test_invmod { let b: BN = BN::derive_from_seed(y); // a / b should equal a * invmod(b) - let div_result: BN = a.__div(b); + let div_result: BN = a.__div(&b); let inv_b: BN = b.__invmod(); - let mul_result: BN = a.__mul(inv_b); - assert(b.__mul(inv_b) == BN::one(), "b * invmod(b) should equal 1"); + let mul_result: BN = a.__mul(&inv_b); + assert(b.__mul(&inv_b) == BN::one(), "b * invmod(b) should equal 1"); assert(div_result == mul_result, "division should equal multiply by inverse"); } diff --git a/src/params.nr b/src/params.nr index 4b5634b7..939f085f 100644 --- a/src/params.nr +++ b/src/params.nr @@ -24,7 +24,7 @@ impl BigNumParams { Self { has_multiplicative_inverse, modulus, - double_modulus: get_double_modulus(modulus), + double_modulus: get_double_modulus(&modulus), redc_param, } } @@ -39,7 +39,7 @@ impl std::cmp::Eq for BigNumParams { } } -fn get_double_modulus(modulus: [u128; N]) -> [u128; N] { +fn get_double_modulus(modulus: &[u128; N]) -> [u128; N] { let mut result: [u128; N] = [0; N]; let mut carry: u128 = 0; for i in 0..N { diff --git a/src/runtime_bignum.nr b/src/runtime_bignum.nr index 16c417a5..bf155282 100644 --- a/src/runtime_bignum.nr +++ b/src/runtime_bignum.nr @@ -40,7 +40,7 @@ impl RuntimeBigNum { params: BigNumParams, seed: [u8; SeedBytes], ) -> Self { - let limbs: [u128; N] = derive_from_seed::<_, MOD_BITS, _>(params, seed); + let limbs: [u128; N] = derive_from_seed::<_, MOD_BITS, _>(¶ms, seed); Self { limbs, params } } @@ -49,7 +49,7 @@ impl RuntimeBigNum { params: BigNumParams, seed: [u8; SeedBytes], ) -> Self { - let limbs: [u128; N] = __derive_from_seed::<_, MOD_BITS, _>(params, seed); + let limbs: [u128; N] = __derive_from_seed::<_, MOD_BITS, _>(¶ms, seed); Self { limbs, params } } @@ -62,22 +62,22 @@ impl RuntimeBigNum { } pub fn from_be_bytes(params: BigNumParams, x: [u8; (MOD_BITS + 7) / 8]) -> Self { - Self { limbs: from_be_bytes::(x), params } + Self { limbs: from_be_bytes::(&x), params } } pub fn from_le_bytes(params: BigNumParams, x: [u8; (MOD_BITS + 7) / 8]) -> Self { - Self { limbs: from_le_bytes::(x), params } + Self { limbs: from_le_bytes::(&x), params } } - pub fn to_be_bytes(self) -> [u8; (MOD_BITS + 7) / 8] { - to_be_bytes::(self.limbs) + pub fn to_be_bytes(&self) -> [u8; (MOD_BITS + 7) / 8] { + to_be_bytes::(&self.limbs) } - pub fn to_le_bytes(self) -> [u8; (MOD_BITS + 7) / 8] { - to_le_bytes::(self.limbs) + pub fn to_le_bytes(&self) -> [u8; (MOD_BITS + 7) / 8] { + to_le_bytes::(&self.limbs) } - pub fn modulus(self) -> Self { + pub fn modulus(&self) -> Self { let params: BigNumParams = self.params; Self { limbs: params.modulus, params } } @@ -90,11 +90,11 @@ impl RuntimeBigNum { N } - pub fn get_limbs(self) -> [u128; N] { + pub fn get_limbs(&self) -> [u128; N] { self.limbs } - pub fn get_limb(self, idx: u32) -> u128 { + pub fn get_limb(&self, idx: u32) -> u128 { self.limbs[idx] } @@ -103,165 +103,167 @@ impl RuntimeBigNum { self.limbs[idx] = value; } - pub unconstrained fn __eq(self, other: Self) -> bool { + pub unconstrained fn __eq(&self, other: &Self) -> bool { assert(self.params == other.params); - __eq(self.limbs, other.limbs) + __eq(&self.limbs, &other.limbs) } - pub unconstrained fn __is_zero(self) -> bool { - __is_zero(self.limbs) + pub unconstrained fn __is_zero(&self) -> bool { + __is_zero(&self.limbs) } // UNCONSTRAINED! (Hence `__` prefix). - pub fn __neg(self) -> Self { + pub fn __neg(&self) -> Self { let params = self.params; // Safety: Unconstrained function simulation - let limbs: [u128; N] = unsafe { __neg(params.modulus, self.limbs) }; + let limbs: [u128; N] = unsafe { __neg(¶ms.modulus, &self.limbs) }; Self { params, limbs } } // UNCONSTRAINED! (Hence `__` prefix). - pub fn __add(self, other: Self) -> Self { + pub fn __add(&self, other: &Self) -> Self { let params: BigNumParams = self.params; assert(params == other.params); // Safety: Unconstrained function simulation - let limbs: [u128; N] = unsafe { __add(params.modulus, self.limbs, other.limbs) }; + let limbs: [u128; N] = unsafe { __add(¶ms.modulus, &self.limbs, &other.limbs) }; Self { params, limbs } } // UNCONSTRAINED! (Hence `__` prefix). - pub fn __sub(self, other: Self) -> Self { + pub fn __sub(&self, other: &Self) -> Self { let params = self.params; assert(params == other.params); // Safety: Unconstrained function simulation - let limbs: [u128; N] = unsafe { __sub(params.modulus, self.limbs, other.limbs) }; + let limbs: [u128; N] = unsafe { __sub(¶ms.modulus, &self.limbs, &other.limbs) }; Self { params, limbs } } // UNCONSTRAINED! (Hence `__` prefix). - pub fn __mul(self, other: Self) -> Self { + pub fn __mul(&self, other: &Self) -> Self { let params: BigNumParams = self.params; assert(params == other.params); // Safety: Unconstrained function simulation - let limbs: [u128; N] = unsafe { __mul::<_, MOD_BITS>(params, self.limbs, other.limbs) }; + let limbs: [u128; N] = unsafe { __mul::<_, MOD_BITS>(¶ms, &self.limbs, &other.limbs) }; Self { params, limbs } } // UNCONSTRAINED! (Hence `__` prefix). - pub fn __sqr(self) -> Self { + pub fn __sqr(&self) -> Self { let params: BigNumParams = self.params; // Safety: Unconstrained function simulation - let limbs: [u128; N] = unsafe { __sqr::<_, MOD_BITS>(params, self.limbs) }; + let limbs: [u128; N] = unsafe { __sqr::<_, MOD_BITS>(¶ms, &self.limbs) }; Self { params: params, limbs: limbs } } // UNCONSTRAINED! (Hence `__` prefix). - pub fn __div(self, divisor: Self) -> Self { + pub fn __div(&self, divisor: &Self) -> Self { let params: BigNumParams = self.params; assert(params == divisor.params); // Safety: Unconstrained function simulation - let limbs: [u128; N] = unsafe { __div::<_, MOD_BITS>(params, self.limbs, divisor.limbs) }; + let limbs: [u128; N] = + unsafe { __div::<_, MOD_BITS>(¶ms, &self.limbs, &divisor.limbs) }; Self { params, limbs } } // UNCONSTRAINED! (Hence `__` prefix). - pub fn __udiv_mod(self, divisor: Self) -> (Self, Self) { + pub fn __udiv_mod(&self, divisor: &Self) -> (Self, Self) { let params: BigNumParams = self.params; assert(params == divisor.params); // Safety: Unconstrained function simulation - let (q, r): ([u128; N], [u128; N]) = unsafe { __udiv_mod(self.limbs, divisor.limbs) }; + let (q, r): ([u128; N], [u128; N]) = unsafe { __udiv_mod(&self.limbs, &divisor.limbs) }; (Self { limbs: q, params }, Self { limbs: r, params }) } // UNCONSTRAINED! (Hence `__` prefix). - pub fn __invmod(self) -> Self { + pub fn __invmod(&self) -> Self { let params: BigNumParams = self.params; // Safety: Unconstrained function simulation - let limbs: [u128; N] = unsafe { __invmod::(params, self.limbs) }; + let limbs: [u128; N] = unsafe { __invmod::(¶ms, &self.limbs) }; Self { limbs, params } } // UNCONSTRAINED! (Hence `__` prefix). - pub fn __pow(self, exponent: Self) -> Self { + pub fn __pow(&self, exponent: &Self) -> Self { let params: BigNumParams = self.params; assert(params == exponent.params); // Safety: Unconstrained function simulation - let limbs: [u128; N] = unsafe { __pow::<_, MOD_BITS>(params, self.limbs, exponent.limbs) }; + let limbs: [u128; N] = + unsafe { __pow::<_, MOD_BITS>(¶ms, &self.limbs, &exponent.limbs) }; Self { limbs, params } } // UNCONSTRAINED! (Hence `__` prefix). #[deprecated("use __sqrt")] - pub fn __tonelli_shanks_sqrt(self) -> std::option::Option { + pub fn __tonelli_shanks_sqrt(&self) -> std::option::Option { let params: BigNumParams = self.params; // Safety: out-of-circuit sqrt computation - let maybe_limbs: Option<[u128; N]> = unsafe { __sqrt(params, self.limbs) }; + let maybe_limbs: Option<[u128; N]> = unsafe { __sqrt(¶ms, &self.limbs) }; maybe_limbs.map(|limbs: [u128; N]| Self { limbs, params }) } // UNCONSTRAINED! (Hence `__` prefix). - pub fn __sqrt(self) -> std::option::Option { + pub fn __sqrt(&self) -> std::option::Option { let params = self.params; // Safety: out-of-circuit sqrt computation - let maybe_limbs: Option<[u128; N]> = unsafe { __sqrt(params, self.limbs) }; + let maybe_limbs: Option<[u128; N]> = unsafe { __sqrt(¶ms, &self.limbs) }; maybe_limbs.map(|limbs: [u128; N]| Self { limbs, params }) } - pub fn validate_in_field(self: Self) { + pub fn validate_in_field(&self) { let params = self.params; - validate_in_field::(params, self.limbs); + validate_in_field::(¶ms, &self.limbs); } - pub fn validate_in_range(self) { - validate_in_range::(self.limbs); + pub fn validate_in_range(&self) { + validate_in_range::(&self.limbs); } - pub fn assert_is_not_equal(self, other: Self) { + pub fn assert_is_not_equal(&self, other: &Self) { let params: BigNumParams = self.params; assert(params == other.params); - assert_is_not_equal(params, self.limbs, other.limbs); + assert_is_not_equal(¶ms, &self.limbs, &other.limbs); } - pub fn sqr(self) -> Self { + pub fn sqr(&self) -> Self { let params: BigNumParams = self.params; Self { limbs: sqr(params, self.limbs), params: params } } - pub fn udiv_mod(self, divisor: Self) -> (Self, Self) { + pub fn udiv_mod(&self, divisor: &Self) -> (Self, Self) { let params: BigNumParams = self.params; assert(params == divisor.params); - let (q, r) = udiv_mod::(self.limbs, divisor.limbs); + let (q, r) = udiv_mod::(&self.limbs, &divisor.limbs); (Self { limbs: q, params }, Self { limbs: r, params }) } - pub fn udiv(self, divisor: Self) -> Self { + pub fn udiv(&self, divisor: &Self) -> Self { let params: BigNumParams = self.params; assert(params == divisor.params); - Self { limbs: udiv::(self.limbs, divisor.limbs), params } + Self { limbs: udiv::(&self.limbs, &divisor.limbs), params } } - pub fn umod(self, divisor: Self) -> Self { + pub fn umod(&self, divisor: &Self) -> Self { let params: BigNumParams = self.params; assert(params == divisor.params); - Self { limbs: umod::(self.limbs, divisor.limbs), params } + Self { limbs: umod::(&self.limbs, &divisor.limbs), params } } - pub fn is_zero(self) -> bool { + pub fn is_zero(&self) -> bool { let params: BigNumParams = self.params; - is_zero(params, self.limbs) + is_zero(¶ms, &self.limbs) } - pub fn is_zero_integer(self: Self) -> bool { - is_zero_integer(self.limbs) + pub fn is_zero_integer(&self) -> bool { + is_zero_integer(&self.limbs) } - pub fn assert_is_not_zero(self: Self) { + pub fn assert_is_not_zero(&self) { let params: BigNumParams = self.params; - assert_is_not_zero::(params, self.limbs); + assert_is_not_zero::(¶ms, &self.limbs); } - pub fn assert_is_not_zero_integer(self: Self) { - assert_is_not_zero_integer(self.limbs); + pub fn assert_is_not_zero_integer(&self) { + assert_is_not_zero_integer(&self.limbs); } } @@ -274,7 +276,7 @@ impl std::ops::Add for RuntimeBigNum fn add(self, other: Self) -> Self { let params: BigNumParams = self.params; assert(params == other.params); - Self { limbs: add::(params, self.limbs, other.limbs), params } + Self { limbs: add::(¶ms, &self.limbs, &other.limbs), params } } } @@ -287,7 +289,7 @@ impl std::ops::Sub for RuntimeBigNum fn sub(self, other: Self) -> Self { let params: BigNumParams = self.params; assert(params == other.params); - Self { limbs: sub::(params, self.limbs, other.limbs), params } + Self { limbs: sub::(¶ms, &self.limbs, &other.limbs), params } } } @@ -322,7 +324,7 @@ impl Neg for RuntimeBigNum { /// will create much fewer constraints than calling `mul` directly fn neg(self) -> Self { let params: BigNumParams = self.params; - Self { limbs: neg::(params, self.limbs), params } + Self { limbs: neg::(¶ms, &self.limbs), params } } } @@ -330,14 +332,14 @@ impl std::cmp::Eq for RuntimeBigNum fn eq(self, other: Self) -> bool { let params = self.params; assert(params == other.params); - eq::(params, self.limbs, other.limbs) + eq::(¶ms, &self.limbs, &other.limbs) } } impl std::cmp::Ord for RuntimeBigNum { fn cmp(self, other: Self) -> Ordering { assert(self.params == other.params); - cmp::(self.limbs, other.limbs) + cmp::(&self.limbs, &other.limbs) } } @@ -354,13 +356,13 @@ pub fn __compute_quadratic_expression( - params, - map(lhs_terms, |bns| map(bns, |bn| RuntimeBigNum::get_limbs(bn))), - lhs_flags, - map(rhs_terms, |bns| map(bns, |bn| RuntimeBigNum::get_limbs(bn))), - rhs_flags, - map(linear_terms, |bn| RuntimeBigNum::get_limbs(bn)), - linear_flags, + ¶ms, + &map(&lhs_terms, |bns| map(bns, |bn| bn.get_limbs())), + &lhs_flags, + &map(&rhs_terms, |bns| map(bns, |bn| bn.get_limbs())), + &rhs_flags, + &map(&linear_terms, |bn| bn.get_limbs()), + &linear_flags, ) }; (RuntimeBigNum { limbs: q_limbs, params }, RuntimeBigNum { limbs: r_limbs, params }) @@ -376,13 +378,13 @@ pub fn evaluate_quadratic_expression( - params, - map(lhs_terms, |bns| map(bns, |bn| RuntimeBigNum::get_limbs(bn))), - lhs_flags, - map(rhs_terms, |bns| map(bns, |bn| RuntimeBigNum::get_limbs(bn))), - rhs_flags, - map(linear_terms, |bn| RuntimeBigNum::get_limbs(bn)), - linear_flags, + ¶ms, + &map(&lhs_terms, |bns| map(bns, |bn| bn.get_limbs())), + &lhs_flags, + &map(&rhs_terms, |bns| map(bns, |bn| bn.get_limbs())), + &rhs_flags, + &map(&linear_terms, |bn| bn.get_limbs()), + &linear_flags, ) } @@ -394,10 +396,8 @@ pub fn __batch_invert( assert(params.has_multiplicative_inverse); // Safety: Unconstrained function simulation let all_limbs: [[u128; N]; M] = unsafe { - crate::fns::unconstrained_ops::batch_invert::<_, MOD_BITS, _>( - params, - x.map(|bn| RuntimeBigNum::get_limbs(bn)), - ) + let limb_arr = x.map(|bn| bn.get_limbs()); + crate::fns::unconstrained_ops::batch_invert::<_, MOD_BITS, _>(¶ms, &limb_arr) }; all_limbs.map(|limbs: [u128; N]| RuntimeBigNum { limbs, params }) } @@ -408,8 +408,8 @@ pub unconstrained fn __batch_invert_slice = x[0].params; assert(params.has_multiplicative_inverse); let all_limbs: [[u128; N]] = crate::fns::unconstrained_ops::batch_invert_slice::( - params, - x.map(|bn| RuntimeBigNum::get_limbs(bn)), + ¶ms, + x.map(|bn| bn.get_limbs()), ); all_limbs.map(|limbs: [u128; N]| RuntimeBigNum { limbs, params }) diff --git a/src/tests/bignum_test.nr b/src/tests/bignum_test.nr index f4a9d57b..ca2e6fe6 100644 --- a/src/tests/bignum_test.nr +++ b/src/tests/bignum_test.nr @@ -142,7 +142,7 @@ fn test_SecP224r1_mul_regression() { let x = SecP224r1 { limbs: [0x03c1d356c21122343280d6115c1d21, 0xb70e0cbd6bb4bf7f321390b94a] }; let res = x * x; // Safety: test code - let expected = unsafe { SecP224r1::__mul(x, x) }; + let expected = unsafe { x.__mul(&x) }; res.validate_in_field(); expected.validate_in_field(); assert(res == expected); @@ -171,11 +171,11 @@ fn test_evaluate_quadratic_expression_regression_div_extra_mod() { // compute honest z = x * y^{-1} (mod MOD) // Safety: test code - let mut z_: [u128; 4] = unsafe { x.__div(y).get_limbs() }; + let mut z_: [u128; 4] = unsafe { x.__div(&y).get_limbs() }; // Safety: test code - z_ = unsafe { __helper_add(z_, modulus) }; + z_ = unsafe { __helper_add(&z_, &modulus) }; // Safety: test code - z_ = unsafe { __helper_add(z_, modulus) }; + z_ = unsafe { __helper_add(&z_, &modulus) }; let z: BLS12_381_Fq = BLS12_381_Fq::from_limbs(z_); @@ -191,11 +191,11 @@ fn test_evaluate_quadratic_expression_add_extra_mod_fuzz(xseed: [u8; 4], yseed: // compute honest z = x + y (mod MOD) // Safety: test code - let mut z_: [u128; 3] = unsafe { x.__add(y).get_limbs() }; + let mut z_: [u128; 3] = unsafe { x.__add(&y).get_limbs() }; // Safety: test code - z_ = unsafe { __helper_add(z_, modulus) }; + z_ = unsafe { __helper_add(&z_, &modulus) }; // Safety: test code - z_ = unsafe { __helper_add(z_, modulus) }; + z_ = unsafe { __helper_add(&z_, &modulus) }; let z: BLS12_381_Fr = BLS12_381_Fr::from_limbs(z_); @@ -211,11 +211,11 @@ fn test_evaluate_quadratic_expression_mul_extra_mod_fuzz(xseed: [u8; 1], yseed: // compute honest z = x * y (mod MOD) // Safety: test code - let mut z_: [u128; 3] = unsafe { x.__mul(y).get_limbs() }; + let mut z_: [u128; 3] = unsafe { x.__mul(&y).get_limbs() }; // Safety: test code - z_ = unsafe { __helper_add(z_, modulus) }; + z_ = unsafe { __helper_add(&z_, &modulus) }; // Safety: test code - z_ = unsafe { __helper_add(z_, modulus) }; + z_ = unsafe { __helper_add(&z_, &modulus) }; let z: U256 = U256::from_limbs(z_); @@ -233,13 +233,13 @@ fn test_udiv_mod_extra_modulus_regression(xseed: [u8; 1], yseed: [u8; 2]) { // compute honest z = x * y (mod MOD) // Safety: test code - let (q, r): (U256, U256) = unsafe { x.__udiv_mod(y) }; + let (q, r): (U256, U256) = unsafe { x.__udiv_mod(&y) }; let mut q_: [u128; 3] = q.get_limbs(); // Safety: test code - q_ = unsafe { __helper_add(q_, modulus) }; + q_ = unsafe { __helper_add(&q_, &modulus) }; - validate_udiv_mod_expression::<3, 257>(x.get_limbs(), y.get_limbs(), q_, r.get_limbs()); + validate_udiv_mod_expression::<3, 257>(&x.get_limbs(), &y.get_limbs(), &q_, &r.get_limbs()); } // This tests that `validate_gt` no longer accepts identical inputs @@ -247,7 +247,9 @@ fn test_udiv_mod_extra_modulus_regression(xseed: [u8; 1], yseed: [u8; 2]) { fn test_validate_gt_regression_BN() { let x: BN254_Fq = BN254_Fq::derive_from_seed([1, 2, 3, 4]); let y: BN254_Fq = BN254_Fq::derive_from_seed([1, 2, 3, 4]); - validate_gt::<3, 254>(x.get_limbs(), y.get_limbs()); + let x_limbs = x.get_limbs(); + let y_limbs = y.get_limbs(); + validate_gt::<3, 254>(&x_limbs, &y_limbs); } // ------------------------------ BASIC TESTS ------------------------------ @@ -510,7 +512,7 @@ where let a: BN = BN::derive_from_seed([1, 2, 3, 4]); let b: BN = BN::derive_from_seed([4, 5, 6, 7]); - a.assert_is_not_equal(b); + a.assert_is_not_equal(&b); } fn test_assert_is_not_equal_fail() @@ -520,7 +522,7 @@ where let a: BN = BN::derive_from_seed([1, 2, 3, 4]); let b: BN = BN::derive_from_seed([1, 2, 3, 4]); - a.assert_is_not_equal(b); + a.assert_is_not_equal(&b); } fn test_assert_is_not_equal_overloaded_lhs_fail() @@ -535,8 +537,8 @@ where let t0 = a.get_limbs(); let t1 = modulus.get_limbs(); // Safety: test code - let a_plus_modulus: BN = BN::from_limbs(unsafe { __helper_add(t0, t1) }); - a_plus_modulus.assert_is_not_equal(b); + let a_plus_modulus: BN = BN::from_limbs(unsafe { __helper_add(&t0, &t1) }); + a_plus_modulus.assert_is_not_equal(&b); } fn test_assert_is_not_equal_overloaded_rhs_fail() @@ -551,8 +553,8 @@ where let t0 = b.get_limbs(); let t1 = modulus.get_limbs(); // Safety: test code - let b_plus_modulus = BN::from_limbs(unsafe { __helper_add(t0, t1) }); - a.assert_is_not_equal(b_plus_modulus); + let b_plus_modulus = BN::from_limbs(unsafe { __helper_add(&t0, &t1) }); + a.assert_is_not_equal(&b_plus_modulus); } fn test_assert_is_not_equal_overloaded_fail() @@ -568,10 +570,10 @@ where let t1 = b.get_limbs(); let t2 = modulus.get_limbs(); // Safety: test code - let a_plus_modulus: BN = BN::from_limbs(unsafe { __helper_add(t0, t2) }); + let a_plus_modulus: BN = BN::from_limbs(unsafe { __helper_add(&t0, &t2) }); // Safety: test code - let b_plus_modulus: BN = BN::from_limbs(unsafe { __helper_add(t1, t2) }); - a_plus_modulus.assert_is_not_equal(b_plus_modulus); + let b_plus_modulus: BN = BN::from_limbs(unsafe { __helper_add(&t1, &t2) }); + a_plus_modulus.assert_is_not_equal(&b_plus_modulus); } #[test] @@ -659,9 +661,9 @@ fn test_do_nothing() { let a: BN254_Fq = BN254_Fq::from_limbs([1, 2, 0]); let b: BN254_Fq = BN254_Fq::from_limbs([1, 2, 0]); // Safety: test code - let c: [u128; 3] = unsafe { __helper_add(a.get_limbs(), b.get_limbs()) }; + let c: [u128; 3] = unsafe { __helper_add(&a.get_limbs(), &b.get_limbs()) }; // Safety: test code - let d: BN254_Fq = unsafe { a.__add(b) }; + let d: BN254_Fq = unsafe { a.__add(&b) }; let e: BN254_Fq = a + b; assert(c == d.get_limbs()); assert(c == e.get_limbs()); @@ -763,7 +765,7 @@ fn test_add_modulus_overflow() { [0xffffffffffffffffffffffffffffff, 0xffffffffffffffffffffffffffffff, 0x3fff]; let one = [1, 0, 0]; // Safety: test code - let a: BN254_Fq = BN254_Fq::from_limbs(unsafe { __helper_add(p, one) }); + let a: BN254_Fq = BN254_Fq::from_limbs(unsafe { __helper_add(&p, &one) }); let b: BN254_Fq = BN254_Fq::from_limbs(two_pow_254_minus_1); let result = a + b; assert(result == b); @@ -855,7 +857,7 @@ where let inv_two: BN = unsafe { BN::from_limbs( __shr( - __helper_add(BN::modulus().get_limbs(), BN::one().get_limbs()), + &__helper_add(&BN::modulus().get_limbs(), &BN::one().get_limbs()), 1, ), ) @@ -877,7 +879,7 @@ where let two: BN = BN::one() + BN::one(); let c: BN = a / two; - let d: BN = a.udiv(two); + let d: BN = a.udiv(&two); assert(c == d); } @@ -890,7 +892,7 @@ where // Safety: test code let v: BN = unsafe { u.__invmod() }; // Safety: test code - let result: BN = unsafe { u.__mul(v) }; + let result: BN = unsafe { u.__mul(&v) }; let expected: BN = BN::one(); assert(result == expected); } @@ -940,10 +942,10 @@ fn test_udiv_mod_U256() { ]); let b: U256 = U256::from_limbs([12, 0, 0]); - let (q, r): (U256, U256) = a.udiv_mod(b); + let (q, r): (U256, U256) = a.udiv_mod(&b); // Safety: test code - let product: U256 = unsafe { q.__mul(b).__add(r) }; + let product: U256 = unsafe { q.__mul(&b).__add(&r) }; assert(product == a); } @@ -951,7 +953,8 @@ fn test_udiv_mod_U256() { fn test_1_udiv_mod_2() { let _0: U256 = U256::zero(); let _1: U256 = U256::one(); - assert(_1.udiv_mod(_1 + _1) == (_0, _1)); + let _2: U256 = _1 + _1; + assert(_1.udiv_mod(&_2) == (_0, _1)); } #[test] @@ -959,7 +962,8 @@ fn test_20_udiv_mod_11() { let _1: U256 = U256::one(); let _2_POW_120: U256 = U256::from_limbs([0, 1, 0]); let _2_POW_121: U256 = U256::from_limbs([0, 2, 0]); - assert(_2_POW_121.udiv_mod(_2_POW_120 + _1) == (_1, _2_POW_120 - _1)); + let divisor: U256 = _2_POW_120 + _1; + assert(_2_POW_121.udiv_mod(&divisor) == (_1, _2_POW_120 - _1)); } //// Set up parametrized tests @@ -1353,7 +1357,7 @@ fn test_expressions() { let y: BN254_Fq = BN254_Fq::from_limbs([0x1, 0x1, 0x0]); let z: BN254_Fq = BN254_Fq::from_limbs([0x2, 0x2, 0x0]); // Safety: test code - let yy: BN254_Fq = unsafe { y.__add(y) }; + let yy: BN254_Fq = unsafe { y.__add(&y) }; assert(yy.get_limbs() == z.get_limbs()); @@ -1379,13 +1383,13 @@ fn test_expressions() { 0x0000000000000000000000000000000000000000000000000000000000000f93, ]); // Safety: test code - let wx: BN254_Fq = unsafe { w.__mul(x) }; + let wx: BN254_Fq = unsafe { w.__mul(&x) }; // Safety: test code - let uv: BN254_Fq = unsafe { uu.__mul(vv) }; + let uv: BN254_Fq = unsafe { uu.__mul(&vv) }; // Safety: test code - let y: BN254_Fq = unsafe { (uv.__add(wx)).__neg() }; + let y: BN254_Fq = unsafe { (uv.__add(&wx)).__neg() }; // Safety: test code - let z: BN254_Fq = unsafe { uv.__add(wx) }; + let z: BN254_Fq = unsafe { uv.__add(&wx) }; evaluate_quadratic_expression( [[uu], [w]], @@ -1475,7 +1479,7 @@ fn test_2048_bit_quadratic_expression() { let b_bn: BN2048 = BN2048::from_limbs(b); // Safety: test code - let c_bn: BN2048 = unsafe { a_bn.__mul(b_bn) }; + let c_bn: BN2048 = unsafe { a_bn.__mul(&b_bn) }; assert(c_bn.limbs == c_expected); @@ -1513,7 +1517,7 @@ where BN: BigNum, { // Safety: test code - let qnr_limbs = unsafe { __quadratic_non_residue(BN::params()) }; + let qnr_limbs = unsafe { __quadratic_non_residue(&BN::params()) }; let g: BN = BN::from_limbs(qnr_limbs); // Safety: test code @@ -1540,12 +1544,14 @@ fn test_sqrt_equality_fuzz(seed: [u8; 3]) { let b: BLS12_381_Fq = a * a; // Safety: test code let c: BLS12_381_Fq = unsafe { - BLS12_381_Fq::from_limbs(__tonelli_shanks_sqrt(BLS12_381_Fq::params(), b.get_limbs()) + BLS12_381_Fq::from_limbs(__tonelli_shanks_sqrt(&BLS12_381_Fq::params(), &b.get_limbs()) .unwrap()) }; // Safety: test code let d: BLS12_381_Fq = unsafe { - BLS12_381_Fq::from_limbs(__easy_sqrt(BLS12_381_Fq::params(), b.get_limbs()).unwrap()) + BLS12_381_Fq::from_limbs( + __easy_sqrt(&BLS12_381_Fq::params(), &b.get_limbs()).unwrap(), + ) }; assert((c == d) | (c == -d)); } diff --git a/src/tests/runtime_bignum_test.nr b/src/tests/runtime_bignum_test.nr index 756587d4..422a7f24 100644 --- a/src/tests/runtime_bignum_test.nr +++ b/src/tests/runtime_bignum_test.nr @@ -130,7 +130,7 @@ fn test_add_modulus_limit() { let a: $RuntimeBigNum<$N, $MOD_BITS> = $RuntimeBigNum { limbs: p, params }; // Safety: test code - let two_pow_modulus_bits_minus_one: [u128; $N] = unsafe { $__helper_sub($__shl(one, $MOD_BITS), one) }; + let two_pow_modulus_bits_minus_one: [u128; $N] = unsafe { $__helper_sub(&$__shl(&one, $MOD_BITS), &one) }; let b: $RuntimeBigNum<$N, $MOD_BITS> = $RuntimeBigNum { limbs: two_pow_modulus_bits_minus_one, params }; let result = a + b; @@ -146,10 +146,10 @@ fn test_add_modulus_overflow() { let one = unsafe{$__one()}; // Safety: test code - let a: $RuntimeBigNum<$N, $MOD_BITS> = $RuntimeBigNum { limbs: unsafe { $__helper_add(p, one) }, params }; + let a: $RuntimeBigNum<$N, $MOD_BITS> = $RuntimeBigNum { limbs: unsafe { $__helper_add(&p, &one) }, params }; // Safety: test code - let two_pow_modulus_bits_minus_one: [u128; $N] = unsafe { $__helper_sub($__shl(one, $MOD_BITS), one) }; + let two_pow_modulus_bits_minus_one: [u128; $N] = unsafe { $__helper_sub(&$__shl(&one, $MOD_BITS), &one) }; let b: $RuntimeBigNum<$N, $MOD_BITS> = $RuntimeBigNum { limbs: two_pow_modulus_bits_minus_one, params }; let result = a + b; @@ -221,7 +221,7 @@ fn assert_is_not_equal() { let a: $RuntimeBigNum<$N, $MOD_BITS> = $RuntimeBigNum::derive_from_seed(params, [1, 2, 3, 4]); let b: $RuntimeBigNum<$N, $MOD_BITS> = $RuntimeBigNum::derive_from_seed(params, [4, 5, 6, 7]); - a.assert_is_not_equal(b); + a.assert_is_not_equal(&b); } #[test(should_fail_with = "assert_is_not_equal fail")] @@ -231,7 +231,7 @@ fn assert_is_not_equal_fail() { let a: $RuntimeBigNum<$N, $MOD_BITS> = $RuntimeBigNum::derive_from_seed(params, [1, 2, 3, 4]); let b: $RuntimeBigNum<$N, $MOD_BITS> = $RuntimeBigNum::derive_from_seed(params, [1, 2, 3, 4]); - a.assert_is_not_equal(b); + a.assert_is_not_equal(&b); } #[test(should_fail_with = "assert_is_not_equal fail")] @@ -246,9 +246,9 @@ fn assert_is_not_equal_overloaded_lhs_fail() { let t0: [u128; $N] = a.limbs; let t1: [u128; $N] = modulus; // Safety: test code - let a_plus_modulus: $RuntimeBigNum<$N, $MOD_BITS> = $RuntimeBigNum { limbs: unsafe { $__helper_add(t0, t1) }, params }; + let a_plus_modulus: $RuntimeBigNum<$N, $MOD_BITS> = $RuntimeBigNum { limbs: unsafe { $__helper_add(&t0, &t1) }, params }; - a_plus_modulus.assert_is_not_equal(b); + a_plus_modulus.assert_is_not_equal(&b); } #[test(should_fail_with = "assert_is_not_equal fail")] @@ -263,9 +263,9 @@ fn assert_is_not_equal_overloaded_rhs_fail() { let t0: [u128; $N] = b.limbs; let t1: [u128; $N] = modulus; // Safety: test code - let b_plus_modulus: $RuntimeBigNum<$N, $MOD_BITS> = $RuntimeBigNum { limbs: unsafe { $__helper_add(t0, t1) }, params }; + let b_plus_modulus: $RuntimeBigNum<$N, $MOD_BITS> = $RuntimeBigNum { limbs: unsafe { $__helper_add(&t0, &t1) }, params }; - a.assert_is_not_equal(b_plus_modulus); + a.assert_is_not_equal(&b_plus_modulus); } #[test(should_fail_with = "assert_is_not_equal fail")] @@ -282,11 +282,11 @@ fn assert_is_not_equal_overloaded_fail() { let t2: [u128; $N] = modulus; // Safety: test code - let a_plus_modulus: $RuntimeBigNum<$N, $MOD_BITS> = $RuntimeBigNum { limbs: unsafe { $__helper_add(t0, t2) }, params }; + let a_plus_modulus: $RuntimeBigNum<$N, $MOD_BITS> = $RuntimeBigNum { limbs: unsafe { $__helper_add(&t0, &t2) }, params }; // Safety: test code - let b_plus_modulus: $RuntimeBigNum<$N, $MOD_BITS> = $RuntimeBigNum { limbs: unsafe { $__helper_add(t1, t2) }, params }; + let b_plus_modulus: $RuntimeBigNum<$N, $MOD_BITS> = $RuntimeBigNum { limbs: unsafe { $__helper_add(&t1, &t2) }, params }; - a_plus_modulus.assert_is_not_equal(b_plus_modulus); + a_plus_modulus.assert_is_not_equal(&b_plus_modulus); } #[test] @@ -303,7 +303,7 @@ fn test_eq() { let t1: [u128; $N] = b.limbs; // Safety: test code - let b_plus_modulus: $RuntimeBigNum<$N, $MOD_BITS> = $RuntimeBigNum { limbs: unsafe { $__helper_add(t0, t1) }, params }; + let b_plus_modulus: $RuntimeBigNum<$N, $MOD_BITS> = $RuntimeBigNum { limbs: unsafe { $__helper_add(&t0, &t1) }, params }; assert_eq(a, b); assert_eq(a, b_plus_modulus); @@ -390,10 +390,11 @@ fn test_cmp_fuzz(seed: [u8; 2]){ one[0] = 1; if modulus[0] % 2 == 0 { // if modulus is even, half_modulus is modulus / 2 - half_modulus = $RuntimeBigNum {limbs: $udiv::<$N, $MOD_BITS>(modulus, two), params}; + half_modulus = $RuntimeBigNum {limbs: $udiv::<$N, $MOD_BITS>(&modulus, &two), params}; } else { // if modulus is odd, half_modulus is (modulus - 1) / 2 - half_modulus = $RuntimeBigNum {limbs: $udiv::<$N, $MOD_BITS>($sub(params,modulus, one), two), params}; + let modulus_minus_one = $sub(¶ms,&modulus, &one); + half_modulus = $RuntimeBigNum {limbs: $udiv::<$N, $MOD_BITS>(&modulus_minus_one, &two), params}; } @@ -493,7 +494,7 @@ unconstrained fn test_invmod(params: BigNumParams let u: RuntimeBigNum = RuntimeBigNum::derive_from_seed(params, [1, 2, 3, 4]); for _ in 0..1 { let v = u.__invmod(); - let result = u.__mul(v); + let result = u.__mul(&v); let expected: RuntimeBigNum = RuntimeBigNum::one(params); assert(result.limbs == expected.limbs); } @@ -617,7 +618,7 @@ fn test_sqrt_BLS12_381_Fq() { assert((c == a) | (c == -a)); // Safety: test code - let qnr_limbs: [u128; 4] = unsafe { __quadratic_non_residue(params) }; + let qnr_limbs: [u128; 4] = unsafe { __quadratic_non_residue(¶ms) }; let g = RuntimeBigNum { params: params, limbs: qnr_limbs }; // Safety: test code let c = g.__sqrt(); @@ -678,6 +679,6 @@ fn test_barrett_reduction_fix() { // Safety: test code let (_quotient, remainder) = - unsafe { __barrett_reduction(to_reduce, params.redc_param, 1024, params.modulus) }; + unsafe { __barrett_reduction(&to_reduce, ¶ms.redc_param, 1024, ¶ms.modulus) }; assert((remainder[8] as Field).lt(params.modulus[8] as Field)); } diff --git a/src/utils/map.nr b/src/utils/map.nr index ca2360c5..84cc7f2c 100644 --- a/src/utils/map.nr +++ b/src/utils/map.nr @@ -1,17 +1,17 @@ // Copied from std::array, because I couldn't figure out how to import the `map` method of the weird trait for an array. // And the reason I wanted direct access to it, is because I couldn't figure out how to implement a double map. -pub(crate) fn map(arr: [T; N], f: fn[Env](T) -> U) -> [U; N] { +pub(crate) fn map(arr: &[T; N], f: fn[Env](&T) -> U) -> [U; N] { let mut ret: [U; N] = std::mem::zeroed(); for i in 0..N { - ret[i] = f(arr[i]); + ret[i] = f(&arr[i]); } ret } /// reverse an array -pub(crate) fn invert_array(array: [T; M]) -> [T; M] { +pub(crate) fn invert_array(array: &[T; M]) -> [T; M] { let mut ret: [T; M] = std::mem::zeroed(); for i in 0..M { diff --git a/src/utils/msb.nr b/src/utils/msb.nr index 0e3f27e0..319fbf18 100644 --- a/src/utils/msb.nr +++ b/src/utils/msb.nr @@ -121,27 +121,27 @@ mod tests { let x: Field = 0x8000000000000000; let arr: [u128; 4] = [0, 0, x as u128, 0]; let msb1: u32 = __get_msb64(arr); - let msb2: u32 = __get_msb(arr); + let msb2: u32 = __get_msb(&arr); assert_eq(msb1, msb2); // Test multiple limbs (120-bit number) let x: Field = 0x800000000000000000000000000000; // 120 bits number, 2^119 let arr: [u128; 4] = [0, 0, x as u128, 0]; let msb1: u32 = __get_msb64(arr); - let msb2: u32 = __get_msb(arr); + let msb2: u32 = __get_msb(&arr); assert_eq(msb1, msb2); // Test zero let arr: [u128; 4] = [0, 0, 0, 0]; let msb1: u32 = __get_msb64(arr); - let msb2: u32 = __get_msb(arr); + let msb2: u32 = __get_msb(&arr); assert_eq(msb1, msb2); // Test all bits set (120 bits) let x: Field = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF; // 120 bits, 2^120 - 1 let arr: [u128; 4] = [0, x as u128, 0, 0]; let msb1: u32 = __get_msb64(arr); - let msb2: u32 = __get_msb(arr); + let msb2: u32 = __get_msb(&arr); assert_eq(msb1, msb2); // Test systematic bit positions @@ -150,7 +150,7 @@ mod tests { let shifted: u128 = x << i; let arr: [u128; 4] = [0, shifted, 0, 0]; let msb1: u32 = __get_msb64(arr); - let msb2: u32 = __get_msb(arr); + let msb2: u32 = __get_msb(&arr); assert_eq(msb1, msb2); } @@ -167,7 +167,7 @@ mod tests { for i in 0..patterns.len() { let arr: [u128; 4] = [0, patterns[i] as u128, 0, 0]; let msb1: u32 = __get_msb64(arr); - let msb2: u32 = __get_msb(arr); + let msb2: u32 = __get_msb(&arr); assert_eq(msb1, msb2); } @@ -178,19 +178,19 @@ mod tests { let arr3: [u128; 4] = [0, 0, x as u128, 0]; let arr4: [u128; 4] = [0, 0, 0, x as u128]; let msb1_1: u32 = __get_msb64(arr1); - let msb2_1: u32 = __get_msb(arr1); + let msb2_1: u32 = __get_msb(&arr1); assert_eq(msb1_1, msb2_1); let msb1_2: u32 = __get_msb64(arr2); - let msb2_2: u32 = __get_msb(arr2); + let msb2_2: u32 = __get_msb(&arr2); assert_eq(msb1_2, msb2_2); let msb1_3: u32 = __get_msb64(arr3); - let msb2_3: u32 = __get_msb(arr3); + let msb2_3: u32 = __get_msb(&arr3); assert_eq(msb1_3, msb2_3); let msb1_4: u32 = __get_msb64(arr4); - let msb2_4: u32 = __get_msb(arr4); + let msb2_4: u32 = __get_msb(&arr4); assert_eq(msb1_4, msb2_4); } @@ -201,7 +201,7 @@ mod tests { seed_copy[i] = seed_copy[i] & (TWO_POW_120 - 1); } let msb1: u32 = __get_msb64(seed_copy); - let msb2: u32 = __get_msb(seed_copy); + let msb2: u32 = __get_msb(&seed_copy); assert_eq(msb1, msb2); } }