Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement wNAF-based MSM derived from Gemini #539

Merged
merged 4 commits into from
Dec 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,13 @@ jobs:

- name: Test assembly on nightly
env:
RUSTFLAGS: -C target-cpu=native
RUSTFLAGS: -C target-cpu=native -Z macro-backtrace
uses: actions-rs/cargo@v1
with:
command: test
args: "--workspace \
--package ark-test-curves \
--all-features
-- -Z macro-backtrace
"
if: matrix.rust == 'nightly'

Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@
- Add constructor `new_coset`.
- Add convenience method `get_coset`.
- Add methods `coset_offset`, `coset_offset_inv` and `coset_offset_pow_size`.
- [\#539](https://github.com/arkworks-rs/algebra/pull/539) (`ark-ec`) Implement wNAF-based MSM, resulting in 5-10% speedups.

### Improvements

Expand Down
1 change: 1 addition & 0 deletions ec/src/models/short_weierstrass/group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,7 @@ where

impl<P: SWCurveConfig> ScalarMul for Projective<P> {
type MulBase = Affine<P>;
const NEGATION_IS_CHEAP: bool = true;

fn batch_convert_to_mul_base(bases: &[Self]) -> Vec<Self::MulBase> {
Self::normalize_batch(bases)
Expand Down
1 change: 1 addition & 0 deletions ec/src/models/twisted_edwards/group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,7 @@ where

impl<P: TECurveConfig> ScalarMul for Projective<P> {
type MulBase = Affine<P>;
const NEGATION_IS_CHEAP: bool = true;

fn batch_convert_to_mul_base(bases: &[Self]) -> Vec<Self::MulBase> {
Self::normalize_batch(bases)
Expand Down
1 change: 1 addition & 0 deletions ec/src/pairing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ impl<P: Pairing> Group for PairingOutput<P> {

impl<P: Pairing> crate::ScalarMul for PairingOutput<P> {
type MulBase = Self;
const NEGATION_IS_CHEAP: bool = P::TargetField::INVERSE_IS_FAST;

fn batch_convert_to_mul_base(bases: &[Self]) -> Vec<Self::MulBase> {
bases.to_vec()
Expand Down
11 changes: 9 additions & 2 deletions ec/src/scalar_mul/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ pub mod variable_base;

use crate::Group;
use ark_std::{
ops::{Add, AddAssign, Mul},
ops::{Add, AddAssign, Mul, Neg, Sub, SubAssign},
vec::Vec,
};

Expand All @@ -25,6 +25,10 @@ pub trait ScalarMul:
+ AddAssign<Self::MulBase>
+ for<'a> Add<&'a Self::MulBase, Output = Self>
+ for<'a> AddAssign<&'a Self::MulBase>
+ Sub<Self::MulBase, Output = Self>
+ SubAssign<Self::MulBase>
+ for<'a> Sub<&'a Self::MulBase, Output = Self>
+ for<'a> SubAssign<&'a Self::MulBase>
+ From<Self::MulBase>
{
type MulBase: Send
Expand All @@ -33,7 +37,10 @@ pub trait ScalarMul:
+ Eq
+ core::hash::Hash
+ Mul<Self::ScalarField, Output = Self>
+ for<'a> Mul<&'a Self::ScalarField, Output = Self>;
+ for<'a> Mul<&'a Self::ScalarField, Output = Self>
+ Neg<Output = Self::MulBase>;

const NEGATION_IS_CHEAP: bool;

fn batch_convert_to_mul_base(bases: &[Self]) -> Vec<Self::MulBase>;
}
296 changes: 207 additions & 89 deletions ec/src/scalar_mul/variable_base/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,96 +42,11 @@ pub trait VariableBaseMSM: ScalarMul {
bases: &[Self::MulBase],
bigints: &[<Self::ScalarField as PrimeField>::BigInt],
) -> Self {
let size = ark_std::cmp::min(bases.len(), bigints.len());
let scalars = &bigints[..size];
let bases = &bases[..size];
let scalars_and_bases_iter = scalars.iter().zip(bases).filter(|(s, _)| !s.is_zero());

let c = if size < 32 {
3
if Self::NEGATION_IS_CHEAP {
msm_bigint_wnaf(bases, bigints)
} else {
super::ln_without_floats(size) + 2
};

let num_bits = Self::ScalarField::MODULUS_BIT_SIZE as usize;
let fr_one = Self::ScalarField::one().into_bigint();

let zero = Self::zero();
let window_starts: Vec<_> = (0..num_bits).step_by(c).collect();

// Each window is of size `c`.
// We divide up the bits 0..num_bits into windows of size `c`, and
// in parallel process each such window.
let window_sums: Vec<_> = ark_std::cfg_into_iter!(window_starts)
.map(|w_start| {
let mut res = zero;
// We don't need the "zero" bucket, so we only have 2^c - 1 buckets.
let mut buckets = vec![zero; (1 << c) - 1];
// This clone is cheap, because the iterator contains just a
// pointer and an index into the original vectors.
scalars_and_bases_iter.clone().for_each(|(&scalar, base)| {
if scalar == fr_one {
// We only process unit scalars once in the first window.
if w_start == 0 {
res += base;
}
} else {
let mut scalar = scalar;

// We right-shift by w_start, thus getting rid of the
// lower bits.
scalar.divn(w_start as u32);

// We mod the remaining bits by 2^{window size}, thus taking `c` bits.
let scalar = scalar.as_ref()[0] % (1 << c);

// If the scalar is non-zero, we update the corresponding
// bucket.
// (Recall that `buckets` doesn't have a zero bucket.)
if scalar != 0 {
buckets[(scalar - 1) as usize] += base;
}
}
});

// Compute sum_{i in 0..num_buckets} (sum_{j in i..num_buckets} bucket[j])
// This is computed below for b buckets, using 2b curve additions.
//
// We could first normalize `buckets` and then use mixed-addition
// here, but that's slower for the kinds of groups we care about
// (Short Weierstrass curves and Twisted Edwards curves).
// In the case of Short Weierstrass curves,
// mixed addition saves ~4 field multiplications per addition.
// However normalization (with the inversion batched) takes ~6
// field multiplications per element,
// hence batch normalization is a slowdown.

// `running_sum` = sum_{j in i..num_buckets} bucket[j],
// where we iterate backward from i = num_buckets to 0.
let mut running_sum = Self::zero();
buckets.into_iter().rev().for_each(|b| {
running_sum += &b;
res += &running_sum;
});
res
})
.collect();

// We store the sum for the lowest window.
let lowest = *window_sums.first().unwrap();

// We're traversing windows from high to low.
lowest
+ &window_sums[1..]
.iter()
.rev()
.fold(zero, |mut total, sum_i| {
total += sum_i;
for _ in 0..c {
total.double_in_place();
}
total
})
msm_bigint(bases, bigints)
}
}

/// Streaming multi-scalar multiplication algorithm with hard-coded chunk
Expand Down Expand Up @@ -169,3 +84,206 @@ pub trait VariableBaseMSM: ScalarMul {
result
}
}

// Compute msm using windowed non-adjacent form
fn msm_bigint_wnaf<V: VariableBaseMSM>(
bases: &[V::MulBase],
bigints: &[<V::ScalarField as PrimeField>::BigInt],
) -> V {
let size = ark_std::cmp::min(bases.len(), bigints.len());
let scalars = &bigints[..size];
let bases = &bases[..size];

let c = if size < 32 {
3
} else {
super::ln_without_floats(size) + 2
};

let num_bits = V::ScalarField::MODULUS_BIT_SIZE as usize;
let digits_count = (num_bits + c - 1) / c;
let scalar_digits = scalars
.iter()
.flat_map(|s| make_digits(s, c, num_bits))
.collect::<Vec<_>>();
let zero = V::zero();
let window_sums: Vec<_> = ark_std::cfg_into_iter!(0..digits_count)
.map(|i| {
let mut buckets = vec![zero; 1 << c];
for (digits, base) in scalar_digits.chunks(digits_count).zip(bases) {
use ark_std::cmp::Ordering;
// digits is the digits thing of the first scalar?
let scalar = digits[i];
match 0.cmp(&scalar) {
Ordering::Less => buckets[(scalar - 1) as usize] += base,
Ordering::Greater => buckets[(-scalar - 1) as usize] -= base,
Ordering::Equal => (),
}
}

let mut running_sum = V::zero();
let mut res = V::zero();
buckets.into_iter().rev().for_each(|b| {
running_sum += &b;
res += &running_sum;
});
res
})
.collect();

// We store the sum for the lowest window.
let lowest = *window_sums.first().unwrap();

// We're traversing windows from high to low.
lowest
+ &window_sums[1..]
.iter()
.rev()
.fold(zero, |mut total, sum_i| {
total += sum_i;
for _ in 0..c {
total.double_in_place();
}
total
})
}

/// Optimized implementation of multi-scalar multiplication.
fn msm_bigint<V: VariableBaseMSM>(
bases: &[V::MulBase],
bigints: &[<V::ScalarField as PrimeField>::BigInt],
) -> V {
let size = ark_std::cmp::min(bases.len(), bigints.len());
let scalars = &bigints[..size];
let bases = &bases[..size];
let scalars_and_bases_iter = scalars.iter().zip(bases).filter(|(s, _)| !s.is_zero());

let c = if size < 32 {
3
} else {
super::ln_without_floats(size) + 2
};

let num_bits = V::ScalarField::MODULUS_BIT_SIZE as usize;
let one = V::ScalarField::one().into_bigint();

let zero = V::zero();
let window_starts: Vec<_> = (0..num_bits).step_by(c).collect();

// Each window is of size `c`.
// We divide up the bits 0..num_bits into windows of size `c`, and
// in parallel process each such window.
let window_sums: Vec<_> = ark_std::cfg_into_iter!(window_starts)
.map(|w_start| {
let mut res = zero;
// We don't need the "zero" bucket, so we only have 2^c - 1 buckets.
let mut buckets = vec![zero; (1 << c) - 1];
// This clone is cheap, because the iterator contains just a
// pointer and an index into the original vectors.
scalars_and_bases_iter.clone().for_each(|(&scalar, base)| {
if scalar == one {
// We only process unit scalars once in the first window.
if w_start == 0 {
res += base;
}
} else {
let mut scalar = scalar;

// We right-shift by w_start, thus getting rid of the
// lower bits.
scalar.divn(w_start as u32);

// We mod the remaining bits by 2^{window size}, thus taking `c` bits.
let scalar = scalar.as_ref()[0] % (1 << c);

// If the scalar is non-zero, we update the corresponding
// bucket.
// (Recall that `buckets` doesn't have a zero bucket.)
if scalar != 0 {
buckets[(scalar - 1) as usize] += base;
}
}
});

// Compute sum_{i in 0..num_buckets} (sum_{j in i..num_buckets} bucket[j])
// This is computed below for b buckets, using 2b curve additions.
//
// We could first normalize `buckets` and then use mixed-addition
// here, but that's slower for the kinds of groups we care about
// (Short Weierstrass curves and Twisted Edwards curves).
// In the case of Short Weierstrass curves,
// mixed addition saves ~4 field multiplications per addition.
// However normalization (with the inversion batched) takes ~6
// field multiplications per element,
// hence batch normalization is a slowdown.

// `running_sum` = sum_{j in i..num_buckets} bucket[j],
// where we iterate backward from i = num_buckets to 0.
let mut running_sum = V::zero();
buckets.into_iter().rev().for_each(|b| {
running_sum += &b;
res += &running_sum;
});
res
})
.collect();

// We store the sum for the lowest window.
let lowest = *window_sums.first().unwrap();

// We're traversing windows from high to low.
lowest
+ &window_sums[1..]
.iter()
.rev()
.fold(zero, |mut total, sum_i| {
total += sum_i;
for _ in 0..c {
total.double_in_place();
}
total
})
}

// From: https://github.com/arkworks-rs/gemini/blob/main/src/kzg/msm/variable_base.rs#L20
fn make_digits(a: &impl BigInteger, w: usize, num_bits: usize) -> Vec<i64> {
let scalar = a.as_ref();
let radix: u64 = 1 << w;
let window_mask: u64 = radix - 1;

let mut carry = 0u64;
let num_bits = if num_bits == 0 {
a.num_bits() as usize
} else {
num_bits
};
let digits_count = (num_bits + w - 1) / w;
let mut digits = vec![0i64; digits_count];
for (i, digit) in digits.iter_mut().enumerate() {
// Construct a buffer of bits of the scalar, starting at `bit_offset`.
let bit_offset = i * w;
let u64_idx = bit_offset / 64;
let bit_idx = bit_offset % 64;
// Read the bits from the scalar
let bit_buf: u64;
if bit_idx < 64 - w || u64_idx == scalar.len() - 1 {
// This window's bits are contained in a single u64,
// or it's the last u64 anyway.
bit_buf = scalar[u64_idx] >> bit_idx;
} else {
// Combine the current u64's bits with the bits from the next u64
bit_buf = (scalar[u64_idx] >> bit_idx) | (scalar[1 + u64_idx] << (64 - bit_idx));
}

// Read the actual coefficient value from the window
let coef = carry + (bit_buf & window_mask); // coef = [0, 2^r)

// Recenter coefficients from [0,2^w) to [-2^w/2, 2^w/2)
carry = (coef + radix / 2) >> w;
*digit = (coef as i64) - (carry << w) as i64;
}

digits[digits_count - 1] += (carry << w) as i64;

digits
}