Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 50 additions & 2 deletions ml-dsa/src/algebra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,65 @@ pub(crate) trait Decompose {
fn decompose<TwoGamma2: Unsigned>(self) -> (Elem, Elem);
}

/// Constant-time division by a compile-time constant divisor.
///
/// This trait provides a constant-time alternative to the hardware division
/// instruction, which has variable timing based on operand values.
/// Uses Barrett reduction to compute `x / M` where M is a compile-time constant.
pub(crate) trait ConstantTimeDiv: Unsigned {
/// Bit shift for Barrett reduction, chosen to provide sufficient precision
const CT_DIV_SHIFT: usize;
/// Precomputed multiplier: ceil(2^SHIFT / M)
const CT_DIV_MULTIPLIER: u64;

/// Perform constant-time division of x by `Self::U32`
/// Requires: x < Q (the field modulus, ~2^23)
#[allow(clippy::inline_always)] // Required for constant-time guarantees in crypto code
#[inline(always)]
fn ct_div(x: u32) -> u32 {
// Barrett reduction: q = (x * MULTIPLIER) >> SHIFT
// This gives us floor(x / M) for x < 2^SHIFT / MULTIPLIER * M
let x64 = u64::from(x);
let quotient = (x64 * Self::CT_DIV_MULTIPLIER) >> Self::CT_DIV_SHIFT;
// SAFETY: quotient is guaranteed to fit in u32 because:
// - x < Q (~2^23), so quotient = x / M < x < 2^23 < 2^32
#[allow(clippy::cast_possible_truncation, clippy::as_conversions)]
let result = quotient as u32;
result
}
}

impl<M> ConstantTimeDiv for M
where
M: Unsigned,
{
// Use a shift that provides enough precision for the ML-DSA field (Q ~ 2^23)
// We need SHIFT > log2(Q) + log2(M) to ensure accuracy
// With Q < 2^24 and M < 2^20, SHIFT = 48 is sufficient
const CT_DIV_SHIFT: usize = 48;

// Precompute the multiplier at compile time
// We add (M-1) before dividing to get ceiling division, ensuring we never underestimate
#[allow(clippy::integer_division_remainder_used)]
const CT_DIV_MULTIPLIER: u64 = (1u64 << Self::CT_DIV_SHIFT).div_ceil(M::U64);
}

impl Decompose for Elem {
// Algorithm 36 Decompose
//
// This implementation uses constant-time division to avoid timing side-channels.
// The original algorithm used hardware division which has variable timing based
// on operand values, potentially leaking secret information during signing.
fn decompose<TwoGamma2: Unsigned>(self) -> (Elem, Elem) {
let r_plus = self.clone();
let r0 = r_plus.mod_plus_minus::<TwoGamma2>();

if r_plus - r0 == Elem::new(BaseField::Q - 1) {
(Elem::new(0), r0 - Elem::new(1))
} else {
let mut r1 = r_plus - r0;
r1.0 /= TwoGamma2::U32;
let diff = r_plus - r0;
// Use constant-time division instead of hardware division
let r1 = Elem::new(TwoGamma2::ct_div(diff.0));
(r1, r0)
}
}
Expand Down
99 changes: 69 additions & 30 deletions ml-dsa/src/ntt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,28 +50,46 @@ pub(crate) trait Ntt {
fn ntt(&self) -> Self::Output;
}

/// Constant-time NTT butterfly layer.
///
/// Uses const generics to ensure loop bounds are compile-time constants,
/// avoiding UDIV instructions from runtime `step_by` calculations.
#[allow(clippy::inline_always)] // Required for constant-time guarantees in crypto code
#[inline(always)]
fn ntt_layer<const LEN: usize, const ITERATIONS: usize>(w: &mut [Elem; 256], m: &mut usize) {
for i in 0..ITERATIONS {
let start = i * 2 * LEN;
*m += 1;
let z = ZETA_POW_BITREV[*m];
for j in start..(start + LEN) {
let t = z * w[j + LEN];
w[j + LEN] = w[j] - t;
w[j] = w[j] + t;
}
}
}

impl Ntt for Polynomial {
type Output = NttPolynomial;

// Algorithm 41 NTT
//
// This implementation uses const-generic helper functions to ensure all loop
// bounds are compile-time constants, avoiding potential UDIV instructions.
fn ntt(&self) -> Self::Output {
let mut w = self.0.clone();

let mut w: [Elem; 256] = self.0.clone().into();
let mut m = 0;
for len in [128, 64, 32, 16, 8, 4, 2, 1] {
for start in (0..256).step_by(2 * len) {
m += 1;
let z = ZETA_POW_BITREV[m];

for j in start..(start + len) {
let t = z * w[j + len];
w[j + len] = w[j] - t;
w[j] = w[j] + t;
}
}
}

NttPolynomial::new(w)
ntt_layer::<128, 1>(&mut w, &mut m);
ntt_layer::<64, 2>(&mut w, &mut m);
ntt_layer::<32, 4>(&mut w, &mut m);
ntt_layer::<16, 8>(&mut w, &mut m);
ntt_layer::<8, 16>(&mut w, &mut m);
ntt_layer::<4, 32>(&mut w, &mut m);
ntt_layer::<2, 64>(&mut w, &mut m);
ntt_layer::<1, 128>(&mut w, &mut m);

NttPolynomial::new(w.into())
}
}

Expand All @@ -89,30 +107,51 @@ pub(crate) trait NttInverse {
fn ntt_inverse(&self) -> Self::Output;
}

/// Constant-time inverse NTT butterfly layer.
///
/// Uses const generics to ensure loop bounds are compile-time constants,
/// avoiding UDIV instructions from runtime `step_by` calculations.
#[allow(clippy::inline_always)] // Required for constant-time guarantees in crypto code
#[inline(always)]
fn ntt_inverse_layer<const LEN: usize, const ITERATIONS: usize>(
w: &mut [Elem; 256],
m: &mut usize,
) {
for i in 0..ITERATIONS {
let start = i * 2 * LEN;
*m -= 1;
let z = -ZETA_POW_BITREV[*m];
for j in start..(start + LEN) {
let t = w[j];
w[j] = t + w[j + LEN];
w[j + LEN] = z * (t - w[j + LEN]);
}
}
}

impl NttInverse for NttPolynomial {
type Output = Polynomial;

// Algorithm 42 NTT^{−1}
//
// This implementation uses const-generic helper functions to ensure all loop
// bounds are compile-time constants, avoiding potential UDIV instructions.
fn ntt_inverse(&self) -> Self::Output {
const INVERSE_256: Elem = Elem::new(8_347_681);

let mut w = self.0.clone();

let mut w: [Elem; 256] = self.0.clone().into();
let mut m = 256;
for len in [1, 2, 4, 8, 16, 32, 64, 128] {
for start in (0..256).step_by(2 * len) {
m -= 1;
let z = -ZETA_POW_BITREV[m];

for j in start..(start + len) {
let t = w[j];
w[j] = t + w[j + len];
w[j + len] = z * (t - w[j + len]);
}
}
}

INVERSE_256 * &Polynomial::new(w)
ntt_inverse_layer::<1, 128>(&mut w, &mut m);
ntt_inverse_layer::<2, 64>(&mut w, &mut m);
ntt_inverse_layer::<4, 32>(&mut w, &mut m);
ntt_inverse_layer::<8, 16>(&mut w, &mut m);
ntt_inverse_layer::<16, 8>(&mut w, &mut m);
ntt_inverse_layer::<32, 4>(&mut w, &mut m);
ntt_inverse_layer::<64, 2>(&mut w, &mut m);
ntt_inverse_layer::<128, 1>(&mut w, &mut m);

INVERSE_256 * &Polynomial::new(w.into())
}
}

Expand Down