Skip to content

mul_mod for 2^31 < m < 2^32 #111

Open
@mizar

Description

@mizar

mul_mod seems easily be adapted to the case 2^31 < m < 2^32 by simply improving the last subtraction borrow check.

(current code: ac-library)

https://github.com/atcoder/ac-library/blob/6c88a70c8f95fef575af354900d107fbd0db1a12/atcoder/internal_math.hpp#L22-L62

(current code: ac-library-rs)

/// Fast modular by barrett reduction
/// Reference: https://en.wikipedia.org/wiki/Barrett_reduction
/// NOTE: reconsider after Ice Lake
pub(crate) struct Barrett {
pub(crate) _m: u32,
pub(crate) im: u64,
}
impl Barrett {
/// # Arguments
/// * `m` `1 <= m`
/// (Note: `m <= 2^31` should also hold, which is undocumented in the original library.
/// See the [pull reqeust commment](https://github.com/rust-lang-ja/ac-library-rs/pull/3#discussion_r484661007)
/// for more details.)
pub(crate) fn new(m: u32) -> Barrett {
Barrett {
_m: m,
im: (-1i64 as u64 / m as u64).wrapping_add(1),
}
}
/// # Returns
/// `m`
pub(crate) fn umod(&self) -> u32 {
self._m
}
/// # Parameters
/// * `a` `0 <= a < m`
/// * `b` `0 <= b < m`
///
/// # Returns
/// a * b % m
#[allow(clippy::many_single_char_names)]
pub(crate) fn mul(&self, a: u32, b: u32) -> u32 {
mul_mod(a, b, self._m, self.im)
}
}
/// Calculates `a * b % m`.
///
/// * `a` `0 <= a < m`
/// * `b` `0 <= b < m`
/// * `m` `1 <= m <= 2^31`
/// * `im` = ceil(2^64 / `m`)
#[allow(clippy::many_single_char_names)]
pub(crate) fn mul_mod(a: u32, b: u32, m: u32, im: u64) -> u32 {
// [1] m = 1
// a = b = im = 0, so okay
// [2] m >= 2
// im = ceil(2^64 / m)
// -> im * m = 2^64 + r (0 <= r < m)
// let z = a*b = c*m + d (0 <= c, d < m)
// a*b * im = (c*m + d) * im = c*(im*m) + d*im = c*2^64 + c*r + d*im
// c*r + d*im < m * m + m * im < m * m + 2^64 + m <= 2^64 + m * (m + 1) < 2^64 * 2
// ((ab * im) >> 64) == c or c + 1
let mut z = a as u64;
z *= b as u64;
let x = (((z as u128) * (im as u128)) >> 64) as u64;
let mut v = z.wrapping_sub(x.wrapping_mul(m as u64)) as u32;
if m <= v {
v = v.wrapping_add(m);
}
v
}

  • $m\in\text{自然数(natural number)}\mathbb{N},\quad 1\le m\lt 2^{32}$
  • $\lbrace a,b\rbrace\in\text{整数(integer)}\mathbb{Z},\quad 0\le \lbrace a, b\rbrace\lt m$
  • $\displaystyle\bar{m'} = \left\lfloor\frac{2^{64}-1}{m}\right\rfloor+1\mod2^{64}=\left\lceil\frac{2^{64}}{m}\right\rceil\mod 2^{64}$
  • $\displaystyle x=\left\lfloor\frac{ab\bar{m'}}{2^{64}}\right\rfloor$
  • $ab\mod m=ab-xm\quad(ab\ge xm)$
  • $ab\mod m=ab-xm+m\quad(ab\lt xm)$

(proof)

  1. when $m=1$, $a=b=\bar{m'}=0$, so okey
  2. when $2\le m\lt 2^{32}$,
    • $2^{32}+2=\left\lceil\frac{2^{64}}{2^{32}-1}\right\rceil\le\bar{m'}=\left\lceil\frac{2^{64}}{m}\right\rceil\le \left\lceil\frac{2^{64}}{2}\right\rceil=2^{63}$
    • $\bar{m'}\hspace{.1em}m=2^{64}+r\quad(0\le r\lt m)$
    • $z = ab = cm + d\quad(0\le\lbrace c,d\rbrace\lt m)$
    • $z\hspace{.1em}\bar{m'}=ab\hspace{.1em}\bar{m'}=(cm+d)\hspace{.1em}\bar{m'}=c(\bar{m'}\hspace{.1em}m)+d\hspace{.1em}\bar{m'}=2^{64}c+c\hspace{.1em}r+d\hspace{.1em}\bar{m'}$
    • $2^{64}c\le z\hspace{.1em}\bar{m'}\lt 2^{64}(c+2)$
      • $z\hspace{.1em}\bar{m'}=2^{64}c+c\hspace{.1em}r+d\hspace{.1em}\bar{m'}$
      • $0\le c\hspace{.1em}r\le (m-1)^2\le(2^{32}-2)^2=2^{64}-2^{34}+4$
      • $0\le d\hspace{.1em}\bar{m'}\le\bar{m'}\hspace{.1em}(m-1)=2^{64}+r-\bar{m'}\le 2^{64}+(2^{32}-2)-(2^{32}+2)=2^{64}-4$
    • $x=\left\lfloor\frac{ab\hspace{.1em}\bar{m'}}{2^{64}}\right\rfloor=\lbrace c$ or $(c+1)\rbrace$
    • $z-xm=ab-\left\lfloor\frac{ab\hspace{.1em}\bar{m'}}{2^{64}}\right\rfloor m=\lbrace d$ or $(d-m)\rbrace$

(C++: $1\le m\lt 2^{32}$ draft code)

https://godbolt.org/z/9Gz1oGrTa

#ifdef _MSC_VER
#include <intrin.h>
#endif

// @param a `0 <= a < m`
// @param b `0 <= b < m`
// @return `a * b % m`
unsigned int barrett_mul_before(unsigned int a, unsigned int b, unsigned int _m, unsigned long long im) {
    // [1] m = 1
    // a = b = im = 0, so okay

    // [2] m >= 2
    // im = ceil(2^64 / m)
    // -> im * m = 2^64 + r (0 <= r < m)
    // let z = a*b = c*m + d (0 <= c, d < m)
    // a*b * im = (c*m + d) * im = c*(im*m) + d*im = c*2^64 + c*r + d*im
    // c*r + d*im < m * m + m * im < m * m + 2^64 + m <= 2^64 + m * (m + 1) < 2^64 * 2
    // ((ab * im) >> 64) == c or c + 1
    unsigned long long z = a;
    z *= b;
#ifdef _MSC_VER
    unsigned long long x;
    _umul128(z, im, &x);
#else
    unsigned long long x =
        (unsigned long long)(((unsigned __int128)(z)*im) >> 64);
#endif
    unsigned int v = (unsigned int)(z - x * _m);
    if (_m <= v) v += _m;
    return v;
}

// @param a `0 <= a < m`
// @param b `0 <= b < m`
// @return `a * b % m`
unsigned int barrett_mul_after(unsigned int a, unsigned int b, unsigned int _m, unsigned long long im) {
    // [1] m = 1
    // a = b = im = 0, so okay

    // [2] m >= 2
    // im = ceil(2^64 / m)
    // -> im * m = 2^64 + r (0 <= r < m)
    // let z = a*b = c*m + d (0 <= c, d < m)
    // a*b * im = (c*m + d) * im = c*(im*m) + d*im = c*2^64 + c*r + d*im
    // c*r + d*im < m * m + m * im < m * m + 2^64 + m <= 2^64 + m * (m + 1) < 2^64 * 2
    // ((ab * im) >> 64) == c or c + 1
    unsigned long long z = a;
    z *= b;
#ifdef _MSC_VER
    unsigned long long x;
    _umul128(z, im, &x);
#else
    unsigned long long x =
        (unsigned long long)(((unsigned __int128)(z)*im) >> 64);
#endif
    unsigned long long y = x * _m;
    return (unsigned int)(z - y + (z < y ? _m : 0));
}

(Rust: $1\le m\lt 2^{32}$ draft code)

https://rust.godbolt.org/z/7P5rjahMn

/// Calculates `a * b % m`.
///
/// * `a` `0 <= a < m`
/// * `b` `0 <= b < m`
/// * `m` `1 <= m <= 2^31`
/// * `im` = ceil(2^64 / `m`)
#[allow(clippy::many_single_char_names)]
pub fn mul_mod_before(a: u32, b: u32, m: u32, im: u64) -> u32 {
    // [1] m = 1
    // a = b = im = 0, so okay

    // [2] m >= 2
    // im = ceil(2^64 / m)
    // -> im * m = 2^64 + r (0 <= r < m)
    // let z = a*b = c*m + d (0 <= c, d < m)
    // a*b * im = (c*m + d) * im = c*(im*m) + d*im = c*2^64 + c*r + d*im
    // c*r + d*im < m * m + m * im < m * m + 2^64 + m <= 2^64 + m * (m + 1) < 2^64 * 2
    // ((ab * im) >> 64) == c or c + 1
    let mut z = a as u64;
    z *= b as u64;
    let x = (((z as u128) * (im as u128)) >> 64) as u64;
    let mut v = z.wrapping_sub(x.wrapping_mul(m as u64)) as u32;
    if m <= v {
        v = v.wrapping_add(m);
    }
    v
}

/// Calculates `a * b % m`.
///
/// * `a` `0 <= a < m`
/// * `b` `0 <= b < m`
/// * `m` `1 <= m < 2^32`
/// * `im` = ceil(2^64 / `m`) = floor((2^64 - 1) / `m`) + 1
#[allow(clippy::many_single_char_names)]
pub fn mul_mod_after(a: u32, b: u32, m: u32, im: u64) -> u32 {
    // [1] m = 1
    // a = b = im = 0, so okay

    // [2] m >= 2
    // im = ceil(2^64 / m)
    // -> im * m = 2^64 + r (0 <= r < m)
    // let z = a*b = c*m + d (0 <= c, d < m)
    // a*b * im = (c*m + d) * im = c*(im*m) + d*im = c*2^64 + c*r + d*im
    // c*r + d*im < m * m + m * im < m * m + 2^64 + m <= 2^64 + m * (m + 1) < 2^64 * 2
    // ((ab * im) >> 64) == c or c + 1
    let z = (a as u64) * (b as u64);
    let x = (((z as u128) * (im as u128)) >> 64) as u64;
    match z.overflowing_sub(x.wrapping_mul(m as u64)) {
        (v, true) => (v as u32).wrapping_add(m),
        (v, false) => v as u32,
    }
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions