Skip to content

Commit

Permalink
Improve field macro
Browse files Browse the repository at this point in the history
Compile time Montgomery multiplication!
  • Loading branch information
Pratyush committed Nov 29, 2020
1 parent 9bc5417 commit 0ec1bde
Show file tree
Hide file tree
Showing 18 changed files with 632 additions and 331 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ members = [
"serialize",
"serialize-derive",

"ff-macros",
"ff-asm",
"ff",

Expand Down
22 changes: 22 additions & 0 deletions ff-macros/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
[package]
name = "ark-ff-macros"
version = "0.1.0"
authors = [ "arkworks contributors" ]
description = "A library for generating x86-64 assembly for finite field multiplication"
homepage = "https://arworks.rs"
repository = "https://github.com/arkworks/algebra"
documentation = "https://docs.rs/ark-ff-asm/"
keywords = ["cryptography", "finite fields", "assembly" ]
categories = ["cryptography"]
include = ["Cargo.toml", "src", "README.md", "LICENSE-APACHE", "LICENSE-MIT"]
license = "MIT/Apache-2.0"
edition = "2018"

[dependencies]
quote = "1.0.0"
syn = { version = "1.0.0", features = ["full", "parsing", "extra-traits"]}
num-bigint = { version = "0.3", default-features = false }
num-traits = { version = "0.2", default-features = false }

[lib]
proc-macro = true
1 change: 1 addition & 0 deletions ff-macros/LICENSE-APACHE
1 change: 1 addition & 0 deletions ff-macros/LICENSE-MIT
88 changes: 88 additions & 0 deletions ff-macros/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#![deny(
warnings,
unused,
future_incompatible,
nonstandard_style,
rust_2018_idioms
)]
#![forbid(unsafe_code)]
#![recursion_limit = "128"]

use num_bigint::{BigInt, Sign};
use proc_macro::TokenStream;
use std::str::FromStr;
use syn::{Expr, Lit};

fn parse_string(input: TokenStream) -> Option<String> {
let input: Expr = syn::parse(input).unwrap();
let input = if let Expr::Group(syn::ExprGroup { expr, .. }) = input {
expr
} else {
panic!("could not parse");
};
match *input {
Expr::Lit(expr_lit) => match expr_lit.lit {
Lit::Str(s) => Some(s.value()),
_ => None,
},
_ => None,
}
}

fn str_to_limbs(num: &str) -> (bool, Vec<String>) {
let (sign, digits) = BigInt::from_str(num)
.expect("could not parse to bigint")
.to_radix_le(16);
let limbs = digits
.chunks(16)
.map(|chunk| {
let mut this = 0u64;
for (i, hexit) in chunk.iter().enumerate() {
this += (*hexit as u64) << (4 * i);
}
format!("{}u64", this)
})
.collect::<Vec<_>>();

let sign_is_positive = sign != Sign::Minus;
(sign_is_positive, limbs)
}

#[proc_macro]
pub fn to_sign_and_limbs(input: TokenStream) -> TokenStream {
let num = parse_string(input).expect("expected decimal string");
let (is_positive, limbs) = str_to_limbs(&num);

let limbs: String = limbs.join(", ");
let limbs_and_sign = format!("({}", is_positive) + ", [" + &limbs + "])";
let tuple: Expr = syn::parse_str(&limbs_and_sign).unwrap();
quote::quote!(#tuple).into()
}

#[test]
fn test_str_to_limbs() {
let (is_positive, limbs) = str_to_limbs("-5");
assert!(!is_positive);
assert_eq!(&limbs, &["5u64".to_string()]);

let (is_positive, limbs) = str_to_limbs("100");
assert!(is_positive);
assert_eq!(&limbs, &["100u64".to_string()]);

let large_num = -((1i128 << 64) + 101234001234i128);
let (is_positive, limbs) = str_to_limbs(&large_num.to_string());
assert!(!is_positive);
assert_eq!(&limbs, &["101234001234u64".to_string(), "1u64".to_string()]);

let num = "80949648264912719408558363140637477264845294720710499478137287262712535938301461879813459410946";
let (is_positive, limbs) = str_to_limbs(&num.to_string());
assert!(is_positive);
let expected_limbs = [
format!("{}u64", 0x8508c00000000002u64),
format!("{}u64", 0x452217cc90000000u64),
format!("{}u64", 0xc5ed1347970dec00u64),
format!("{}u64", 0x619aaf7d34594aabu64),
format!("{}u64", 0x9b3af05dd14f6ecu64),
];
assert_eq!(&limbs, &expected_limbs);
}
2 changes: 1 addition & 1 deletion ff/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ build = "build.rs"

[dependencies]
ark-ff-asm = { path = "../ff-asm" }
ark-ff-macros = { path = "../ff-macros" }
ark-std = { git = "https://github.com/arkworks-rs/utils", default-features = false }
ark-serialize = { path = "../serialize", default-features = false }
derivative = { version = "2", features = ["use_core"] }
Expand All @@ -27,7 +28,6 @@ rustc_version = "0.3"

[dev-dependencies]
rand_xorshift = "0.2"
ark-test-curves = { path = "../test-curves", default-features = false, features = [ "bls12_381_curve"] }

[features]
default = []
Expand Down
106 changes: 106 additions & 0 deletions ff/src/biginteger/arithmetic.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
use ark_std::vec::Vec;

/// Calculate a + b + carry, returning the sum and modifying the
/// carry value.
macro_rules! adc {
($a:expr, $b:expr, &mut $carry:expr$(,)?) => {{
let tmp = ($a as u128) + ($b as u128) + ($carry as u128);

$carry = (tmp >> 64) as u64;

tmp as u64
}};
}

/// Calculate a + (b * c) + carry, returning the least significant digit
/// and setting carry to the most significant digit.
macro_rules! mac_with_carry {
($a:expr, $b:expr, $c:expr, &mut $carry:expr$(,)?) => {{
let tmp = ($a as u128) + ($b as u128 * $c as u128) + ($carry as u128);

$carry = (tmp >> 64) as u64;

tmp as u64
}};
}

/// Calculate a - b - borrow, returning the result and modifying
/// the borrow value.
macro_rules! sbb {
($a:expr, $b:expr, &mut $borrow:expr$(,)?) => {{
let tmp = (1u128 << 64) + ($a as u128) - ($b as u128) - ($borrow as u128);

$borrow = if tmp >> 64 == 0 { 1 } else { 0 };

tmp as u64
}};
}

#[inline(always)]
pub(crate) fn mac(a: u64, b: u64, c: u64, carry: &mut u64) -> u64 {
let tmp = (u128::from(a)) + u128::from(b) * u128::from(c);

*carry = (tmp >> 64) as u64;

tmp as u64
}

#[inline(always)]
pub(crate) fn mac_discard(a: u64, b: u64, c: u64, carry: &mut u64) {
let tmp = (u128::from(a)) + u128::from(b) * u128::from(c);

*carry = (tmp >> 64) as u64;
}

pub fn find_wnaf(num: &[u64]) -> Vec<i64> {
let is_zero = |num: &[u64]| num.iter().all(|x| *x == 0u64);
let is_odd = |num: &[u64]| num[0] & 1 == 1;
let sub_noborrow = |num: &mut [u64], z: u64| {
let mut other = vec![0u64; num.len()];
other[0] = z;
let mut borrow = 0;

for (a, b) in num.iter_mut().zip(other) {
*a = sbb!(*a, b, &mut borrow);
}
};
let add_nocarry = |num: &mut [u64], z: u64| {
let mut other = vec![0u64; num.len()];
other[0] = z;
let mut carry = 0;

for (a, b) in num.iter_mut().zip(other) {
*a = adc!(*a, b, &mut carry);
}
};
let div2 = |num: &mut [u64]| {
let mut t = 0;
for i in num.iter_mut().rev() {
let t2 = *i << 63;
*i >>= 1;
*i |= t;
t = t2;
}
};

let mut num = num.to_vec();
let mut res = vec![];

while !is_zero(&num) {
let z: i64;
if is_odd(&num) {
z = 2 - (num[0] % 4) as i64;
if z >= 0 {
sub_noborrow(&mut num, z as u64)
} else {
add_nocarry(&mut num, (-z) as u64)
}
} else {
z = 0;
}
res.push(z);
div2(&mut num);
}

res
}
4 changes: 2 additions & 2 deletions ff/src/biginteger/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ macro_rules! bigint_impl {
let mut carry = 0;

for (a, b) in self.0.iter_mut().zip(other.0.iter()) {
*a = arithmetic::adc(*a, *b, &mut carry);
*a = adc!(*a, *b, &mut carry);
}

carry != 0
Expand All @@ -28,7 +28,7 @@ macro_rules! bigint_impl {
let mut borrow = 0;

for (a, b) in self.0.iter_mut().zip(other.0.iter()) {
*a = arithmetic::sbb(*a, *b, &mut borrow);
*a = sbb!(*a, *b, &mut borrow);
}

borrow != 0
Expand Down
108 changes: 2 additions & 106 deletions ff/src/biginteger/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ use rand::{
distributions::{Distribution, Standard},
Rng,
};

#[macro_use]
pub mod arithmetic;
#[macro_use]
mod macros;

Expand Down Expand Up @@ -117,108 +118,3 @@ pub trait BigInteger:
Ok(())
}
}

pub mod arithmetic {
use ark_std::vec::Vec;
pub fn find_wnaf(num: &[u64]) -> Vec<i64> {
let is_zero = |num: &[u64]| num.iter().all(|x| *x == 0u64);
let is_odd = |num: &[u64]| num[0] & 1 == 1;
let sub_noborrow = |num: &mut [u64], z: u64| {
let mut other = vec![0u64; num.len()];
other[0] = z;
let mut borrow = 0;

for (a, b) in num.iter_mut().zip(other) {
*a = sbb(*a, b, &mut borrow);
}
};
let add_nocarry = |num: &mut [u64], z: u64| {
let mut other = vec![0u64; num.len()];
other[0] = z;
let mut carry = 0;

for (a, b) in num.iter_mut().zip(other) {
*a = adc(*a, b, &mut carry);
}
};
let div2 = |num: &mut [u64]| {
let mut t = 0;
for i in num.iter_mut().rev() {
let t2 = *i << 63;
*i >>= 1;
*i |= t;
t = t2;
}
};

let mut num = num.to_vec();
let mut res = vec![];

while !is_zero(&num) {
let z: i64;
if is_odd(&num) {
z = 2 - (num[0] % 4) as i64;
if z >= 0 {
sub_noborrow(&mut num, z as u64)
} else {
add_nocarry(&mut num, (-z) as u64)
}
} else {
z = 0;
}
res.push(z);
div2(&mut num);
}

res
}

/// Calculate a + b + carry, returning the sum and modifying the
/// carry value.
#[inline(always)]
pub(crate) fn adc(a: u64, b: u64, carry: &mut u64) -> u64 {
let tmp = u128::from(a) + u128::from(b) + u128::from(*carry);

*carry = (tmp >> 64) as u64;

tmp as u64
}

/// Calculate a - b - borrow, returning the result and modifying
/// the borrow value.
#[inline(always)]
pub(crate) fn sbb(a: u64, b: u64, borrow: &mut u64) -> u64 {
let tmp = (1u128 << 64) + u128::from(a) - u128::from(b) - u128::from(*borrow);

*borrow = if tmp >> 64 == 0 { 1 } else { 0 };

tmp as u64
}

/// Calculate a + (b * c) + carry, returning the least significant digit
/// and setting carry to the most significant digit.
#[inline(always)]
pub(crate) fn mac_with_carry(a: u64, b: u64, c: u64, carry: &mut u64) -> u64 {
let tmp = (u128::from(a)) + u128::from(b) * u128::from(c) + u128::from(*carry);

*carry = (tmp >> 64) as u64;

tmp as u64
}

#[inline(always)]
pub(crate) fn mac(a: u64, b: u64, c: u64, carry: &mut u64) -> u64 {
let tmp = (u128::from(a)) + u128::from(b) * u128::from(c);

*carry = (tmp >> 64) as u64;

tmp as u64
}

#[inline(always)]
pub(crate) fn mac_discard(a: u64, b: u64, c: u64, carry: &mut u64) {
let tmp = (u128::from(a)) + u128::from(b) * u128::from(c);

*carry = (tmp >> 64) as u64;
}
}
Loading

0 comments on commit 0ec1bde

Please sign in to comment.