Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Changed

- Optimize `from_str_radix` ([#557])

[#557]: https://github.com/recmo/uint/pull/557

## [1.17.2] - 2025-12-28

### Fixed
Expand Down
205 changes: 177 additions & 28 deletions src/string.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#![allow(clippy::missing_inline_in_public_items)] // allow format functions

use crate::{Uint, base_convert::BaseConvertError};
use crate::{Uint, algorithms::DoubleWord, base_convert::BaseConvertError};
use core::{fmt, str::FromStr};

/// Error for [`from_str_radix`](Uint::from_str_radix).
Expand Down Expand Up @@ -44,6 +44,42 @@ impl fmt::Display for ParseError {
}
}

/// Returns `(base, power)` where `base = radix^power` is the largest power of
/// `radix` that fits in a `u64`.
const fn radix_base(radix: u64) -> (u64, usize) {
debug_assert!(radix >= 2);
let mut power: usize = 1;
let mut base = radix;
loop {
match base.checked_mul(radix) {
Some(n) => {
base = n;
power += 1;
}
None => return (base, power),
}
}
}

/// Decode an ASCII byte as a digit for radix <= 36.
/// Case-insensitive 0-9, a-z. Underscores are skipped.
#[inline(always)]
fn decode_digit(b: u8, radix: u64) -> Result<Option<u64>, ParseError> {
let digit = match b {
b'0'..=b'9' => b - b'0',
b'a'..=b'z' => b - b'a' + 10,
b'A'..=b'Z' => b - b'A' + 10,
b'_' => return Ok(None),
_ => return Err(ParseError::InvalidDigit(b as char)),
};
let digit = u64::from(digit);
if digit < radix {
Ok(Some(digit))
} else {
Err(ParseError::InvalidDigit(b as char))
}
}

impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
/// Parse a string into a [`Uint`].
///
Expand All @@ -59,47 +95,131 @@ impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
/// * [`ParseError::BaseConvertError`] if [`Uint::from_base_be`] fails.
// FEATURE: Support proper unicode. Ignore zero-width spaces, joiners, etc.
// Recognize digits from other alphabets.
#[inline]
pub fn from_str_radix(src: &str, radix: u64) -> Result<Self, ParseError> {
if radix > 64 {
return Err(ParseError::InvalidRadix(radix));
match radix {
// Specialize for the common cases.
2 => Self::from_str_radix_pow2(src, 2),
8 => Self::from_str_radix_pow2(src, 8),
10 => Self::from_str_radix_chunked(src, 10),
16 => Self::from_str_radix_pow2(src, 16),

65.. => Err(ParseError::InvalidRadix(radix)),
37.. => Self::from_str_radix_slow(src, radix),
r if r.is_power_of_two() => Self::from_str_radix_pow2(src, radix),
_ => Self::from_str_radix_chunked(src, radix),
}
}

/// Fallback for radix > 36 (base-64 alphabet). Not perf-critical.
#[cold]
fn from_str_radix_slow(src: &str, radix: u64) -> Result<Self, ParseError> {
let mut err = None;
let digits = src.chars().filter_map(|c| {
if err.is_some() {
return None;
}
let digit = if radix <= 36 {
// Case insensitive 0—9, a—z.
match c {
'0'..='9' => u64::from(c) - u64::from('0'),
'a'..='z' => u64::from(c) - u64::from('a') + 10,
'A'..='Z' => u64::from(c) - u64::from('A') + 10,
'_' => return None, // Ignored character.
_ => {
err = Some(ParseError::InvalidDigit(c));
return None;
}
}
} else {
// The Base-64 alphabets
match c {
'A'..='Z' => u64::from(c) - u64::from('A'),
'a'..='f' => u64::from(c) - u64::from('a') + 26,
'0'..='9' => u64::from(c) - u64::from('0') + 52,
'+' | '-' => 62,
'/' | ',' | '_' => 63,
'=' | '\r' | '\n' => return None, // Ignored characters.
_ => {
err = Some(ParseError::InvalidDigit(c));
return None;
}
let digit = match c {
'A'..='Z' => u64::from(c) - u64::from('A'),
'a'..='f' => u64::from(c) - u64::from('a') + 26,
'0'..='9' => u64::from(c) - u64::from('0') + 52,
'+' | '-' => 62,
'/' | ',' | '_' => 63,
'=' | '\r' | '\n' => return None,
_ => {
err = Some(ParseError::InvalidDigit(c));
return None;
}
};
Some(digit)
});
let value = Self::from_base_be(radix, digits)?;
err.map_or(Ok(value), Err)
}

/// Power-of-2 radix: shift digits directly into limbs, no multiplication.
#[inline]
fn from_str_radix_pow2(src: &str, radix: u64) -> Result<Self, ParseError> {
debug_assert!(radix.is_power_of_two());
let bits_per_digit = radix.trailing_zeros() as usize;
let mut result = Self::ZERO;
let mut total_bits = 0usize;
for &b in src.as_bytes().iter().rev() {
let digit = match decode_digit(b, radix) {
Ok(None) => continue,
Ok(Some(d)) => d,
Err(e) => return Err(e),
};
if total_bits >= BITS {
if digit != 0 {
return Err(BaseConvertError::Overflow.into());
}
continue;
}
let limb_idx = total_bits / 64;
let bit_idx = total_bits % 64;
result.limbs[limb_idx] |= digit << bit_idx;
if bit_idx + bits_per_digit > 64 {
let hi = digit >> (64 - bit_idx);
if limb_idx + 1 < LIMBS {
result.limbs[limb_idx + 1] |= hi;
} else if hi != 0 {
return Err(BaseConvertError::Overflow.into());
}
}
total_bits += bits_per_digit;
}
if LIMBS > 0 && result.limbs[LIMBS - 1] > Self::MASK {
return Err(BaseConvertError::Overflow.into());
}
Ok(result)
}

/// Non-power-of-2 radix: accumulate chunks of digits into a u64, then do
/// one widening multiply per chunk instead of per digit.
#[allow(clippy::cast_possible_truncation)]
#[inline]
fn from_str_radix_chunked(src: &str, radix: u64) -> Result<Self, ParseError> {
let (base, power) = radix_base(radix);
let mut result = Self::ZERO;
let mut chunk_val: u64 = 0;
let mut chunk_digits: usize = 0;
for &b in src.as_bytes() {
let digit = match decode_digit(b, radix) {
Ok(None) => continue,
Ok(Some(d)) => d,
Err(e) => return Err(e),
};
chunk_val = chunk_val * radix + digit;
chunk_digits += 1;
if chunk_digits == power {
Self::muladd_limbs(&mut result.limbs, base, chunk_val)?;
chunk_val = 0;
chunk_digits = 0;
}
}
if chunk_digits > 0 {
let mut tail_base = radix;
for _ in 1..chunk_digits {
tail_base *= radix;
}
Self::muladd_limbs(&mut result.limbs, tail_base, chunk_val)?;
}
Ok(result)
}

/// `limbs = limbs * factor + addend`, returning overflow error.
#[inline(always)]
fn muladd_limbs(limbs: &mut [u64; LIMBS], factor: u64, addend: u64) -> Result<(), ParseError> {
let mut carry = addend;
for limb in limbs.iter_mut() {
(*limb, carry) = u128::muladd(*limb, factor, carry).split();
}
if carry > 0 || (LIMBS != 0 && limbs[LIMBS - 1] > Self::MASK) {
return Err(BaseConvertError::Overflow.into());
}
Ok(())
}
}

impl<const BITS: usize, const LIMBS: usize> FromStr for Uint<BITS, LIMBS> {
Expand All @@ -125,6 +245,35 @@ mod tests {
use super::*;
use proptest::{prop_assert_eq, proptest};

#[test]
fn test_pow2_overflow() {
type U8 = Uint<8, 1>;
assert_eq!(U8::from_str("0xff"), Ok(U8::from(255)));
assert_eq!(
U8::from_str("0x1ff"),
Err(ParseError::BaseConvertError(BaseConvertError::Overflow))
);
assert_eq!(
U8::from_str("0x100"),
Err(ParseError::BaseConvertError(BaseConvertError::Overflow))
);

type U7 = Uint<7, 1>;
assert_eq!(U7::from_str("0x7f"), Ok(U7::from(127)));
assert_eq!(
U7::from_str("0xff"),
Err(ParseError::BaseConvertError(BaseConvertError::Overflow))
);

// Octal: 0o777 = 511, which overflows U8 (max 255).
assert_eq!(
U8::from_str("0o777"),
Err(ParseError::BaseConvertError(BaseConvertError::Overflow))
);
// Octal: 0o377 = 255, fits U8.
assert_eq!(U8::from_str("0o377"), Ok(U8::from(255)));
}

#[test]
fn test_parse() {
proptest!(|(value: u128)| {
Expand Down
Loading