Skip to content

Commit

Permalink
Faster vartime division for BoxedUint (#626)
Browse files Browse the repository at this point in the history
Signed-off-by: Andrew Whitehead <cywolf@gmail.com>
  • Loading branch information
andrewwhitehead authored Aug 2, 2024
1 parent e948539 commit 6370b08
Show file tree
Hide file tree
Showing 3 changed files with 275 additions and 73 deletions.
98 changes: 96 additions & 2 deletions benches/boxed_uint.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use criterion::{black_box, criterion_group, criterion_main, BatchSize, Criterion};
use crypto_bigint::{BoxedUint, RandomBits};
use crypto_bigint::{BoxedUint, Limb, NonZero, RandomBits};
use rand_core::OsRng;

/// Size of `BoxedUint` to use in benchmark.
Expand Down Expand Up @@ -43,6 +43,100 @@ fn bench_shifts(c: &mut Criterion) {
group.finish();
}

fn bench_division(c: &mut Criterion) {
let mut group = c.benchmark_group("wrapping ops");

group.bench_function("boxed_div_rem", |b| {
b.iter_batched(
|| {
(
BoxedUint::max(UINT_BITS),
NonZero::new(BoxedUint::random_bits_with_precision(
&mut OsRng,
UINT_BITS / 2,
UINT_BITS,
))
.unwrap(),
)
},
|(x, y)| black_box(x.div_rem(&y)),
BatchSize::SmallInput,
)
});

group.bench_function("boxed_div_rem_vartime", |b| {
b.iter_batched(
|| {
(
BoxedUint::max(UINT_BITS),
NonZero::new(BoxedUint::random_bits_with_precision(
&mut OsRng,
UINT_BITS / 2,
UINT_BITS,
))
.unwrap(),
)
},
|(x, y)| black_box(x.div_rem_vartime(&y)),
BatchSize::SmallInput,
)
});

group.bench_function("boxed_div_rem_limb", |b| {
b.iter_batched(
|| (BoxedUint::max(UINT_BITS), NonZero::new(Limb::ONE).unwrap()),
|(x, y)| black_box(x.div_rem_limb(y)),
BatchSize::SmallInput,
)
});

group.bench_function("boxed_rem", |b| {
b.iter_batched(
|| {
(
BoxedUint::max(UINT_BITS),
NonZero::new(BoxedUint::random_bits_with_precision(
&mut OsRng,
UINT_BITS / 2,
UINT_BITS,
))
.unwrap(),
)
},
|(x, y)| black_box(x.rem(&y)),
BatchSize::SmallInput,
)
});

group.bench_function("boxed_rem_vartime", |b| {
b.iter_batched(
|| {
(
BoxedUint::max(UINT_BITS),
NonZero::new(BoxedUint::random_bits_with_precision(
&mut OsRng,
UINT_BITS / 2,
UINT_BITS,
))
.unwrap(),
)
},
|(x, y)| black_box(x.rem_vartime(&y)),
BatchSize::SmallInput,
)
});

group.bench_function("boxed_rem_limb", |b| {
b.iter_batched(
|| (BoxedUint::max(UINT_BITS), NonZero::new(Limb::ONE).unwrap()),
|(x, y)| black_box(x.rem_limb(y)),
BatchSize::SmallInput,
)
});

group.finish();
}

fn bench_boxed_sqrt(c: &mut Criterion) {
let mut group = c.benchmark_group("boxed_sqrt");
group.bench_function("boxed_sqrt, 4096", |b| {
Expand All @@ -62,6 +156,6 @@ fn bench_boxed_sqrt(c: &mut Criterion) {
});
}

criterion_group!(benches, bench_shifts, bench_boxed_sqrt);
criterion_group!(benches, bench_division, bench_shifts, bench_boxed_sqrt);

criterion_main!(benches);
215 changes: 160 additions & 55 deletions src/uint/boxed/div.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
//! [`BoxedUint`] division operations.
use crate::{
uint::boxed, BoxedUint, CheckedDiv, ConstantTimeSelect, DivRemLimb, Limb, NonZero, Reciprocal,
RemLimb, Wrapping,
uint::{boxed, div_limb::div2by1},
BoxedUint, CheckedDiv, ConstChoice, ConstantTimeSelect, DivRemLimb, Limb, NonZero, Reciprocal,
RemLimb, WideWord, Word, Wrapping,
};
use core::ops::{Div, DivAssign, Rem, RemAssign};
use subtle::{Choice, ConstantTimeEq, ConstantTimeLess, CtOption};
use subtle::{Choice, ConstantTimeLess, CtOption};

impl BoxedUint {
/// Computes `self / rhs` using a pre-made reciprocal,
Expand Down Expand Up @@ -46,34 +47,53 @@ impl BoxedUint {
///
/// Variable-time with respect to `rhs`
pub fn div_rem_vartime(&self, rhs: &NonZero<Self>) -> (Self, Self) {
// Since `rhs` is nonzero, this should always hold.
self.div_rem_vartime_unchecked(rhs.as_ref())
let yc = ((rhs.0.bits_vartime() + Limb::BITS - 1) / Limb::BITS) as usize;

match yc {
0 => panic!("zero divisor"),
1 => {
// Perform limb division
let (quo, rem_limb) =
self.div_rem_limb(rhs.0.limbs[0].to_nz().expect("zero divisor"));
let mut rem = Self::zero_with_precision(rhs.bits_precision());
rem.limbs[0] = rem_limb;
(quo, rem)
}
_ => {
let mut quo = self.clone();
let mut rem = rhs.0.clone();
div_rem_vartime_in_place(&mut quo.limbs, &mut rem.limbs[..yc]);
(quo, rem)
}
}
}

/// Computes self % rhs, returns the remainder.
///
/// Variable-time with respect to `rhs`.
///
/// # Panics
///
/// Panics if `self` and `rhs` have different precisions.
// TODO(tarcieri): handle different precisions without panicking
pub fn rem_vartime(&self, rhs: &NonZero<Self>) -> Self {
debug_assert_eq!(self.bits_precision(), rhs.bits_precision());
let mb = rhs.bits();
let mut bd = self.bits_precision() - mb;
let mut rem = self.clone();
// Will not overflow since `bd < bits_precision`
let mut c = rhs.shl_vartime(bd).expect("shift within range");

loop {
let (r, borrow) = rem.sbb(&c, Limb::ZERO);
rem = Self::ct_select(&r, &rem, !borrow.ct_eq(&Limb::ZERO));
if bd == 0 {
break rem;
let yc = ((rhs.0.bits_vartime() + Limb::BITS - 1) / Limb::BITS) as usize;

match yc {
0 => panic!("zero divisor"),
1 => {
// Perform limb division
let rem_limb = self.rem_limb(rhs.0.limbs[0].to_nz().expect("zero divisor"));
let mut rem = Self::zero_with_precision(rhs.bits_precision());
rem.limbs[0] = rem_limb;
rem
}
_ if yc > self.limbs.len() => {
let mut rem = Self::zero_with_precision(rhs.bits_precision());
rem.limbs[..self.limbs.len()].copy_from_slice(&self.limbs);
rem
}
_ => {
let mut quo = self.clone();
let mut rem = rhs.0.clone();
div_rem_vartime_in_place(&mut quo.limbs, &mut rem.limbs[..yc]);
rem
}
bd -= 1;
c.shr1_assign();
}
}

Expand Down Expand Up @@ -134,36 +154,6 @@ impl BoxedUint {

(quo, rem)
}

/// Computes `self` / `rhs`, returns the quotient (q), remainder (r) without checking if `rhs`
/// is zero.
///
/// This function operates in variable-time.
fn div_rem_vartime_unchecked(&self, rhs: &Self) -> (Self, Self) {
debug_assert_eq!(self.bits_precision(), rhs.bits_precision());
let mb = rhs.bits_vartime();
let mut bd = self.bits_precision() - mb;
let mut remainder = self.clone();
let mut quotient = Self::zero_with_precision(self.bits_precision());
// Will not overflow since `bd < bits_precision`
let mut c = rhs.shl_vartime(bd).expect("shift within range");

loop {
let (mut r, borrow) = remainder.sbb(&c, Limb::ZERO);
let borrow = Choice::from(borrow.0 as u8 & 1);
remainder = Self::ct_select(&r, &remainder, borrow);
r = &quotient | Self::one();
quotient = Self::ct_select(&r, &quotient, borrow);
if bd == 0 {
break;
}
bd -= 1;
c.shr1_assign();
quotient.shl1_assign();
}

(quotient, remainder)
}
}

impl CheckedDiv for BoxedUint {
Expand Down Expand Up @@ -320,9 +310,116 @@ impl RemLimb for BoxedUint {
}
}

/// Computes `x` / `y`, returning the quotient in `x` and the remainder in `y`.
///
/// This function operates in variable-time. It will panic if the divisor is zero
/// or the leading word of the divisor is zero.
pub(crate) fn div_rem_vartime_in_place(x: &mut [Limb], y: &mut [Limb]) {
let xc = x.len();
let yc = y.len();
assert!(
yc > 0 && y[yc - 1].0 != 0,
"divisor must have a non-zero leading word"
);

if xc == 0 {
// If the quotient is empty, set the remainder to zero and return.
y.fill(Limb::ZERO);
return;
} else if yc > xc {
// Divisor is greater than dividend. Return zero and the dividend as the
// quotient and remainder
y[..xc].copy_from_slice(&x[..xc]);
y[xc..].fill(Limb::ZERO);
x.fill(Limb::ZERO);
return;
}

let lshift = y[yc - 1].leading_zeros();
let rshift = if lshift == 0 { 0 } else { Limb::BITS - lshift };
let mut x_hi = Limb::ZERO;
let mut carry;

if lshift != 0 {
// Shift divisor such that it has no leading zeros
// This means that div2by1 requires no extra shifts, and ensures that the high word >= b/2
carry = Limb::ZERO;
for i in 0..yc {
(y[i], carry) = (Limb((y[i].0 << lshift) | carry.0), Limb(y[i].0 >> rshift));
}

// Shift the dividend to match
carry = Limb::ZERO;
for i in 0..xc {
(x[i], carry) = (Limb((x[i].0 << lshift) | carry.0), Limb(x[i].0 >> rshift));
}
x_hi = carry;
}

let reciprocal = Reciprocal::new(y[yc - 1].to_nz().expect("zero divisor"));

for xi in (yc - 1..xc).rev() {
// Divide high dividend words by the high divisor word to estimate the quotient word
let (mut quo, mut rem) = div2by1(x_hi.0, x[xi].0, &reciprocal);

for _ in 0..2 {
let qy = (quo as WideWord) * (y[yc - 2].0 as WideWord);
let rx = ((rem as WideWord) << Word::BITS) | (x[xi - 1].0 as WideWord);
// Constant-time check for q*y[-2] < r*x[-1], based on ConstChoice::from_word_lt
let diff = ConstChoice::from_word_lsb(
((((!rx) & qy) | (((!rx) | qy) & (rx.wrapping_sub(qy)))) >> (WideWord::BITS - 1))
as Word,
);
quo = diff.select_word(quo, quo.saturating_sub(1));
rem = diff.select_word(rem, rem.saturating_add(y[yc - 1].0));
}

// Subtract q*divisor from the dividend
carry = Limb::ZERO;
let mut borrow = Limb::ZERO;
let mut tmp;
for i in 0..yc {
(tmp, carry) = Limb::ZERO.mac(y[i], Limb(quo), carry);
(x[xi + i + 1 - yc], borrow) = x[xi + i + 1 - yc].sbb(tmp, borrow);
}
(_, borrow) = x_hi.sbb(carry, borrow);

// If the subtraction borrowed, then decrement q and add back the divisor
// The probability of this being needed is very low, about 2/(Limb::MAX+1)
let ct_borrow = ConstChoice::from_word_mask(borrow.0);
carry = Limb::ZERO;
for i in 0..yc {
(x[xi + i + 1 - yc], carry) =
x[xi + i + 1 - yc].adc(Limb::select(Limb::ZERO, y[i], ct_borrow), carry);
}
quo = ct_borrow.select_word(quo, quo.saturating_sub(1));

// Store the quotient within dividend and set x_hi to the current highest word
x_hi = x[xi];
x[xi] = Limb(quo);
}

// Copy the remainder to divisor
y[..yc - 1].copy_from_slice(&x[..yc - 1]);
y[yc - 1] = x_hi;

// Unshift the remainder from the earlier adjustment
if lshift != 0 {
carry = Limb::ZERO;
for i in (0..yc).rev() {
(y[i], carry) = (Limb((y[i].0 >> lshift) | carry.0), Limb(y[i].0 << rshift));
}
}

// Shift the quotient to the low limbs within dividend
// let x_size = xc - yc + 1;
x.copy_within(yc - 1..xc, 0);
x[xc - yc + 1..].fill(Limb::ZERO);
}

#[cfg(test)]
mod tests {
use super::{BoxedUint, NonZero};
use super::{BoxedUint, Limb, NonZero};

#[test]
fn rem() {
Expand All @@ -337,4 +434,12 @@ mod tests {
let p = NonZero::new(BoxedUint::from(997u128)).unwrap();
assert_eq!(BoxedUint::from(648u128), n.rem_vartime(&p));
}

#[test]
fn rem_limb() {
let n = BoxedUint::from(0xFFEECCBBAA99887766u128);
let pl = NonZero::new(Limb(997)).unwrap();
let p = NonZero::new(BoxedUint::from(997u128)).unwrap();
assert_eq!(n.rem(&p).limbs[0], n.rem_limb(pl));
}
}
Loading

0 comments on commit 6370b08

Please sign in to comment.