Skip to content
This repository has been archived by the owner on Nov 6, 2020. It is now read-only.

ethcore: minor optimization of modexp by using LR exponentiation #9697

Merged
merged 1 commit into from
Oct 4, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions ethcore/benches/builtin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,10 @@ extern crate ethereum_types;
extern crate parity_bytes as bytes;
extern crate rustc_hex;

use std::collections::BTreeMap;

use bytes::BytesRef;
use ethcore::builtin::Builtin;
use ethcore::machine::EthereumMachine;
use ethereum_types::{Address, U256};
use ethereum_types::U256;
use ethcore::ethereum::new_byzantium_test_machine;
use rustc_hex::FromHex;
use self::test::Bencher;
Expand Down
68 changes: 44 additions & 24 deletions ethcore/src/builtin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -311,35 +311,51 @@ impl Impl for Ripemd160 {
}
}

// calculate modexp: exponentiation by squaring. the `num` crate has pow, but not modular.
fn modexp(mut base: BigUint, mut exp: BigUint, modulus: BigUint) -> BigUint {
use num::Integer;
// calculate modexp: left-to-right binary exponentiation to keep multiplicands lower
fn modexp(mut base: BigUint, exp: Vec<u8>, modulus: BigUint) -> BigUint {
const BITS_PER_DIGIT: usize = 8;

if modulus <= BigUint::one() { // n^m % 0 || n^m % 1
// n^m % 0 || n^m % 1
if modulus <= BigUint::one() {
return BigUint::zero();
}

if exp.is_zero() { // n^0 % m
// normalize exponent
let mut exp = exp.into_iter().skip_while(|d| *d == 0).peekable();

// n^0 % m
if let None = exp.peek() {
return BigUint::one();
}

if base.is_zero() { // 0^n % m, n>0
// 0^n % m, n > 0
if base.is_zero() {
return BigUint::zero();
}

let mut result = BigUint::one();
base = base % &modulus;

// fast path for base divisible by modulus.
// Fast path for base divisible by modulus.
if base.is_zero() { return BigUint::zero() }
while !exp.is_zero() {
if exp.is_odd() {
result = (result * &base) % &modulus;
}

exp = exp >> 1;
base = (base.clone() * base) % &modulus;
// Left-to-right binary exponentiation (Handbook of Applied Cryptography - Algorithm 14.79).
// http://www.cacr.math.uwaterloo.ca/hac/about/chap14.pdf
let mut result = BigUint::one();

for digit in exp {
let mut mask = 1 << (BITS_PER_DIGIT - 1);

for _ in 0..BITS_PER_DIGIT {
result = &result * &result % &modulus;

if digit & mask > 0 {
result = result * &base % &modulus;
}

mask >>= 1;
}
}

result
}

Expand All @@ -366,15 +382,19 @@ impl Impl for ModexpImpl {
} else {
// read the numbers themselves.
let mut buf = vec![0; max(mod_len, max(base_len, exp_len))];
let mut read_num = |len| {
let mut read_num = |reader: &mut io::Chain<&[u8], io::Repeat>, len: usize| {
reader.read_exact(&mut buf[..len]).expect("reading from zero-extended memory cannot fail; qed");
BigUint::from_bytes_be(&buf[..len])
};

let base = read_num(base_len);
let exp = read_num(exp_len);
let modulus = read_num(mod_len);
modexp(base, exp, modulus)
let base = read_num(&mut reader, base_len);

let mut exp_buf = vec![0; exp_len];
reader.read_exact(&mut exp_buf[..exp_len]).expect("reading from zero-extended memory cannot fail; qed");

let modulus = read_num(&mut reader, mod_len);

modexp(base, exp_buf, modulus)
};

// write output to given memory, left padded and same length as the modulus.
Expand Down Expand Up @@ -551,31 +571,31 @@ mod tests {
let mut base = BigUint::parse_bytes(b"12345", 10).unwrap();
let mut exp = BigUint::zero();
let mut modulus = BigUint::parse_bytes(b"789", 10).unwrap();
assert_eq!(me(base, exp, modulus), BigUint::one());
assert_eq!(me(base, exp.to_bytes_be(), modulus), BigUint::one());

// 0^n % m == 0
base = BigUint::zero();
exp = BigUint::parse_bytes(b"12345", 10).unwrap();
modulus = BigUint::parse_bytes(b"789", 10).unwrap();
assert_eq!(me(base, exp, modulus), BigUint::zero());
assert_eq!(me(base, exp.to_bytes_be(), modulus), BigUint::zero());

// n^m % 1 == 0
base = BigUint::parse_bytes(b"12345", 10).unwrap();
exp = BigUint::parse_bytes(b"789", 10).unwrap();
modulus = BigUint::one();
assert_eq!(me(base, exp, modulus), BigUint::zero());
assert_eq!(me(base, exp.to_bytes_be(), modulus), BigUint::zero());

// if n % d == 0, then n^m % d == 0
base = BigUint::parse_bytes(b"12345", 10).unwrap();
exp = BigUint::parse_bytes(b"789", 10).unwrap();
modulus = BigUint::parse_bytes(b"15", 10).unwrap();
assert_eq!(me(base, exp, modulus), BigUint::zero());
assert_eq!(me(base, exp.to_bytes_be(), modulus), BigUint::zero());

// others
base = BigUint::parse_bytes(b"12345", 10).unwrap();
exp = BigUint::parse_bytes(b"789", 10).unwrap();
modulus = BigUint::parse_bytes(b"97", 10).unwrap();
assert_eq!(me(base, exp, modulus), BigUint::parse_bytes(b"55", 10).unwrap());
assert_eq!(me(base, exp.to_bytes_be(), modulus), BigUint::parse_bytes(b"55", 10).unwrap());
}

#[test]
Expand Down