Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Changed

- Optimize `from_str_radix` ([#557])

[#557]: https://github.com/recmo/uint/pull/557

## [1.17.2] - 2025-12-28

### Fixed
Expand Down
43 changes: 18 additions & 25 deletions src/algorithms/mul_redc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,36 +21,33 @@ pub fn mul_redc<const N: usize>(a: [u64; N], b: [u64; N], modulus: [u64; N], inv
// See <https://tches.iacr.org/index.php/TCHES/article/view/10972>
let mut result = [0; N];
let mut carry = false;
let has_top_carry = modulus[N - 1] >= 0x7fff_ffff_ffff_ffff;
for b in b {
let mut m = 0;
let mut carry_1 = 0;
let mut carry_2 = 0;
for i in 0..N {
// Add limb product
let mut carry_1;
let mut carry_2;

// i = 0: compute initial value and reduction factor.
let (value, next_carry) = carrying_mul_add(a[0], b, result[0], 0);
carry_1 = next_carry;
let m = value.wrapping_mul(inv);
let (value, next_carry) = carrying_mul_add(modulus[0], m, value, 0);
carry_2 = next_carry;
debug_assert_eq!(value, 0);

// i = 1..N
for i in 1..N {
let (value, next_carry) = carrying_mul_add(a[i], b, result[i], carry_1);
carry_1 = next_carry;

if i == 0 {
// Compute reduction factor
m = value.wrapping_mul(inv);
}

// Add m * modulus to acc to clear next_result[0]
let (value, next_carry) = carrying_mul_add(modulus[i], m, value, carry_2);
carry_2 = next_carry;

// Shift result
if i > 0 {
result[i - 1] = value;
} else {
debug_assert_eq!(value, 0);
}
result[i - 1] = value;
}

// Add carries
let (value, next_carry) = carrying_add(carry_1, carry_2, carry);
result[N - 1] = value;
if modulus[N - 1] >= 0x7fff_ffff_ffff_ffff {
if has_top_carry {
carry = next_carry;
} else {
debug_assert!(!next_carry);
Expand All @@ -74,8 +71,8 @@ pub fn square_redc<const N: usize>(a: [u64; N], modulus: [u64; N], inv: u64) ->

let mut result = [0; N];
let mut carry_outer = 0;
let has_top_carry = modulus[N - 1] >= 0x3fff_ffff_ffff_ffff;
for i in 0..N {
// Add limb product
let (value, mut carry_lo) = carrying_mul_add(a[i], a[i], result[i], 0);
let mut carry_hi = false;
result[i] = value;
Expand All @@ -87,7 +84,6 @@ pub fn square_redc<const N: usize>(a: [u64; N], modulus: [u64; N], inv: u64) ->
carry_hi = next_carry_hi;
}

// Add m times modulus to result and shift one limb
let m = result[0].wrapping_mul(inv);
let (value, mut carry) = carrying_mul_add(m, modulus[0], result[0], 0);
debug_assert_eq!(value, 0);
Expand All @@ -97,19 +93,16 @@ pub fn square_redc<const N: usize>(a: [u64; N], modulus: [u64; N], inv: u64) ->
carry = next_carry;
}

// Add carries
if modulus[N - 1] >= 0x3fff_ffff_ffff_ffff {
if has_top_carry {
let wide = (carry_outer as u128)
.wrapping_add(carry_lo as u128)
.wrapping_add((carry_hi as u128) << 64)
.wrapping_add(carry as u128);
result[N - 1] = wide as u64;

// Note carry_outer can be {0, 1, 2}.
carry_outer = (wide >> 64) as u64;
debug_assert!(carry_outer <= 2);
} else {
// `carry_outer` and `carry_hi` are always zero.
debug_assert!(!carry_hi);
debug_assert_eq!(carry_outer, 0);
let (value, carry) = carry_lo.overflowing_add(carry);
Expand Down
57 changes: 43 additions & 14 deletions src/base_convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,35 +193,64 @@ impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
base: u64,
digits: I,
) -> Result<Self, BaseConvertError> {
// OPT: Special handling of bases that divide 2^64, and bases that are
// powers of 2.
// OPT: Same trick as with `to_base_le`, find the largest power of base
// that fits `u64` and accumulate there first.
if base < 2 {
return Err(BaseConvertError::InvalidBase(base));
}

let chunk_base = crate::utils::max_pow_u64(base);
let mut chunk_power: usize = 1;
{
let mut p = base;
while p != chunk_base {
p *= base;
chunk_power += 1;
}
}

let mut result = Self::ZERO;
let mut chunk_val: u64 = 0;
let mut chunk_digits: usize = 0;
for digit in digits {
if digit >= base {
return Err(BaseConvertError::InvalidDigit(digit, base));
}
// Multiply by base.
// OPT: keep track of non-zero limbs and mul the minimum.
let mut carry = u128::from(digit);
#[allow(clippy::cast_possible_truncation)]
for limb in &mut result.limbs {
carry += u128::from(*limb) * u128::from(base);
*limb = carry as u64;
carry >>= 64;
chunk_val = chunk_val * base + digit;
chunk_digits += 1;
if chunk_digits == chunk_power {
Self::from_base_muladd(&mut result, chunk_base, chunk_val)?;
chunk_val = 0;
chunk_digits = 0;
}
if carry > 0 || (LIMBS != 0 && result.limbs[LIMBS - 1] > Self::MASK) {
return Err(BaseConvertError::Overflow);
}
if chunk_digits > 0 {
let mut tail_base = base;
for _ in 1..chunk_digits {
tail_base *= base;
}
Self::from_base_muladd(&mut result, tail_base, chunk_val)?;
}

Ok(result)
}

#[inline(always)]
#[allow(clippy::cast_possible_truncation)]
fn from_base_muladd(
result: &mut Self,
factor: u64,
addend: u64,
) -> Result<(), BaseConvertError> {
let mut carry = u128::from(addend);
for limb in &mut result.limbs {
carry += u128::from(*limb) * u128::from(factor);
*limb = carry as u64;
carry >>= 64;
}
if carry > 0 || (LIMBS != 0 && result.limbs[LIMBS - 1] > Self::MASK) {
return Err(BaseConvertError::Overflow);
}
Ok(())
}
}

struct SpigotLittle<const LIMBS: usize> {
Expand Down
64 changes: 64 additions & 0 deletions src/bits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,38 @@ impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
#[inline]
#[must_use]
pub const fn overflowing_shl(self, rhs: usize) -> (Self, bool) {
if LIMBS == 1 {
let (limbs, bits) = (rhs / 64, rhs % 64);
if limbs >= 1 {
return (Self::ZERO, self.limbs[0] != 0);
}
let x = self.limbs[0];
let carry = (x >> (63 - bits)) >> 1;
let mut r = Self::ZERO;
r.limbs[0] = (x << bits) & Self::MASK;
return (r, carry != 0);
}
if LIMBS == 2 {
let (limbs, bits) = (rhs / 64, rhs % 64);
if limbs >= 2 {
return (Self::ZERO, !self.const_is_zero());
}
let val = self.as_double_words()[0].get();
let shifted = val << bits;
if limbs == 0 {
let carry = (val >> (127 - bits)) >> 1;
let mut r = Self::ZERO;
r.limbs[0] = shifted as u64;
r.limbs[1] = (shifted >> 64) as u64 & Self::MASK;
return (r, carry != 0);
}
let x = self.limbs[0] as u128;
let carry = (x >> (63 - bits)) >> 1;
let mut r = Self::ZERO;
r.limbs[1] = (x << bits) as u64 & Self::MASK;
return (r, carry != 0);
}

let (limbs, bits) = (rhs / 64, rhs % 64);
if limbs >= LIMBS {
return (Self::ZERO, !self.const_is_zero());
Expand Down Expand Up @@ -410,6 +442,38 @@ impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
#[inline]
#[must_use]
pub const fn overflowing_shr(self, rhs: usize) -> (Self, bool) {
if LIMBS == 1 {
let (limbs, bits) = (rhs / 64, rhs % 64);
if limbs >= 1 {
return (Self::ZERO, self.limbs[0] != 0);
}
let x = self.limbs[0];
let carry = (x << (63 - bits)) << 1;
let mut r = Self::ZERO;
r.limbs[0] = x >> bits;
return (r, carry != 0);
}
if LIMBS == 2 {
let (limbs, bits) = (rhs / 64, rhs % 64);
if limbs >= 2 {
return (Self::ZERO, !self.const_is_zero());
}
let val = self.as_double_words()[0].get();
if limbs == 0 {
let carry = (val << (127 - bits)) << 1;
let shifted = val >> bits;
let mut r = Self::ZERO;
r.limbs[0] = shifted as u64;
r.limbs[1] = (shifted >> 64) as u64;
return (r, carry != 0);
}
let x = self.limbs[1];
let carry = (x << (63 - bits)) << 1;
let mut r = Self::ZERO;
r.limbs[0] = x >> bits;
return (r, carry != 0);
}

let (limbs, bits) = (rhs / 64, rhs % 64);
if limbs >= LIMBS {
return (Self::ZERO, !self.const_is_zero());
Expand Down
8 changes: 8 additions & 0 deletions src/cmp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@ impl<const BITS: usize, const LIMBS: usize> PartialOrd for Uint<BITS, LIMBS> {
impl<const BITS: usize, const LIMBS: usize> Ord for Uint<BITS, LIMBS> {
#[inline]
fn cmp(&self, rhs: &Self) -> Ordering {
if LIMBS == 1 {
return self.limbs[0].cmp(&rhs.limbs[0]);
}
if LIMBS == 2 {
return self.as_double_words()[0]
.get()
.cmp(&rhs.as_double_words()[0].get());
}
crate::algorithms::cmp(self.as_limbs(), rhs.as_limbs())
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/div.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
let q = &mut self.limbs[0];
let r = &mut rhs.limbs[0];
(*q, *r) = algorithms::div::div_1x1(*q, *r);
} else if LIMBS <= 4 {
algorithms::div::div_inlined(&mut self.limbs, &mut rhs.limbs);
} else {
Self::div_rem_by_ref(&mut self, &mut rhs);
}
Expand Down
56 changes: 52 additions & 4 deletions src/fmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,51 @@ impl<const BITS: usize, const LIMBS: usize> fmt::Debug for Uint<BITS, LIMBS> {
}

impl_fmt!(fmt::Display; base::Decimal, "");
impl_fmt!(fmt::Binary; base::Binary, "b");
impl_fmt!(fmt::Octal; base::Octal, "o");
impl_fmt!(fmt::LowerHex; base::Hexadecimal, "x");
impl_fmt!(fmt::UpperHex; base::Hexadecimal, "X");

macro_rules! impl_fmt_pow2 {
($tr:path; $base:ty, $bits_per_digit:literal, $upper:literal) => {
impl<const BITS: usize, const LIMBS: usize> $tr for Uint<BITS, LIMBS> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if let Ok(small) = u64::try_from(self) {
return <u64 as $tr>::fmt(&small, f);
}
if let Ok(small) = u128::try_from(self) {
return <u128 as $tr>::fmt(&small, f);
}

let alphabet: &[u8; 16] = if $upper {
b"0123456789ABCDEF"
} else {
b"0123456789abcdef"
};
let mask: u64 = (1 << $bits_per_digit) - 1;

let bit_len = self.bit_len();
let total_digits = bit_len.div_ceil($bits_per_digit);

let mut s = StackString::<BITS>::new();
let mut i = total_digits;
while i > 0 {
i -= 1;
let bit_offset = i * $bits_per_digit;
let limb_idx = bit_offset / 64;
let bit_idx = bit_offset % 64;
let mut digit = (self.limbs[limb_idx] >> bit_idx) & mask;
if bit_idx + $bits_per_digit > 64 && limb_idx + 1 < LIMBS {
digit |= (self.limbs[limb_idx + 1] << (64 - bit_idx)) & mask;
}
s.push_byte(alphabet[digit as usize]);
}
f.pad_integral(true, <$base>::PREFIX, s.as_str())
}
}
};
}

impl_fmt_pow2!(fmt::Binary; base::Binary, 1, false);
impl_fmt_pow2!(fmt::Octal; base::Octal, 3, false);
impl_fmt_pow2!(fmt::LowerHex; base::Hexadecimal, 4, false);
impl_fmt_pow2!(fmt::UpperHex; base::Hexadecimal, 4, true);

/// A stack-allocated buffer that implements [`fmt::Write`].
pub(crate) struct StackString<const SIZE: usize> {
Expand Down Expand Up @@ -115,6 +156,13 @@ impl<const SIZE: usize> StackString<SIZE> {
const fn as_bytes(&self) -> &[u8] {
unsafe { core::slice::from_raw_parts(self.buf.as_ptr().cast(), self.len) }
}

#[inline]
fn push_byte(&mut self, b: u8) {
debug_assert!(self.len < SIZE);
unsafe { self.buf.as_mut_ptr().add(self.len).cast::<u8>().write(b) };
self.len += 1;
}
}

impl<const SIZE: usize> fmt::Write for StackString<SIZE> {
Expand Down
6 changes: 5 additions & 1 deletion src/modular.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,11 @@ impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
// Reuse `div_rem` if we don't need an extra limb.
if const { crate::nlimbs(BITS + 1) == LIMBS } {
let numerator = unsafe { &mut *numerator.as_mut_ptr().cast::<Self>() };
Self::div_rem_by_ref(numerator, &mut modulus);
if LIMBS <= 4 {
algorithms::div::div_inlined(&mut numerator.limbs, &mut modulus.limbs);
} else {
Self::div_rem_by_ref(numerator, &mut modulus);
}
} else {
Self::div_rem_bits_plus_one(numerator.as_mut_ptr(), &mut modulus);
}
Expand Down
Loading
Loading