Skip to content

Commit 28d98d1

Browse files
authored
feat(fft): add fast mul for fft friendly fields (#997)
* feat(fft): add fast multiplication * chore: add fast_mul benchmark * impl suggestion * chore: impl suggestion
1 parent 3e41387 commit 28d98d1

File tree

5 files changed

+89
-5
lines changed

5 files changed

+89
-5
lines changed

crates/math/benches/polynomials/polynomial.rs

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
use super::utils::{rand_field_elements, rand_poly, FE};
1+
use super::utils::{rand_complex_mersenne_poly, rand_field_elements, rand_poly, FE};
22
use const_random::const_random;
33
use core::hint::black_box;
44
use criterion::Criterion;
5-
use lambdaworks_math::polynomial::Polynomial;
5+
use lambdaworks_math::{
6+
field::fields::mersenne31::extensions::Degree2ExtensionField, polynomial::Polynomial,
7+
};
68

79
pub fn polynomial_benchmarks(c: &mut Criterion) {
810
let mut group = c.benchmark_group("Polynomial");
@@ -43,6 +45,20 @@ pub fn polynomial_benchmarks(c: &mut Criterion) {
4345
bench.iter(|| black_box(&x_poly) * black_box(&y_poly));
4446
});
4547

48+
let big_order = 9;
49+
let x_poly = rand_complex_mersenne_poly(big_order);
50+
let y_poly = rand_complex_mersenne_poly(big_order);
51+
group.bench_function("fast_mul big poly", |bench| {
52+
bench.iter(|| {
53+
black_box(&x_poly)
54+
.fast_fft_multiplication::<Degree2ExtensionField>(black_box(&y_poly))
55+
.unwrap()
56+
});
57+
});
58+
group.bench_function("slow mul big poly", |bench| {
59+
bench.iter(|| black_box(&x_poly) * black_box(&y_poly));
60+
});
61+
4662
group.bench_function("div", |bench| {
4763
let x_poly = rand_poly(order);
4864
let y_poly = rand_poly(order);

crates/math/benches/polynomials/utils.rs

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
use const_random::const_random;
22
use lambdaworks_math::{
3-
field::fields::u64_prime_field::{U64FieldElement, U64PrimeField},
3+
field::{
4+
element::FieldElement,
5+
fields::{
6+
mersenne31::{extensions::Degree2ExtensionField, field::Mersenne31Field},
7+
u64_prime_field::{U64FieldElement, U64PrimeField},
8+
},
9+
},
410
polynomial::{
511
dense_multilinear_poly::DenseMultilinearPolynomial,
612
sparse_multilinear_poly::SparseMultilinearPolynomial, Polynomial,
@@ -36,6 +42,28 @@ pub fn rand_field_elements(order: u64) -> Vec<FE> {
3642
pub fn rand_poly(order: u64) -> Polynomial<FE> {
3743
Polynomial::new(&rand_field_elements(order))
3844
}
45+
#[allow(dead_code)]
46+
#[inline(never)]
47+
#[export_name = "u64_utils::rand_complex_mersenne_field_elements"]
48+
pub fn rand_complex_mersenne_field_elements(
49+
order: u32,
50+
) -> Vec<FieldElement<Degree2ExtensionField>> {
51+
let mut result = Vec::with_capacity(1 << order);
52+
for _ in 0..result.capacity() {
53+
result.push(FieldElement::<Degree2ExtensionField>::new([
54+
FieldElement::<Mersenne31Field>::new(random()),
55+
FieldElement::<Mersenne31Field>::new(random()),
56+
]));
57+
}
58+
result
59+
}
60+
61+
#[allow(dead_code)]
62+
#[inline(never)]
63+
#[export_name = "u64_utils::rand_complex_mersenne_poly"]
64+
pub fn rand_complex_mersenne_poly(order: u32) -> Polynomial<FieldElement<Degree2ExtensionField>> {
65+
Polynomial::new(&rand_complex_mersenne_field_elements(order))
66+
}
3967

4068
#[allow(dead_code)]
4169
#[inline(never)]

crates/math/src/fft/errors.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ pub enum FFTError {
1010
RootOfUnityError(u64),
1111
InputError(usize),
1212
OrderError(u64),
13+
DomainSizeError(usize),
1314
#[cfg(feature = "cuda")]
1415
CudaError(CudaError),
1516
}
@@ -24,6 +25,9 @@ impl Display for FFTError {
2425
FFTError::OrderError(v) => {
2526
write!(f, "Order should be less than or equal to 63, but is {v}")
2627
}
28+
FFTError::DomainSizeError(_) => {
29+
write!(f, "Domain size exceeds two adicity of the field")
30+
}
2731
#[cfg(feature = "cuda")]
2832
FFTError::CudaError(_) => {
2933
write!(f, "A CUDA related error has ocurred")

crates/math/src/fft/polynomial.rs

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ impl<E: IsField> Polynomial<FieldElement<E>> {
2727
) -> Result<Vec<FieldElement<E>>, FFTError> {
2828
let domain_size = domain_size.unwrap_or(0);
2929
let len = core::cmp::max(poly.coeff_len(), domain_size).next_power_of_two() * blowup_factor;
30-
30+
if len.trailing_zeros() as u64 > F::TWO_ADICITY {
31+
return Err(FFTError::DomainSizeError(len.trailing_zeros() as usize));
32+
}
3133
if poly.coefficients().is_empty() {
3234
return Ok(vec![FieldElement::zero(); len]);
3335
}
@@ -97,6 +99,22 @@ impl<E: IsField> Polynomial<FieldElement<E>> {
9799
let scaled = Polynomial::interpolate_fft::<F>(fft_evals)?;
98100
Ok(scaled.scale(&offset.inv().unwrap()))
99101
}
102+
103+
/// Multiplies two polynomials using FFT.
104+
/// It's faster than naive multiplication when the degree of the polynomials is large enough (>=2**6).
105+
/// This works best with polynomials whose highest degree is equal to a power of 2 - 1.
106+
/// Will return an error if the degree of the resulting polynomial is greater than 2**63.
107+
pub fn fast_fft_multiplication<F: IsFFTField + IsSubFieldOf<E>>(
108+
&self,
109+
other: &Self,
110+
) -> Result<Self, FFTError> {
111+
let domain_size = self.degree() + other.degree() + 1;
112+
let p = Polynomial::evaluate_fft::<F>(self, 1, Some(domain_size))?;
113+
let q = Polynomial::evaluate_fft::<F>(other, 1, Some(domain_size))?;
114+
let r = p.into_iter().zip(q).map(|(a, b)| a * b).collect::<Vec<_>>();
115+
116+
Polynomial::interpolate_fft::<F>(&r)
117+
}
100118
}
101119

102120
pub fn compose_fft<F, E>(
@@ -313,6 +331,11 @@ mod tests {
313331

314332
prop_assert_eq!(poly, new_poly);
315333
}
334+
335+
#[test]
336+
fn test_fft_multiplication_works(poly in poly(7), other in poly(7)) {
337+
prop_assert_eq!(poly.fast_fft_multiplication::<F>(&other).unwrap(), poly * other);
338+
}
316339
}
317340

318341
#[test]
@@ -408,6 +431,11 @@ mod tests {
408431
let (poly, new_poly) = gen_fft_interpolate_and_evaluate(poly);
409432
prop_assert_eq!(poly, new_poly);
410433
}
434+
435+
#[test]
436+
fn test_fft_multiplication_works(poly in poly(7), other in poly(7)) {
437+
prop_assert_eq!(poly.fast_fft_multiplication::<F>(&other).unwrap(), poly * other);
438+
}
411439
}
412440
}
413441

crates/math/src/field/fields/mersenne31/extensions.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use super::field::Mersenne31Field;
22
use crate::field::{
33
element::FieldElement,
44
errors::FieldError,
5-
traits::{IsField, IsSubFieldOf},
5+
traits::{IsFFTField, IsField, IsSubFieldOf},
66
};
77
#[cfg(feature = "alloc")]
88
use alloc::vec::Vec;
@@ -93,6 +93,14 @@ impl IsField for Degree2ExtensionField {
9393
}
9494
}
9595

96+
impl IsFFTField for Degree2ExtensionField {
97+
// Values taken from stwo
98+
// https://github.com/starkware-libs/stwo/blob/dev/crates/prover/src/core/circle.rs#L203-L209
99+
const TWO_ADICITY: u64 = 31;
100+
const TWO_ADIC_PRIMITVE_ROOT_OF_UNITY: Self::BaseType =
101+
[FpE::const_from_raw(2), FpE::const_from_raw(1268011823)];
102+
}
103+
96104
impl IsSubFieldOf<Degree2ExtensionField> for Mersenne31Field {
97105
fn add(
98106
a: &Self::BaseType,

0 commit comments

Comments
 (0)