Skip to content

Commit b3d1d95

Browse files
committed
Faster polynomial evaluation in Lagrange basis with rhizomes algorithm.
Faster polynomial evaluation on flp.Query. Polynomials are directly evaluated in the Lagrange basis. Uses the batched algorithm from the rhizomes paper (https://ia.cr/2025/1727).
1 parent cad5eb9 commit b3d1d95

File tree

6 files changed

+198
-13
lines changed

6 files changed

+198
-13
lines changed

src/field.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,9 @@ pub trait FieldElement:
109109
/// Modular inversion, i.e., `self^-1 (mod p)`. If `self` is 0, then the output is undefined.
110110
fn inv(&self) -> Self;
111111

112+
/// Returns 1/2.
113+
fn half() -> Self;
114+
112115
/// Interprets the next [`Self::ENCODED_SIZE`] bytes from the input slice as an element of the
113116
/// field. Any of the most significant bits beyond the bit length of the modulus will be
114117
/// cleared, in order to minimize the amount of rejection sampling needed.
@@ -740,6 +743,10 @@ macro_rules! make_field {
740743
fn one() -> Self {
741744
Self($fp::ROOTS[0])
742745
}
746+
747+
fn half() -> Self {
748+
Self($fp::HALF)
749+
}
743750
}
744751

745752
impl FieldElementWithInteger for $elem {
@@ -1001,6 +1008,7 @@ pub(crate) mod test_utils {
10011008
let int_one = F::TestInteger::try_from(1).unwrap();
10021009
let zero = F::zero();
10031010
let one = F::one();
1011+
let half = F::half();
10041012
let two = F::from(F::TestInteger::try_from(2).unwrap());
10051013
let four = F::from(F::TestInteger::try_from(4).unwrap());
10061014

@@ -1045,6 +1053,7 @@ pub(crate) mod test_utils {
10451053
// mul
10461054
assert_eq!(two * two, four);
10471055
assert_eq!(two * one, two);
1056+
assert_eq!(two * half, one);
10481057
assert_eq!(two * zero, zero);
10491058
assert_eq!(one * F::from(int_modulus.clone()), zero);
10501059

src/field/field255.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,16 @@ impl FieldElement for Field255 {
309309
fn one() -> Self {
310310
Field255(fiat_25519_tight_field_element([1, 0, 0, 0, 0]))
311311
}
312+
313+
fn half() -> Self {
314+
Field255(fiat_25519_tight_field_element([
315+
2251799813685239,
316+
2251799813685247,
317+
2251799813685247,
318+
2251799813685247,
319+
1125899906842623,
320+
]))
321+
}
312322
}
313323

314324
impl Default for Field255 {

src/flp.rs

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ use crate::dp::DifferentialPrivacyStrategy;
5151
use crate::field::{FieldElement, FieldElementWithInteger, FieldError, NttFriendlyFieldElement};
5252
use crate::fp::log2;
5353
use crate::ntt::{ntt, ntt_inv_finish, NttError};
54-
use crate::polynomial::poly_eval;
54+
use crate::polynomial::{nth_root_powers, poly_eval, poly_eval_batched};
5555
use std::any::Any;
5656
use std::convert::TryFrom;
5757
use std::fmt::Debug;
@@ -466,16 +466,13 @@ pub trait Flp: Sized + Eq + Clone + Debug {
466466
// Reconstruct the wire polynomials `f[0], ..., f[g_arity-1]` and evaluate each wire
467467
// polynomial at query randomness value.
468468
let m = (1 + gadget.calls()).next_power_of_two();
469-
let m_inv = Self::Field::from(
470-
<Self::Field as FieldElementWithInteger>::Integer::try_from(m).unwrap(),
471-
)
472-
.inv();
473-
let mut f = vec![Self::Field::zero(); m];
474-
for wire in 0..gadget.arity() {
475-
ntt(&mut f, &gadget.f_vals[wire], m)?;
476-
ntt_inv_finish(&mut f, m, m_inv);
477-
verifier.push(poly_eval(&f, *query_rand_val));
478-
}
469+
470+
// Evaluates a batch of polynomials in the Lagrange basis.
471+
// This avoids using NTTs to convert them to the monomial basis.
472+
let roots = nth_root_powers(m);
473+
let polynomials = &gadget.f_vals[..gadget.arity()];
474+
let mut evals = poly_eval_batched(polynomials, &roots, *query_rand_val);
475+
verifier.append(&mut evals);
479476

480477
// Add the value of the gadget polynomial evaluated at the query randomness value.
481478
verifier.push(gadget.p_at_r);

src/fp.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ impl FieldParameters<u32> for FP32 {
3030
1534972560, 3732920810, 3229320047, 2836564014, 2170197442, 3760663902, 2144268387,
3131
3849278021, 1395394315, 574397626, 125025876, 3755041587, 2680072542, 3903828692,
3232
];
33+
const HALF: u32 = 2147483648;
3334
#[cfg(test)]
3435
const LOG2_BASE: usize = 32;
3536
#[cfg(test)]
@@ -72,6 +73,7 @@ impl FieldParameters<u64> for FP64 {
7273
10135969988448727187,
7374
6815045114074884550,
7475
];
76+
const HALF: u64 = 9223372036854775808;
7577
#[cfg(test)]
7678
const LOG2_BASE: usize = 64;
7779
#[cfg(test)]
@@ -114,6 +116,7 @@ impl FieldParameters<u128> for FP128 {
114116
258279638927684931537542082169183965856,
115117
148221243758794364405224645520862378432,
116118
];
119+
const HALF: u128 = 170141183460469231731687303715884105728;
117120
#[cfg(test)]
118121
const LOG2_BASE: usize = 64;
119122
#[cfg(test)]

src/fp/ops.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ pub trait FieldParameters<W: Word> {
6464
/// `ROOTS[l]` has order `2^l` in the multiplicative group.
6565
/// `ROOTS[0]` is equal to one by definition.
6666
const ROOTS: [W; MAX_ROOTS + 1];
67+
/// The multiplicative inverse of 2.
68+
const HALF: W;
6769
/// The log2(base) for the base used for multiprecision arithmetic.
6870
/// So, `LOG2_BASE ≤ 64` as processors have at most a 64-bit
6971
/// integer multiplier.

src/polynomial.rs

Lines changed: 166 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,12 @@
33

44
//! Functions for polynomial interpolation and evaluation
55
6-
use crate::field::NttFriendlyFieldElement;
76
#[cfg(all(feature = "crypto-dependencies", feature = "experimental"))]
87
use crate::ntt::{ntt, ntt_inv_finish};
8+
use crate::{
9+
field::{FieldElement, NttFriendlyFieldElement},
10+
fp::log2,
11+
};
912

1013
use std::convert::TryFrom;
1114

@@ -60,6 +63,90 @@ pub fn poly_interpret_eval<F: NttFriendlyFieldElement>(
6063
poly_eval(&tmp_coeffs[..points.len()], eval_at)
6164
}
6265

66+
/// Returns the element `1/n` on `F`, where `n` must be a power of two.
67+
#[inline]
68+
fn inv_pow2<F: FieldElement>(n: usize) -> F {
69+
let log2_n = usize::try_from(log2(n as u128)).unwrap();
70+
assert_eq!(n, 1 << log2_n);
71+
72+
let half = F::half();
73+
let mut x = F::one();
74+
for _ in 0..log2_n {
75+
x *= half
76+
}
77+
x
78+
}
79+
80+
/// Evaluates multiple polynomials given in the Lagrange basis.
81+
///
82+
/// This is Algorithm 7 of rhizomes paper.
83+
/// <https://eprint.iacr.org/2025/1727>.
84+
pub(crate) fn poly_eval_batched<F: FieldElement>(
85+
polynomials: &[Vec<F>],
86+
roots: &[F],
87+
x: F,
88+
) -> Vec<F> {
89+
let mut l = F::one();
90+
let mut u = Vec::with_capacity(polynomials.len());
91+
u.extend(polynomials.iter().map(|poly| poly[0]));
92+
let mut d = roots[0] - x;
93+
for (i, wn_i) in (1..).zip(&roots[1..]) {
94+
l *= d;
95+
d = *wn_i - x;
96+
let t = l * *wn_i;
97+
for (u_j, poly) in u.iter_mut().zip(polynomials) {
98+
*u_j *= d;
99+
if let Some(yi) = poly.get(i) {
100+
*u_j += t * *yi;
101+
}
102+
}
103+
}
104+
105+
if roots.len() > 1 {
106+
let num_roots_inv = -inv_pow2::<F>(roots.len());
107+
u.iter_mut().for_each(|u_j| *u_j *= num_roots_inv);
108+
}
109+
110+
u
111+
}
112+
113+
/// Generates the powers of the primitive n-th root of unity.
114+
///
115+
/// Returns
116+
/// roots\[i\] = w_n^i for 0 ≤ i < n,
117+
/// where
118+
/// w_n is the primitive n-th root of unity in `F`, and
119+
/// `n` must be a power of two.
120+
pub(crate) fn nth_root_powers<F: NttFriendlyFieldElement>(n: usize) -> Vec<F> {
121+
let log2_n = usize::try_from(log2(n as u128)).unwrap();
122+
assert_eq!(n, 1 << log2_n);
123+
124+
let mut roots = vec![F::zero(); n];
125+
roots[0] = F::one();
126+
if n > 1 {
127+
roots[1] = -F::one();
128+
for i in 2..=log2_n {
129+
let mid = 1 << (i - 1);
130+
// Due to w_{2n}^{2j} = w_{n}^j
131+
for j in (1..mid).rev() {
132+
roots[j << 1] = roots[j]
133+
}
134+
135+
let wn = F::root(i).unwrap();
136+
roots[1] = wn;
137+
roots[1 + mid] = -wn;
138+
139+
// Due to w_{n}^{j} = -w_{n}^{j+n/2}
140+
for j in (3..mid).step_by(2) {
141+
roots[j] = wn * roots[j - 1];
142+
roots[j + mid] = -roots[j]
143+
}
144+
}
145+
}
146+
147+
roots
148+
}
149+
63150
/// Returns a polynomial that evaluates to `0` if the input is in range `[start, end)`. Otherwise,
64151
/// the output is not `0`.
65152
pub(crate) fn poly_range_check<F: NttFriendlyFieldElement>(start: usize, end: usize) -> Vec<F> {
@@ -76,7 +163,11 @@ pub(crate) fn poly_range_check<F: NttFriendlyFieldElement>(start: usize, end: us
76163
mod tests {
77164
use crate::{
78165
field::{Field64, FieldElement, FieldPrio2, NttFriendlyFieldElement},
79-
polynomial::{poly_deg, poly_eval, poly_mul, poly_range_check},
166+
fp::log2,
167+
polynomial::{
168+
nth_root_powers, poly_deg, poly_eval, poly_eval_batched, poly_interpret_eval, poly_mul,
169+
poly_range_check,
170+
},
80171
};
81172
use std::convert::TryFrom;
82173

@@ -154,4 +245,77 @@ mod tests {
154245
let y = poly_eval(&p, x);
155246
assert_ne!(y, Field64::zero());
156247
}
248+
249+
/// Generates the powers of the primitive n-th root of unity.
250+
///
251+
/// Returns
252+
/// roots\[i\] = w_n^i for 0 ≤ i < n,
253+
/// where
254+
/// w_n is the primitive n-th root of unity in `F`, and
255+
/// `n` must be a power of two.
256+
///
257+
/// This is the iterative method.
258+
fn nth_root_powers_slow<F: NttFriendlyFieldElement>(n: usize) -> Vec<F> {
259+
let log2_n = usize::try_from(log2(n as u128)).unwrap();
260+
let wn = F::root(log2_n).unwrap();
261+
core::iter::successors(Some(F::one()), |&x| Some(x * wn))
262+
.take(n)
263+
.collect()
264+
}
265+
266+
#[test]
267+
fn test_nth_root_powers() {
268+
for i in 0..8 {
269+
assert_eq!(
270+
nth_root_powers::<Field64>(1 << i),
271+
nth_root_powers_slow::<Field64>(1 << i)
272+
);
273+
}
274+
}
275+
276+
#[test]
277+
fn test_poly_eval_batched_ones() {
278+
test_poly_eval_batched(&[1]);
279+
test_poly_eval_batched(&[1, 1]);
280+
}
281+
282+
#[test]
283+
fn test_poly_eval_batched_powers() {
284+
test_poly_eval_batched(&[1, 2, 4, 16, 64]);
285+
}
286+
287+
#[test]
288+
fn test_poly_eval_batched_arbitrary() {
289+
test_poly_eval_batched(&[1, 6, 3, 9]);
290+
}
291+
292+
fn test_poly_eval_batched(lengths: &[usize]) {
293+
let sizes = lengths
294+
.iter()
295+
.map(|s| s.next_power_of_two())
296+
.collect::<Vec<_>>();
297+
298+
let polynomials = sizes
299+
.iter()
300+
.map(|&size| Field64::random_vector(size))
301+
.collect::<Vec<_>>();
302+
let x = Field64::random_vector(1)[0];
303+
304+
let &n = sizes.iter().max().unwrap();
305+
let mut ntt_mem = vec![Field64::zero(); n];
306+
let roots = nth_root_powers(n);
307+
308+
// Evaluates several polynomials converting them to the monomial basis (iteratively).
309+
let want = polynomials
310+
.iter()
311+
.map(|poly| {
312+
let extended_poly = [poly.clone(), vec![Field64::zero(); n - poly.len()]].concat();
313+
poly_interpret_eval(&extended_poly, x, &mut ntt_mem)
314+
})
315+
.collect::<Vec<_>>();
316+
317+
// Simultaneouly evaluates several polynomials directly in the Lagrange basis (batched).
318+
let got = poly_eval_batched(&polynomials, &roots, x);
319+
assert_eq!(got, want, "sizes: {sizes:?} x: {x} P: {polynomials:?}");
320+
}
157321
}

0 commit comments

Comments
 (0)