From 22d0756b68f4b7c9f7a68a06e39d06354009603e Mon Sep 17 00:00:00 2001 From: Diggory Hardy Date: Fri, 24 Mar 2023 12:29:15 +0000 Subject: [PATCH] Uniform sampling: use Canon's method, Lemire's method (#1287) Also: * Add uniform distribution benchmarks * Add "unbiased" feature flag * Fix feature simd_support * Uniform: impl PartialEq, Eq where possible * CI: benches now require small_rng; build-test unbiased --- .github/workflows/test.yml | 4 +- CHANGELOG.md | 1 + Cargo.toml | 11 +- benches/uniform.rs | 78 ++++++++++++++ src/distributions/uniform.rs | 193 +++++++++++++++++++---------------- src/seq/index.rs | 10 +- src/seq/mod.rs | 38 +++---- 7 files changed, 215 insertions(+), 120 deletions(-) create mode 100644 benches/uniform.rs diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index a35e82db3d..14639f24d3 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -79,13 +79,13 @@ jobs: run: | cargo test --target ${{ matrix.target }} --features=nightly cargo test --target ${{ matrix.target }} --all-features - cargo test --target ${{ matrix.target }} --benches --features=nightly + cargo test --target ${{ matrix.target }} --benches --features=small_rng,nightly cargo test --target ${{ matrix.target }} --manifest-path rand_distr/Cargo.toml --benches cargo test --target ${{ matrix.target }} --lib --tests --no-default-features - name: Test rand run: | cargo test --target ${{ matrix.target }} --lib --tests --no-default-features - cargo build --target ${{ matrix.target }} --no-default-features --features alloc,getrandom,small_rng + cargo build --target ${{ matrix.target }} --no-default-features --features alloc,getrandom,small_rng,unbiased cargo test --target ${{ matrix.target }} --lib --tests --no-default-features --features=alloc,getrandom,small_rng cargo test --target ${{ matrix.target }} --examples - name: Test rand (all stable features) diff --git a/CHANGELOG.md b/CHANGELOG.md index 083b7dbbeb..8706df8c63 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ You may also find the [Upgrade Guide](https://rust-random.github.io/book/update. ### Distributions - `{Uniform, UniformSampler}::{new, new_inclusive}` return a `Result` (instead of potentially panicking) (#1229) - `Uniform` implements `TryFrom` instead of `From` for ranges (#1229) +- `Uniform` now uses Canon's method (single sampling) / Lemire's method (distribution sampling) for faster sampling (breaks value stability; #1287) ### Other - Simpler and faster implementation of Floyd's F2 (#1277). This diff --git a/Cargo.toml b/Cargo.toml index ec0e4d7767..f85e126d73 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -51,6 +51,11 @@ std_rng = ["rand_chacha"] # Option: enable SmallRng small_rng = [] +# Option: use unbiased sampling for algorithms supporting this option: Uniform distribution. +# By default, bias affecting no more than one in 2^48 samples is accepted. +# Note: enabling this option is expected to affect reproducibility of results. +unbiased = [] + [workspace] members = [ "rand_core", @@ -76,6 +81,10 @@ bincode = "1.2.1" rayon = "1.5.3" criterion = { version = "0.4" } +[[bench]] +name = "uniform" +harness = false + [[bench]] name = "seq_choose" path = "benches/seq_choose.rs" @@ -84,4 +93,4 @@ harness = false [[bench]] name = "shuffle" path = "benches/shuffle.rs" -harness = false \ No newline at end of file +harness = false diff --git a/benches/uniform.rs b/benches/uniform.rs new file mode 100644 index 0000000000..d0128d5a48 --- /dev/null +++ b/benches/uniform.rs @@ -0,0 +1,78 @@ +// Copyright 2021 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! Implement benchmarks for uniform distributions over integer types + +use core::time::Duration; +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use rand::distributions::uniform::{SampleRange, Uniform}; +use rand::prelude::*; +use rand_chacha::ChaCha8Rng; +use rand_pcg::{Pcg32, Pcg64}; + +const WARM_UP_TIME: Duration = Duration::from_millis(1000); +const MEASUREMENT_TIME: Duration = Duration::from_secs(3); +const SAMPLE_SIZE: usize = 100_000; +const N_RESAMPLES: usize = 10_000; + +macro_rules! sample { + ($R:ty, $T:ty, $U:ty, $g:expr) => { + $g.bench_function(BenchmarkId::new(stringify!($R), "single"), |b| { + let mut rng = <$R>::from_entropy(); + let x = rng.gen::<$U>(); + let bits = (<$T>::BITS / 2); + let mask = (1 as $U).wrapping_neg() >> bits; + let range = (x >> bits) * (x & mask); + let low = <$T>::MIN; + let high = low.wrapping_add(range as $T); + + b.iter(|| (low..=high).sample_single(&mut rng)); + }); + + $g.bench_function(BenchmarkId::new(stringify!($R), "distr"), |b| { + let mut rng = <$R>::from_entropy(); + let x = rng.gen::<$U>(); + let bits = (<$T>::BITS / 2); + let mask = (1 as $U).wrapping_neg() >> bits; + let range = (x >> bits) * (x & mask); + let low = <$T>::MIN; + let high = low.wrapping_add(range as $T); + let dist = Uniform::<$T>::new_inclusive(<$T>::MIN, high).unwrap(); + + b.iter(|| dist.sample(&mut rng)); + }); + }; + + ($c:expr, $T:ty, $U:ty) => {{ + let mut g = $c.benchmark_group(concat!("sample", stringify!($T))); + g.sample_size(SAMPLE_SIZE); + g.warm_up_time(WARM_UP_TIME); + g.measurement_time(MEASUREMENT_TIME); + g.nresamples(N_RESAMPLES); + sample!(SmallRng, $T, $U, g); + sample!(ChaCha8Rng, $T, $U, g); + sample!(Pcg32, $T, $U, g); + sample!(Pcg64, $T, $U, g); + g.finish(); + }}; +} + +fn sample(c: &mut Criterion) { + sample!(c, i8, u8); + sample!(c, i16, u16); + sample!(c, i32, u32); + sample!(c, i64, u64); + sample!(c, i128, u128); +} + +criterion_group! { + name = benches; + config = Criterion::default(); + targets = sample +} +criterion_main!(benches); diff --git a/src/distributions/uniform.rs b/src/distributions/uniform.rs index b4856ff613..326d1ed4e8 100644 --- a/src/distributions/uniform.rs +++ b/src/distributions/uniform.rs @@ -195,7 +195,7 @@ use serde::{Serialize, Deserialize}; /// [`new`]: Uniform::new /// [`new_inclusive`]: Uniform::new_inclusive /// [`Rng::gen_range`]: Rng::gen_range -#[derive(Clone, Copy, Debug, PartialEq)] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde1", serde(bound(serialize = "X::Sampler: Serialize")))] #[cfg_attr(feature = "serde1", serde(bound(deserialize = "X::Sampler: Deserialize<'de>")))] @@ -444,21 +444,21 @@ impl SampleRange for RangeInclusive { /// use `u32` for our `zone` and samples (because it's not slower and because /// it reduces the chance of having to reject a sample). In this case we cannot /// store `zone` in the target type since it is too large, however we know -/// `ints_to_reject < range <= $unsigned::MAX`. +/// `ints_to_reject < range <= $uty::MAX`. /// /// An alternative to using a modulus is widening multiply: After a widening /// multiply by `range`, the result is in the high word. Then comparing the low /// word against `zone` makes sure our distribution is uniform. -#[derive(Clone, Copy, Debug, PartialEq)] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] pub struct UniformInt { low: X, range: X, - z: X, // either ints_to_reject or zone depending on implementation + thresh: X, // effectively 2.pow(max(64, uty_bits)) % range } macro_rules! uniform_int_impl { - ($ty:ty, $unsigned:ident, $u_large:ident) => { + ($ty:ty, $uty:ty, $sample_ty:ident) => { impl SampleUniform for $ty { type Sampler = UniformInt<$ty>; } @@ -466,7 +466,7 @@ macro_rules! uniform_int_impl { impl UniformSampler for UniformInt<$ty> { // We play free and fast with unsigned vs signed here // (when $ty is signed), but that's fine, since the - // contract of this macro is for $ty and $unsigned to be + // contract of this macro is for $ty and $uty to be // "bit-equal", so casting between them is a no-op. type X = $ty; @@ -498,41 +498,38 @@ macro_rules! uniform_int_impl { if !(low <= high) { return Err(Error::EmptyRange); } - let unsigned_max = ::core::$u_large::MAX; - let range = high.wrapping_sub(low).wrapping_add(1) as $unsigned; - let ints_to_reject = if range > 0 { - let range = $u_large::from(range); - (unsigned_max - range + 1) % range + let range = high.wrapping_sub(low).wrapping_add(1) as $uty; + let thresh = if range > 0 { + let range = $sample_ty::from(range); + (range.wrapping_neg() % range) } else { 0 }; Ok(UniformInt { low, - // These are really $unsigned values, but store as $ty: - range: range as $ty, - z: ints_to_reject as $unsigned as $ty, + range: range as $ty, // type: $uty + thresh: thresh as $uty as $ty, // type: $sample_ty }) } + /// Sample from distribution, Lemire's method, unbiased #[inline] fn sample(&self, rng: &mut R) -> Self::X { - let range = self.range as $unsigned as $u_large; - if range > 0 { - let unsigned_max = ::core::$u_large::MAX; - let zone = unsigned_max - (self.z as $unsigned as $u_large); - loop { - let v: $u_large = rng.gen(); - let (hi, lo) = v.wmul(range); - if lo <= zone { - return self.low.wrapping_add(hi as $ty); - } - } - } else { - // Sample from the entire integer range. - rng.gen() + let range = self.range as $uty as $sample_ty; + if range == 0 { + return rng.gen(); } + + let thresh = self.thresh as $uty as $sample_ty; + let hi = loop { + let (hi, lo) = rng.gen::<$sample_ty>().wmul(range); + if lo >= thresh { + break hi; + } + }; + self.low.wrapping_add(hi as $ty) } #[inline] @@ -549,8 +546,15 @@ macro_rules! uniform_int_impl { Self::sample_single_inclusive(low, high - 1, rng) } + /// Sample single value, Canon's method, biased + /// + /// In the worst case, bias affects 1 in `2^n` samples where n is + /// 56 (`i8`), 48 (`i16`), 96 (`i32`), 64 (`i64`), 128 (`i128`). + #[cfg(not(feature = "unbiased"))] #[inline] - fn sample_single_inclusive(low_b: B1, high_b: B2, rng: &mut R) -> Result + fn sample_single_inclusive( + low_b: B1, high_b: B2, rng: &mut R, + ) -> Result where B1: SampleBorrow + Sized, B2: SampleBorrow + Sized, @@ -560,33 +564,72 @@ macro_rules! uniform_int_impl { if !(low <= high) { return Err(Error::EmptyRange); } - let range = high.wrapping_sub(low).wrapping_add(1) as $unsigned as $u_large; - // If the above resulted in wrap-around to 0, the range is $ty::MIN..=$ty::MAX, - // and any integer will do. + let range = high.wrapping_sub(low).wrapping_add(1) as $uty as $sample_ty; if range == 0 { + // Range is MAX+1 (unrepresentable), so we need a special case return Ok(rng.gen()); } - let zone = if ::core::$unsigned::MAX <= ::core::u16::MAX as $unsigned { - // Using a modulus is faster than the approximation for - // i8 and i16. I suppose we trade the cost of one - // modulus for near-perfect branch prediction. - let unsigned_max: $u_large = ::core::$u_large::MAX; - let ints_to_reject = (unsigned_max - range + 1) % range; - unsigned_max - ints_to_reject - } else { - // conservative but fast approximation. `- 1` is necessary to allow the - // same comparison without bias. - (range << range.leading_zeros()).wrapping_sub(1) - }; + // generate a sample using a sensible integer type + let (mut result, lo_order) = rng.gen::<$sample_ty>().wmul(range); - loop { - let v: $u_large = rng.gen(); - let (hi, lo) = v.wmul(range); - if lo <= zone { - return Ok(low.wrapping_add(hi as $ty)); + // if the sample is biased... + if lo_order > range.wrapping_neg() { + // ...generate a new sample to reduce bias... + let (new_hi_order, _) = (rng.gen::<$sample_ty>()).wmul(range as $sample_ty); + // ... incrementing result on overflow + let is_overflow = lo_order.checked_add(new_hi_order as $sample_ty).is_none(); + result += is_overflow as $sample_ty; + } + + Ok(low.wrapping_add(result as $ty)) + } + + /// Sample single value, Canon's method, unbiased + #[cfg(feature = "unbiased")] + #[inline] + fn sample_single_inclusive( + low_b: B1, high_b: B2, rng: &mut R, + ) -> Result + where + B1: SampleBorrow<$ty> + Sized, + B2: SampleBorrow<$ty> + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + if !(low <= high) { + return Err(Error::EmptyRange); + } + let range = high.wrapping_sub(low).wrapping_add(1) as $uty as $sample_ty; + if range == 0 { + // Range is MAX+1 (unrepresentable), so we need a special case + return Ok(rng.gen()); + } + + let (mut result, mut lo) = rng.gen::<$sample_ty>().wmul(range); + + // In constrast to the biased sampler, we use a loop: + while lo > range.wrapping_neg() { + let (new_hi, new_lo) = (rng.gen::<$sample_ty>()).wmul(range); + match lo.checked_add(new_hi) { + Some(x) if x < $sample_ty::MAX => { + // Anything less than MAX: last term is 0 + break; + } + None => { + // Overflow: last term is 1 + result += 1; + break; + } + _ => { + // Unlikely case: must check next sample + lo = new_lo; + continue; + } } } + + Ok(low.wrapping_add(result as $ty)) } } }; @@ -668,22 +711,22 @@ macro_rules! uniform_simd_int_impl { // with bitwise OR let modulo = not_full_range.select(range, unsigned_max); // wrapping addition - let ints_to_reject = (unsigned_max - range + Simd::splat(1)) % modulo; + // TODO: replace with `range.wrapping_neg() % module` when Simd supports this. + let ints_to_reject = (Simd::splat(0) - range) % modulo; // When `range` is 0, `lo` of `v.wmul(range)` will always be // zero which means only one sample is needed. - let zone = unsigned_max - ints_to_reject; Ok(UniformInt { low, // These are really $unsigned values, but store as $ty: range: range.cast(), - z: zone.cast(), + thresh: ints_to_reject.cast(), }) } fn sample(&self, rng: &mut R) -> Self::X { let range: Simd<$unsigned, LANES> = self.range.cast(); - let zone: Simd<$unsigned, LANES> = self.z.cast(); + let thresh: Simd<$unsigned, LANES> = self.thresh.cast(); // This might seem very slow, generating a whole new // SIMD vector for every sample rejection. For most uses @@ -697,7 +740,7 @@ macro_rules! uniform_simd_int_impl { let mut v: Simd<$unsigned, LANES> = rng.gen(); loop { let (hi, lo) = v.wmul(range); - let mask = lo.simd_le(zone); + let mask = lo.simd_ge(thresh); if mask.all() { let hi: Simd<$ty, LANES> = hi.cast(); // wrapping addition @@ -740,7 +783,7 @@ impl SampleUniform for char { /// are used for surrogate pairs in UCS and UTF-16, and consequently are not /// valid Unicode code points. We must therefore avoid sampling values in this /// range. -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] pub struct UniformChar { sampler: UniformInt, @@ -1023,14 +1066,14 @@ uniform_float_impl! { f64x8, u64x8, f64, u64, 64 - 52 } /// /// Unless you are implementing [`UniformSampler`] for your own types, this type /// should not be used directly, use [`Uniform`] instead. -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] pub struct UniformDuration { mode: UniformDurationMode, offset: u32, } -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Copy, Clone, PartialEq, Eq)] #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] enum UniformDurationMode { Small { @@ -1162,32 +1205,7 @@ mod tests { fn test_serialization_uniform_duration() { let distr = UniformDuration::new(Duration::from_secs(10), Duration::from_secs(60)).unwrap(); let de_distr: UniformDuration = bincode::deserialize(&bincode::serialize(&distr).unwrap()).unwrap(); - assert_eq!( - distr.offset, de_distr.offset - ); - match (distr.mode, de_distr.mode) { - (UniformDurationMode::Small {secs: a_secs, nanos: a_nanos}, UniformDurationMode::Small {secs, nanos}) => { - assert_eq!(a_secs, secs); - - assert_eq!(a_nanos.0.low, nanos.0.low); - assert_eq!(a_nanos.0.range, nanos.0.range); - assert_eq!(a_nanos.0.z, nanos.0.z); - } - (UniformDurationMode::Medium {nanos: a_nanos} , UniformDurationMode::Medium {nanos}) => { - assert_eq!(a_nanos.0.low, nanos.0.low); - assert_eq!(a_nanos.0.range, nanos.0.range); - assert_eq!(a_nanos.0.z, nanos.0.z); - } - (UniformDurationMode::Large {max_secs:a_max_secs, max_nanos:a_max_nanos, secs:a_secs}, UniformDurationMode::Large {max_secs, max_nanos, secs} ) => { - assert_eq!(a_max_secs, max_secs); - assert_eq!(a_max_nanos, max_nanos); - - assert_eq!(a_secs.0.low, secs.0.low); - assert_eq!(a_secs.0.range, secs.0.range); - assert_eq!(a_secs.0.z, secs.0.z); - } - _ => panic!("`UniformDurationMode` was not serialized/deserialized correctly") - } + assert_eq!(distr, de_distr); } #[test] @@ -1195,16 +1213,11 @@ mod tests { fn test_uniform_serialization() { let unit_box: Uniform = Uniform::new(-1, 1).unwrap(); let de_unit_box: Uniform = bincode::deserialize(&bincode::serialize(&unit_box).unwrap()).unwrap(); - - assert_eq!(unit_box.0.low, de_unit_box.0.low); - assert_eq!(unit_box.0.range, de_unit_box.0.range); - assert_eq!(unit_box.0.z, de_unit_box.0.z); + assert_eq!(unit_box.0, de_unit_box.0); let unit_box: Uniform = Uniform::new(-1., 1.).unwrap(); let de_unit_box: Uniform = bincode::deserialize(&bincode::serialize(&unit_box).unwrap()).unwrap(); - - assert_eq!(unit_box.0.low, de_unit_box.0.low); - assert_eq!(unit_box.0.scale, de_unit_box.0.scale); + assert_eq!(unit_box.0, de_unit_box.0); } #[test] diff --git a/src/seq/index.rs b/src/seq/index.rs index 50523cc47c..f29f72e172 100644 --- a/src/seq/index.rs +++ b/src/seq/index.rs @@ -613,11 +613,11 @@ mod test { ); }; - do_test(10, 6, &[8, 3, 5, 9, 0, 6]); // floyd - do_test(25, 10, &[18, 14, 9, 15, 0, 13, 5, 24]); // floyd - do_test(300, 8, &[30, 283, 150, 1, 73, 13, 285, 35]); // floyd - do_test(300, 80, &[31, 289, 248, 154, 5, 78, 19, 286]); // inplace - do_test(300, 180, &[31, 289, 248, 154, 5, 78, 19, 286]); // inplace + do_test(10, 6, &[0, 9, 5, 4, 6, 8]); // floyd + do_test(25, 10, &[24, 20, 19, 9, 22, 16, 0, 14]); // floyd + do_test(300, 8, &[30, 283, 243, 150, 218, 240, 1, 189]); // floyd + do_test(300, 80, &[31, 289, 248, 154, 221, 243, 7, 192]); // inplace + do_test(300, 180, &[31, 289, 248, 154, 221, 243, 7, 192]); // inplace do_test(1_000_000, 8, &[ 103717, 963485, 826422, 509101, 736394, 807035, 5327, 632573, diff --git a/src/seq/mod.rs b/src/seq/mod.rs index d9b38e920d..f4605b5775 100644 --- a/src/seq/mod.rs +++ b/src/seq/mod.rs @@ -762,7 +762,7 @@ mod test { let mut nums = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; assert_eq!(chars.choose(&mut r), Some(&'l')); - assert_eq!(nums.choose_mut(&mut r), Some(&mut 10)); + assert_eq!(nums.choose_mut(&mut r), Some(&mut 3)); #[cfg(feature = "alloc")] assert_eq!( @@ -770,13 +770,13 @@ mod test { .choose_multiple(&mut r, 8) .cloned() .collect::>(), - &['d', 'm', 'n', 'k', 'h', 'e', 'b', 'c'] + &['f', 'i', 'd', 'b', 'c', 'm', 'j', 'k'] ); #[cfg(feature = "alloc")] - assert_eq!(chars.choose_weighted(&mut r, |_| 1), Ok(&'f')); + assert_eq!(chars.choose_weighted(&mut r, |_| 1), Ok(&'l')); #[cfg(feature = "alloc")] - assert_eq!(nums.choose_weighted_mut(&mut r, |_| 1), Ok(&mut 5)); + assert_eq!(nums.choose_weighted_mut(&mut r, |_| 1), Ok(&mut 8)); let mut r = crate::test::rng(414); nums.shuffle(&mut r); @@ -1221,7 +1221,7 @@ mod test { chunk_remaining: 32, hint_total_size: false, }), - Some(39) + Some(91) ); assert_eq!( choose(ChunkHintedIterator { @@ -1230,7 +1230,7 @@ mod test { chunk_remaining: 32, hint_total_size: true, }), - Some(39) + Some(91) ); assert_eq!( choose(WindowHintedIterator { @@ -1238,7 +1238,7 @@ mod test { window_size: 32, hint_total_size: false, }), - Some(90) + Some(34) ); assert_eq!( choose(WindowHintedIterator { @@ -1246,7 +1246,7 @@ mod test { window_size: 32, hint_total_size: true, }), - Some(90) + Some(34) ); } @@ -1298,28 +1298,22 @@ mod test { #[test] fn value_stability_choose_multiple() { - fn do_test>(iter: I, v: &[u32]) { + fn do_test>(iter: I, v: &[u32]) { let mut rng = crate::test::rng(412); let mut buf = [0u32; 8]; - assert_eq!(iter.choose_multiple_fill(&mut rng, &mut buf), v.len()); + assert_eq!(iter.clone().choose_multiple_fill(&mut rng, &mut buf), v.len()); assert_eq!(&buf[0..v.len()], v); - } - do_test(0..4, &[0, 1, 2, 3]); - do_test(0..8, &[0, 1, 2, 3, 4, 5, 6, 7]); - do_test(0..100, &[58, 78, 80, 92, 43, 8, 96, 7]); - - #[cfg(feature = "alloc")] - { - fn do_test>(iter: I, v: &[u32]) { + #[cfg(feature = "alloc")] + { let mut rng = crate::test::rng(412); assert_eq!(iter.choose_multiple(&mut rng, v.len()), v); } - - do_test(0..4, &[0, 1, 2, 3]); - do_test(0..8, &[0, 1, 2, 3, 4, 5, 6, 7]); - do_test(0..100, &[58, 78, 80, 92, 43, 8, 96, 7]); } + + do_test(0..4, &[0, 1, 2, 3]); + do_test(0..8, &[0, 1, 2, 3, 4, 5, 6, 7]); + do_test(0..100, &[77, 95, 38, 23, 25, 8, 58, 40]); } #[test]