Skip to content

Commit

Permalink
Make BernsteinYangInverter::invert into a const fn
Browse files Browse the repository at this point in the history
  • Loading branch information
tarcieri committed Dec 3, 2023
1 parent f0cbca4 commit b78e966
Showing 1 changed file with 80 additions and 158 deletions.
238 changes: 80 additions & 158 deletions src/modular/bernstein_yang.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,6 @@
#![allow(clippy::needless_range_loop)]

use core::{
cmp::PartialEq,
ops::{Add, Mul, Neg, Sub},
};

/// Type of the modular multiplicative inverter based on the Bernstein-Yang method.
/// The inverter can be created for a specified modulus M and adjusting parameter A
/// to compute the adjusted multiplicative inverses of positive integers, i.e. for
Expand Down Expand Up @@ -63,21 +58,22 @@ impl<const L: usize> BernsteinYangInverter<L> {

/// Returns either the adjusted modular multiplicative inverse for the argument or None
/// depending on invertibility of the argument, i.e. its coprimality with the modulus
pub fn invert<const S: usize>(&self, value: &[u64]) -> Option<[u64; S]> {
let (mut d, mut e) = (CInt::ZERO, self.adjuster.clone());
pub const fn invert<const S: usize>(&self, value: &[u64]) -> Option<[u64; S]> {
let (mut d, mut e) = (CInt::ZERO, self.adjuster);
let mut g = CInt::<62, L>(Self::convert::<64, 62, L>(value));
let (mut delta, mut f) = (1, self.modulus.clone());
let (mut delta, mut f) = (1, self.modulus);
let mut matrix;
while g != CInt::ZERO {

while !g.eq(&CInt::ZERO) {
(delta, matrix) = Self::jump(&f, &g, delta);
(f, g) = Self::fg(f, g, matrix);
(d, e) = self.de(d, e, matrix);
}
// At this point the absolute value of "f" equals the greatest common divisor
// of the integer to be inverted and the modulus the inverter was created for.
// Thus, if "f" is neither 1 nor -1, then the sought inverse does not exist
let antiunit = f == CInt::MINUS_ONE;
if (f != CInt::ONE) && !antiunit {
let antiunit = f.eq(&CInt::MINUS_ONE);
if !f.eq(&CInt::ONE) && !antiunit {
return None;
}
Some(Self::convert::<62, 64, S>(&self.norm(d, antiunit).0))
Expand All @@ -86,12 +82,21 @@ impl<const L: usize> BernsteinYangInverter<L> {
/// Returns the Bernstein-Yang transition matrix multiplied by 2^62 and the new value
/// of the delta variable for the 62 basic steps of the Bernstein-Yang method, which
/// are to be performed sequentially for specified initial values of f, g and delta
fn jump(f: &CInt<62, L>, g: &CInt<62, L>, mut delta: i64) -> (i64, Matrix) {
const fn jump(f: &CInt<62, L>, g: &CInt<62, L>, mut delta: i64) -> (i64, Matrix) {
// This function is defined because the method "min" of the i64 type is not constant
const fn min(a: i64, b: i64) -> i64 {
if a > b {
b
} else {
a
}
}

let (mut steps, mut f, mut g) = (62, f.lowest() as i64, g.lowest() as i128);
let mut t: Matrix = [[1, 0], [0, 1]];

loop {
let zeros = steps.min(g.trailing_zeros() as i64);
let zeros = min(steps, g.trailing_zeros() as i64);
(steps, delta, g) = (steps - zeros, delta + zeros, g >> zeros);
t[0] = [t[0][0] << zeros, t[0][1] << zeros];

Expand All @@ -106,7 +111,7 @@ impl<const L: usize> BernsteinYangInverter<L> {
// The formula (3 * x) xor 28 = -1 / x (mod 32) for an odd integer x
// in the two's complement code has been derived from the formula
// (3 * x) xor 2 = 1 / x (mod 32) attributed to Peter Montgomery
let mask = (1 << steps.min(1 - delta).min(5)) - 1;
let mask = (1 << min(min(steps, 1 - delta), 5)) - 1;
let w = (g as i64).wrapping_mul(f.wrapping_mul(3) ^ 28) & mask;

t[1] = [t[0][0] * w + t[1][0], t[0][1] * w + t[1][1]];
Expand All @@ -118,18 +123,18 @@ impl<const L: usize> BernsteinYangInverter<L> {

/// Returns the updated values of the variables f and g for specified initial ones and Bernstein-Yang transition
/// matrix multiplied by 2^62. The returned vector is "matrix * (f, g)' / 2^62", where "'" is the transpose operator
fn fg(f: CInt<62, L>, g: CInt<62, L>, t: Matrix) -> (CInt<62, L>, CInt<62, L>) {
const fn fg(f: CInt<62, L>, g: CInt<62, L>, t: Matrix) -> (CInt<62, L>, CInt<62, L>) {
(
(t[0][0] * &f + t[0][1] * &g).shift(),
(t[1][0] * &f + t[1][1] * &g).shift(),
f.mul(t[0][0]).add(&g.mul(t[0][1])).shift(),
f.mul(t[1][0]).add(&g.mul(t[1][1])).shift(),
)
}

/// Returns the updated values of the variables d and e for specified initial ones and Bernstein-Yang transition
/// matrix multiplied by 2^62. The returned vector is congruent modulo M to "matrix * (d, e)' / 2^62 (mod M)",
/// where M is the modulus the inverter was created for and "'" stands for the transpose operator. Both the input
/// and output values lie in the interval (-2 * M, M)
fn de(&self, d: CInt<62, L>, e: CInt<62, L>, t: Matrix) -> (CInt<62, L>, CInt<62, L>) {
const fn de(&self, d: CInt<62, L>, e: CInt<62, L>, t: Matrix) -> (CInt<62, L>, CInt<62, L>) {
let mask = CInt::<62, L>::MASK as i64;
let mut md = t[0][0] * d.is_negative() as i64 + t[0][1] * e.is_negative() as i64;
let mut me = t[1][0] * d.is_negative() as i64 + t[1][1] * e.is_negative() as i64;
Expand All @@ -146,26 +151,32 @@ impl<const L: usize> BernsteinYangInverter<L> {
md -= (self.inverse.wrapping_mul(cd).wrapping_add(md)) & mask;
me -= (self.inverse.wrapping_mul(ce).wrapping_add(me)) & mask;

let cd = t[0][0] * &d + t[0][1] * &e + md * &self.modulus;
let ce = t[1][0] * &d + t[1][1] * &e + me * &self.modulus;
let cd = d
.mul(t[0][0])
.add(&e.mul(t[0][1]))
.add(&self.modulus.mul(md));
let ce = d
.mul(t[1][0])
.add(&e.mul(t[1][1]))
.add(&self.modulus.mul(me));

(cd.shift(), ce.shift())
}

/// Returns either "value (mod M)" or "-value (mod M)", where M is the modulus the
/// inverter was created for, depending on "negate", which determines the presence
/// of "-" in the used formula. The input integer lies in the interval (-2 * M, M)
fn norm(&self, mut value: CInt<62, L>, negate: bool) -> CInt<62, L> {
const fn norm(&self, mut value: CInt<62, L>, negate: bool) -> CInt<62, L> {
if value.is_negative() {
value = value + &self.modulus;
value = value.add(&self.modulus);
}

if negate {
value = -value;
value = value.neg();
}

if value.is_negative() {
value = value + &self.modulus;
value = value.add(&self.modulus);
}

value
Expand Down Expand Up @@ -222,7 +233,7 @@ impl<const L: usize> BernsteinYangInverter<L> {
/// numbers in the two's complement code as arrays of B-bit chunks.
/// The ordering of the chunks in these arrays is little-endian.
/// The arithmetic operations for this type are wrapping ones.
#[derive(Clone, Debug)]
#[derive(Clone, Copy, Debug)]
struct CInt<const B: usize, const L: usize>(pub [u64; L]);

impl<const B: usize, const L: usize> CInt<B, L> {
Expand All @@ -244,150 +255,78 @@ impl<const B: usize, const L: usize> CInt<B, L> {

/// Returns the result of applying B-bit right
/// arithmetical shift to the current number
pub fn shift(&self) -> Self {
pub const fn shift(&self) -> Self {
let mut data = [0; L];
if self.is_negative() {
data[L - 1] = Self::MASK;
}
data[..L - 1].copy_from_slice(&self.0[1..]);

let mut i = 0;
while i < L - 1 {
data[i] = self.0[i + 1];
i += 1;
}

Self(data)
}

/// Returns the lowest B bits of the current number
pub fn lowest(&self) -> u64 {
pub const fn lowest(&self) -> u64 {
self.0[0]
}

/// Returns "true" iff the current number is negative
pub fn is_negative(&self) -> bool {
pub const fn is_negative(&self) -> bool {
self.0[L - 1] > (Self::MASK >> 1)
}
}

impl<const B: usize, const L: usize> PartialEq for CInt<B, L> {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}
/// Const fn equivalent for `PartialEq::eq`
pub const fn eq(&self, other: &Self) -> bool {
let mut ret = true;
let mut i = 0;

impl<const B: usize, const L: usize> Add for &CInt<B, L> {
type Output = CInt<B, L>;
fn add(self, other: Self) -> Self::Output {
let (mut data, mut carry) = ([0; L], 0);
for i in 0..L {
let sum = self.0[i] + other.0[i] + carry;
data[i] = sum & CInt::<B, L>::MASK;
carry = sum >> B;
while i < L {
ret &= self.0[i] == other.0[i];
i += 1;
}
CInt(data)
}
}

impl<const B: usize, const L: usize> Add<&CInt<B, L>> for CInt<B, L> {
type Output = CInt<B, L>;
fn add(self, other: &Self) -> Self::Output {
&self + other
ret
}
}

impl<const B: usize, const L: usize> Add for CInt<B, L> {
type Output = CInt<B, L>;
fn add(self, other: Self) -> Self::Output {
&self + &other
}
}
/// Const fn equivalent for `Add::add`
pub const fn add(&self, other: &Self) -> Self {
let (mut data, mut carry) = ([0; L], 0);
let mut i = 0;

impl<const B: usize, const L: usize> Sub for &CInt<B, L> {
type Output = CInt<B, L>;
fn sub(self, other: Self) -> Self::Output {
// For the two's complement code the additive negation is the result of
// adding 1 to the bitwise inverted argument's representation. Thus, for
// any encoded integers x and y we have x - y = x + !y + 1, where "!" is
// the bitwise inversion and addition is done according to the rules of
// the code. The algorithm below uses this formula and is the modified
// addition algorithm, where the carry flag is initialized with 1 and
// the chunks of the second argument are bitwise inverted
let (mut data, mut carry) = ([0; L], 1);
for i in 0..L {
let sum = self.0[i] + (other.0[i] ^ CInt::<B, L>::MASK) + carry;
while i < L {
let sum = self.0[i] + other.0[i] + carry;
data[i] = sum & CInt::<B, L>::MASK;
carry = sum >> B;
i += 1;
}
CInt(data)
}
}

impl<const B: usize, const L: usize> Sub<&CInt<B, L>> for CInt<B, L> {
type Output = CInt<B, L>;
fn sub(self, other: &Self) -> Self::Output {
&self - other
}
}

impl<const B: usize, const L: usize> Sub for CInt<B, L> {
type Output = CInt<B, L>;
fn sub(self, other: Self) -> Self::Output {
&self - &other
Self(data)
}
}

impl<const B: usize, const L: usize> Neg for &CInt<B, L> {
type Output = CInt<B, L>;
fn neg(self) -> Self::Output {
/// Const fn equivalent for `Neg::neg`
pub const fn neg(&self) -> Self {
// For the two's complement code the additive negation is the result
// of adding 1 to the bitwise inverted argument's representation
let (mut data, mut carry) = ([0; L], 1);
for i in 0..L {
let sum = (self.0[i] ^ CInt::<B, L>::MASK) + carry;
data[i] = sum & CInt::<B, L>::MASK;
carry = sum >> B;
}
CInt(data)
}
}
let mut i = 0;

impl<const B: usize, const L: usize> Neg for CInt<B, L> {
type Output = CInt<B, L>;
fn neg(self) -> Self::Output {
-&self
}
}

impl<const B: usize, const L: usize> Mul for &CInt<B, L> {
type Output = CInt<B, L>;
fn mul(self, other: Self) -> Self::Output {
let mut data = [0; L];
for i in 0..L {
let mut carry = 0;
for k in 0..(L - i) {
let sum = (data[i + k] as u128)
+ (carry as u128)
+ (self.0[i] as u128) * (other.0[k] as u128);
data[i + k] = sum as u64 & CInt::<B, L>::MASK;
carry = (sum >> B) as u64;
}
while i < L {
let sum = (self.0[i] ^ Self::MASK) + carry;
data[i] = sum & Self::MASK;
carry = sum >> B;
i += 1;
}
CInt(data)
}
}

impl<const B: usize, const L: usize> Mul<&CInt<B, L>> for CInt<B, L> {
type Output = CInt<B, L>;
fn mul(self, other: &Self) -> Self::Output {
&self * other
}
}

impl<const B: usize, const L: usize> Mul for CInt<B, L> {
type Output = CInt<B, L>;
fn mul(self, other: Self) -> Self::Output {
&self * &other
Self(data)
}
}

impl<const B: usize, const L: usize> Mul<i64> for &CInt<B, L> {
type Output = CInt<B, L>;
fn mul(self, other: i64) -> Self::Output {
/// Const fn equivalent for `Mul::<i64>::mul`
pub const fn mul(&self, other: i64) -> Self {
let mut data = [0; L];
// If the short multiplicand is non-negative, the standard multiplication
// algorithm is performed. Otherwise, the product of the additively negated
Expand All @@ -402,36 +341,19 @@ impl<const B: usize, const L: usize> Mul<i64> for &CInt<B, L> {
// where the carry flag is initialized with the additively negated short
// multiplicand and the chunks of the long multiplicand are bitwise inverted
let (other, mut carry, mask) = if other < 0 {
(-other, -other as u64, CInt::<B, L>::MASK)
(-other, -other as u64, Self::MASK)
} else {
(other, 0, 0)
};
for i in 0..L {

let mut i = 0;
while i < L {
let sum = (carry as u128) + ((self.0[i] ^ mask) as u128) * (other as u128);
data[i] = sum as u64 & CInt::<B, L>::MASK;
data[i] = sum as u64 & Self::MASK;
carry = (sum >> B) as u64;
i += 1;
}
CInt(data)
}
}

impl<const B: usize, const L: usize> Mul<i64> for CInt<B, L> {
type Output = CInt<B, L>;
fn mul(self, other: i64) -> Self::Output {
&self * other
}
}

impl<const B: usize, const L: usize> Mul<&CInt<B, L>> for i64 {
type Output = CInt<B, L>;
fn mul(self, other: &CInt<B, L>) -> Self::Output {
other * self
}
}

impl<const B: usize, const L: usize> Mul<CInt<B, L>> for i64 {
type Output = CInt<B, L>;
fn mul(self, other: CInt<B, L>) -> Self::Output {
other * self
Self(data)
}
}

0 comments on commit b78e966

Please sign in to comment.