Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 186 additions & 62 deletions ff_derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ pub fn prime_field(input: proc_macro::TokenStream) -> proc_macro::TokenStream {

let mut gen = proc_macro2::TokenStream::new();

let (constants_impl, sqrt_impl) =
let (constants_impl, sqrt_impl, sqrt_ratio_impl) =
prime_field_constants_and_sqrt(&ast.ident, &modulus, limbs, generator);

gen.extend(constants_impl);
Expand All @@ -176,6 +176,7 @@ pub fn prime_field(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
&endianness,
limbs,
sqrt_impl,
sqrt_ratio_impl,
));

// Return the generated impl
Expand Down Expand Up @@ -462,12 +463,13 @@ fn test_exp() {
);
}


fn prime_field_constants_and_sqrt(
name: &syn::Ident,
modulus: &BigUint,
limbs: usize,
generator: BigUint,
) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
) -> (proc_macro2::TokenStream, proc_macro2::TokenStream, proc_macro2::TokenStream) {
let bytes = limbs * 8;
let modulus_num_bits = biguint_num_bits(modulus.clone());

Expand Down Expand Up @@ -498,63 +500,25 @@ fn prime_field_constants_and_sqrt(

// Compute 2^s root of unity given the generator
let root_of_unity = exp(generator.clone(), &t, &modulus);
let root_of_unity_inv = biguint_to_u64_vec(to_mont(invert(root_of_unity.clone())), limbs);
let root_of_unity = biguint_to_u64_vec(to_mont(root_of_unity), limbs);
let delta = biguint_to_u64_vec(
to_mont(exp(generator.clone(), &(BigUint::one() << s), &modulus)),
limbs,
);
let generator = biguint_to_u64_vec(to_mont(generator), limbs);

let sqrt_impl =
if (modulus % BigUint::from_str("4").unwrap()) == BigUint::from_str("3").unwrap() {
// Addition chain for (r + 1) // 4
let mod_plus_1_over_4 = pow_fixed::generate(
&quote! {self},
(modulus + BigUint::from_str("1").unwrap()) >> 2,
);
let root_of_unity_inv = biguint_to_u64_vec(to_mont(invert(root_of_unity.clone())), limbs);

quote! {
use ::ff::derive::subtle::ConstantTimeEq;

// Because r = 3 (mod 4)
// sqrt can be done with only one exponentiation,
// via the computation of self^((r + 1) // 4) (mod r)
let sqrt = {
#mod_plus_1_over_4
};

::ff::derive::subtle::CtOption::new(
sqrt,
(sqrt * &sqrt).ct_eq(self), // Only return Some if it's the square root.
)
}
} else {
// Addition chain for (t - 1) // 2
let t_minus_1_over_2 = if t == BigUint::one() {
quote!( #name::ONE )
} else {
pow_fixed::generate(&quote! {self}, (&t - BigUint::one()) >> 1)
};

quote! {
// Tonelli-Shanks algorithm works for every remaining odd prime.
// https://eprint.iacr.org/2012/685.pdf (page 12, algorithm 5)
// Tonelli shanks logic starts here
fn generate_tonelli_shanks_loop(name: &syn::Ident) -> proc_macro2::TokenStream {
quote! {
/// The loop takes in x and the `projenator` w = x^((t - 1)/2). The function
/// the modifies x, and the final value for x is sqrt(x) (iff x is a QR).
fn tonelli_shanks_loop(x: &mut #name, w: &#name) {
use ::ff::derive::subtle::{ConditionallySelectable, ConstantTimeEq};

// w = self^((t - 1) // 2)
let w = {
#t_minus_1_over_2
};

let mut v = S;
let mut x = *self * &w;
let mut b = x * &w;
use ::ff::PrimeField;

let mut v = #name::S;
*x *= w;
let mut b = *x * w;

// Initialize z as the 2^S root of unity.
let mut z = ROOT_OF_UNITY;
let mut z = #name::ROOT_OF_UNITY;

for max_v in (1..=S).rev() {
for max_v in (1..=#name::S).rev() {
let mut k = 1;
let mut tmp = b.square();
let mut j_less_than_v: ::ff::derive::subtle::Choice = 1.into();
Expand All @@ -569,20 +533,177 @@ fn prime_field_constants_and_sqrt(
z = #name::conditional_select(&z, &new_z, j_less_than_v);
}

let result = x * &z;
x = #name::conditional_select(&result, &x, b.ct_eq(&#name::ONE));
let result = *x * &z;
*x = #name::conditional_select(&result, x, b.ct_eq(&#name::ONE));
z = z.square();
b *= &z;
v = k;
}

::ff::derive::subtle::CtOption::new(
x,
(x * &x).ct_eq(self), // Only return Some if it's the square root.
)
}
}
}

// Recall p - 1 = 2^s * t
// Addition chain for (t - 1) // 2

let t_minus_1_over_2 = if t == BigUint::one() {
quote!( #name::ONE )
} else {
pow_fixed::generate(&quote! {x}, (&t - BigUint::one()) >> 1)
};

// Tonelli--Shanks inner loop
let tonelli_shanks_loop = generate_tonelli_shanks_loop(&name);

let sqrt_impl = if (modulus % BigUint::from_str("4").unwrap()) == BigUint::from_str("3").unwrap() {
// Addition chain for (r + 1) // 4
let mod_plus_1_over_4 = pow_fixed::generate(
&quote! {self},
(modulus + BigUint::from_str("1").unwrap()) >> 2,
);

quote! {
use ::ff::derive::subtle::ConstantTimeEq;

// Because r = 3 (mod 4)
// sqrt can be done with only one exponentiation,
// via the computation of self^((r + 1) // 4) (mod r)
let sqrt = {
#mod_plus_1_over_4
};

::ff::derive::subtle::CtOption::new(
sqrt,
(sqrt * &sqrt).ct_eq(self), // Only return Some if it's the square root.
)
}
} else {
quote! {
// Remark: The Tonelli-Shanks algorithm works for every odd prime.
// However, leave the above 3 mod 4.
// https://eprint.iacr.org/2012/685.pdf (page 12, algorithm 5)
use ::ff::derive::subtle::{ConditionallySelectable, ConstantTimeEq};
#tonelli_shanks_loop

// w = self^((t - 1) // 2)
let mut x = *self;

let w = {
#t_minus_1_over_2
};

tonelli_shanks_loop(&mut x, &w);

::ff::derive::subtle::CtOption::new(
x,
(x * &x).ct_eq(self), // Only return Some if it's the square root.
)
}
};

// Generates an implimentation of sqrt(num/div) using the merged inverse-and-sqrt
// version of the Tonelli--Shanks algorithm combining Scott's `Tricks of the trade` paper
// Section 2 and 2.1 (see https://eprint.iacr.org/2020/1497)
// This is a more general version of p.15 of https://eprint.iacr.org/2011/368.pdf
let sqrt_ratio_impl = {
// setup some compile-time constants
let zeta_2 = biguint_to_u64_vec(
to_mont(exp(root_of_unity.clone(), &BigUint::from_str("2").unwrap(), &modulus)),
limbs,
);
let tw = exp(root_of_unity.clone(), &BigUint::from_str("3").unwrap(), &modulus);
let tw_inv = biguint_to_u64_vec(
to_mont(invert(tw.clone())),
limbs
);
let tw_proj = biguint_to_u64_vec(
to_mont(exp(tw.clone(), &((&t - BigUint::one()) >> 1), &modulus)),
limbs,
);
let tw = biguint_to_u64_vec(
to_mont(tw),
limbs
);

// Fixed exponentiation (x^2*w^4)^(2^(s-1) - 1)
let two_s_m1 = BigUint::one() << (&s - 1);
let two_s_minus_1_m1 = if two_s_m1 == BigUint::one() {
quote!( #name::ONE )
} else {
pow_fixed::generate(&quote! {x2w4}, &two_s_m1 - BigUint::one())
};

quote! {
use ::ff::derive::subtle::{ConditionallySelectable, ConstantTimeEq, CtOption};
use ::ff::PrimeField;
#tonelli_shanks_loop

const Z2: #name = #name(#zeta_2); // #name::ROOT_OF_UNITY^2
const TW: #name = #name(#tw); // #name::ROOT_OF_UNITY^3
const TW_INV: #name = #name(#tw_inv); // TW^(-1)
const TW_PROJ: #name = #name(#tw_proj); // projenator of twist TW^((t - 1)/2)

let num_is_zero = num.is_zero();
let div_is_zero = div.is_zero();

let mut x = num.cube() * div;
let mut sqrtx = x.clone();

let mut tw_x = x.clone();
tw_x *= TW;
let mut tw_sqrtx = tw_x.clone();

let w = {
#t_minus_1_over_2
};
let tw_w = TW_PROJ * &w;

tonelli_shanks_loop(&mut sqrtx, &w); //sqrtx = sqrt(x) now
tonelli_shanks_loop(&mut tw_sqrtx, &tw_w); //tw_sqrtx = sqrt(tw_x) now
// Remark: One can avoid a second call to the loop when p = 3 (4) or p = 5 (8)
// since then tw_sqrtx = tw * sqrtx (cf. p15 of https://eprint.iacr.org/2011/368.pdf)

// x <- x^(-1) = (x^2 * w^4)^(2^(s-1) - 1) * (x * w^4)
let xw4 = w.square().square() * &x;
let mut x2w4 = x * &xw4;
x2w4 = {
#two_s_minus_1_m1
};
x = x2w4 * xw4;

// tx_x <- tw_x^(-1) = tw^(-1) * x^(-1)
let mut tw_x = TW_INV * &x;

// x = sqrt(num/div)
let n2 = num.square();
x = x * sqrtx * &n2;

// tw_x = sqrt(zeta * num / div)
tw_x = Z2 * tw_x * tw_sqrtx * n2;

let tw_num = #name::ROOT_OF_UNITY * num;

let is_square = (x.square() * div).ct_eq(num);
let is_nonsquare = (tw_x.square() * div).ct_eq(&tw_num);

assert!(bool::from(
num_is_zero | div_is_zero | (is_square ^ is_nonsquare)
));
(
is_square & (num_is_zero | !div_is_zero),
#name::conditional_select(&tw_x, &x, is_square),
)
}
};

// Some more constants
let root_of_unity = biguint_to_u64_vec(to_mont(root_of_unity), limbs);
let delta = biguint_to_u64_vec(
to_mont(exp(generator.clone(), &(BigUint::one() << s), &modulus)),
limbs,
);
let generator = biguint_to_u64_vec(to_mont(generator), limbs);

// Compute R^2 mod m
let r2 = biguint_to_u64_vec((&r * &r) % modulus, limbs);

Expand Down Expand Up @@ -649,6 +770,7 @@ fn prime_field_constants_and_sqrt(
const DELTA: #name = #name(#delta);
},
sqrt_impl,
sqrt_ratio_impl,
)
}

Expand All @@ -660,6 +782,7 @@ fn prime_field_impl(
endianness: &ReprEndianness,
limbs: usize,
sqrt_impl: proc_macro2::TokenStream,
sqrt_ratio_impl: proc_macro2::TokenStream,
) -> proc_macro2::TokenStream {
// Returns r{n} as an ident.
fn get_temp(n: usize) -> syn::Ident {
Expand Down Expand Up @@ -880,6 +1003,7 @@ fn prime_field_impl(
}
}


let squaring_impl = sqr_impl(quote! {self}, limbs);
let multiply_impl = mul_impl(quote! {self}, quote! {other}, limbs);
let invert_impl = inv_impl(quote! {self}, modulus);
Expand Down Expand Up @@ -1317,7 +1441,7 @@ fn prime_field_impl(
}

fn sqrt_ratio(num: &Self, div: &Self) -> (::ff::derive::subtle::Choice, Self) {
::ff::helpers::sqrt_ratio_generic(num, div)
#sqrt_ratio_impl
}

fn sqrt(&self) -> ::ff::derive::subtle::CtOption<Self> {
Expand Down
38 changes: 38 additions & 0 deletions tests/derive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,41 @@ fn sqrt() {
use rand::rngs::OsRng;
test(Fp::random(OsRng));
}

#[test]
fn sqrt_ratio_test() {
use ff::{Field, PrimeField};

#[derive(PrimeField)]
#[PrimeFieldModulus = "357686312646216567629137"]
#[PrimeFieldGenerator = "5"]
#[PrimeFieldReprEndianness = "little"]
struct Fp([u64; 2]);

fn test(num: Fp, div: Fp) {
let (choice, sqrt) = Fp::sqrt_ratio(&num, &div);

if bool::from(choice) {
assert!(div != Fp::ZERO);
let div_inv = div.invert().unwrap();
let expected = num * div_inv;
assert_eq!(sqrt.square(), expected);
} else if div != Fp::ZERO {
let div_inv = div.invert().unwrap();
let expected = Fp::ROOT_OF_UNITY * num * div_inv;
assert_eq!(sqrt.square(), expected);
} else {
assert_eq!(sqrt.square(), Fp::ZERO);
}
}

// Easy cases
test(Fp::ZERO, Fp::ONE); // sqrt(0/1) = (true, 0)
test(Fp::ONE, Fp::ZERO); // sqrt(1/0) = (false, 0)

// Random case
use rand::rngs::OsRng;
let a = Fp::random(&mut OsRng);
let b = Fp::random(&mut OsRng);
test(a, b);
}