Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rand_distr: Fix dirichlet sample method for small alpha. #1209

Merged
merged 9 commits into from
May 1, 2023

Conversation

WarrenWeckesser
Copy link
Collaborator

@WarrenWeckesser WarrenWeckesser commented Jan 2, 2022

Generating Dirichlet samples using the method based on samples from the gamma distribution can result in samples being nan if all the values in alpha are sufficiently small. The fix is to instead use the method based on the marginal distributions being the beta distribution (i.e. the "stick breaking" method) when all values in alpha are small.

More details:

Here's an example where the current method produces nans:

use rand::distributions::Distribution;
use rand_distr::Dirichlet;

fn main() {
    let n = 1000;
    println!("Checking {} Dirichlet samples", n);

    let dirichlet = Dirichlet::new(&[0.001, 0.001, 0.001]).unwrap();
    let mut r = rand::thread_rng();
    let mut nancount = 0;
    for _ in 0..n {
        let sample: Vec<f64> = dirichlet.sample(&mut r);
        if sample.iter().any(|x| x.is_nan()) {
            nancount += 1;
        }
    }
    println!("nancount: {}", nancount);
}

Typical output:

Checking 1000 Dirichlet samples
nancount: 128

The method based on samples from the gamma distribution is described in the wikipedia page on the Dirichlet distribution. The problem is that when alpha is small, there is a high probability that the gamma random variate will be 0 (given that we are limited to finite precision floating point). In fact, we can predict the above result: with alpha=0.001, the probability that a gamma variate G(alpha, 1) will be less than 2.225e-308 (approximately the smallest normal 64 bit floating point value) is approximately 0.4927. For the Dirichlet distribution with vector parameter [0.001, 0.001, 0.001], the problem occurs when all three gamma variates are 0, which has probabilty (0.4927)**3 = 0.1196. So when generating 1000 samples, we expect about 120 to contain nan.

A way to avoid this problem is to switch to the less efficient method based on the marginal beta distributions of the Dirichlet distribution. This method is also described on the wikipedia page. In this PR, this method is used when all the alpha values are less than 0.1. This threshold was discussed in the NumPy PR, where it seemed like a reasonable compromise between (i) using the gamma variate method for as wide a range as possible and (ii) ensuring that the probability of generating nans is negligibly small.

@WarrenWeckesser
Copy link
Collaborator Author

FYI: A similar change was made to numpy in numpy/numpy#14924

Generating Dirichlet samples using the method based on samples from
the gamma distribution can result in samples being nan if all the
values in alpha are sufficiently small.  The fix is to instead use
the method based on the marginal distributions being the beta
distribution (i.e. the "stick breaking" method) when all values in
alpha are small.
@dhardy
Copy link
Member

dhardy commented Jan 10, 2022

Thanks for the PR. At a glance this looks good but I'd prefer it have a proper review from someone besides the author (doesn't have to be myself or an existing contributor, if anyone is interested); I didn't find the time yet myself.

@dhardy dhardy added the D-review Do: needs review label Jan 10, 2022
@WarrenWeckesser
Copy link
Collaborator Author

For anyone who might review this: I added some more details to the description above.

rand_distr/src/dirichlet.rs Outdated Show resolved Hide resolved
rand_distr/src/dirichlet.rs Outdated Show resolved Hide resolved
rand_distr/src/dirichlet.rs Outdated Show resolved Hide resolved
Copy link
Collaborator

@vks vks left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest to move the initialization of the new algorithm for small alpha from sample to new. This probably requires making Dirichlet an enum similar to this:

struct DirichletFromGamma<F> { alpha: Box<[F]> }
struct DirichletFromBeta<F> { alpha: Box<[F]>, alpha_sum_r1: Box<[F]> }

pub enum Dirichlet<F> where ... {
    FromGamma(DirichletFromGamma<F>), FromBeta(DirichletFromBeta<F>)
}

Alternatively, we could use this opportunity to refactor that even more initialization can be moved out of sample by switching to something akin to the following:

struct DirichletFromGamma<F> { samplers: Box<[Gamma<F>]> }
struct DirichletFromBeta<F> { samplers: Box<[Beta<F>]> }

@vks
Copy link
Collaborator

vks commented Mar 30, 2022

The last suggestion however would increase storage by a factor of ca. 4, so it would have to be justified with benchmarks.

@WarrenWeckesser
Copy link
Collaborator Author

@vks, thanks for the review. I haven't forgotten about this! A couple other projects moved to the top of my to-do list, but I will get back to it.

@dhardy dhardy added the D-changes Do: changes requested label Nov 9, 2022
@dhardy
Copy link
Member

dhardy commented Feb 20, 2023

@WarrenWeckesser can I remind you of this PR?

@WarrenWeckesser
Copy link
Collaborator Author

@dhardy, yes you can! It has been slowly bubbling back up my "to do" list. I'll get back to it this week.

* Create a struct for each random variate generation method:
  DirichletFromGamma and DirichletFromBeta.
* Move the initialization of the data required by the generators
  into the new() method.
* Make the main Dirichlet object an enum.
@WarrenWeckesser
Copy link
Collaborator Author

I pushed an update that implements my interpretation of the comments made by @vks last year. There are now two structs that implement the two different methods, and Dirichlet is an enum with one variant for each type of struct. Dirichlet::new() returns the variant that is appropriate for the values in the input alpha. The initialization of the underlying data that is needed in the sample() method of each struct has been moved to the new() method of the struct.

This might not be what was intended, and I suspect my Rust code is not idiomatic in a lot of places, so I still view this as a rough draft that will need more iteration.

@WarrenWeckesser
Copy link
Collaborator Author

I implemented Dirichlet as an enum, and implemented new(), new_with_size() and sample() for the enum. Over in gamma.rs, I see that Gamma and ChiSquared have a pattern similar to Dirichlet, where the sampling algorithm depends on the values of the parameters. Each of those distributions is implemented as a struct containing a single field called repr, and it is repr that is the enum, rather than the distribution itself. Would that be preferable for Dirichlet?

@WarrenWeckesser
Copy link
Collaborator Author

... and it is repr that is the enum, rather than the distribution itself. Would that be preferable for Dirichlet?

Answering my own question: yes, I think making Dirichlet a struct with a field that is an enum is preferable. I'll push an update soon.

@WarrenWeckesser
Copy link
Collaborator Author

WarrenWeckesser commented Mar 1, 2023

I pushed an update to make the implementation follow the style (more or less) of Gamma and ChiSquared.

@WarrenWeckesser
Copy link
Collaborator Author

I see that #1292 is also updating the Dirichlet distribution. This PR will almost certainly have conflicts with #1292, but I can deal with that if/when #1292 is merged.

rand_distr/src/dirichlet.rs Outdated Show resolved Hide resolved
rand_distr/src/dirichlet.rs Outdated Show resolved Hide resolved
rand_distr/src/dirichlet.rs Outdated Show resolved Hide resolved
rand_distr/src/dirichlet.rs Outdated Show resolved Hide resolved
rand_distr/src/dirichlet.rs Outdated Show resolved Hide resolved
rand_distr/src/dirichlet.rs Outdated Show resolved Hide resolved
rand_distr/src/dirichlet.rs Outdated Show resolved Hide resolved
rand_distr/src/dirichlet.rs Outdated Show resolved Hide resolved
rand_distr/src/dirichlet.rs Outdated Show resolved Hide resolved
rand_distr/src/dirichlet.rs Outdated Show resolved Hide resolved
/// Dirichlet distribution that generates samples using the gamma distribution.
FromGamma(DirichletFromGamma<F>),

/// Dirichlet distribution that generates samples using the beta distribution.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:

Suggested change
/// Dirichlet distribution that generates samples using the beta distribution.
/// Dirichlet distribution that generates samples using the Beta distribution.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

rand_distr/src/dirichlet.rs Outdated Show resolved Hide resolved
rand_distr/src/dirichlet.rs Outdated Show resolved Hide resolved
rand_distr/src/dirichlet.rs Outdated Show resolved Hide resolved
Copy link
Collaborator

@vks vks left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks for the update! Besides a few nits, I think we need to change the error handling to propagate Results instead of unwrapping them.

I did not review the algorithms yet.

@WarrenWeckesser
Copy link
Collaborator Author

Thanks @vks. I'm working on updating this PR with the recent changes from #1292. I'll probably incorporate the small changes that you have suggested (e.g. use of in-place operators) while doing that, and then I'll fix the error handling.

@WarrenWeckesser
Copy link
Collaborator Author

WarrenWeckesser commented Mar 23, 2023

Design question: how to deal with a limitation of const generics?

My update to this PR will have the same basic design, but the use of, for example,samplers: Box<[Gamma<F>]> in the DirichletFromGamma struct becomes samplers: [Gamma<F>; N]. The problem is that the same field in the DirichletFromBeta struct requires only N - 1 instances of Beta. I can't write samplers: [Beta<F>; N - 1] because the expression N - 1 is not currently supported. Some work-arounds:

  • The compiler tells me

     = help: const parameters may only be used as standalone arguments, i.e. `N`
     = help: use `#![feature(generic_const_exprs)]` to allow generic const expressions
    

    Would it be OK to use that feature?

  • Use a length N array, i.e. samplers: [Beta<F>; N], and fill in the last element of the array with an unused Beta instance. (This is not a very appealing option!)

  • Go back to Box<[Beta<F>]> for this struct.

Any recommendations? Any other ideas?

@WarrenWeckesser
Copy link
Collaborator Author

I'm going to go with Box<[Beta<F>]> for now.

@dhardy
Copy link
Member

dhardy commented Mar 23, 2023

Would it be OK to use that feature?

Absolutely not by default. It's an unstable feature and only available in nightly Rust. (In theory you could, behind a #[cfg(feature = "nightly")], but then you need to write and test two implementations. Once the Rust feature is stabilised the next rand_distr release can depend on it.)

Rust features can take a long time to make it into stable. I guess we'll have generic_const_exprs within a couple of years, but don't know.

Use a length N array

A few dozen wasted bytes, but it's fine.

Go back to Box<[Beta<F>]> for this struct.

Also fine. A little less memory usage, but also less data locality.

The main change is that Dirichlet now has a const trait N for
the dimension.  The other significant change is to propagate
errors that could occur when Beta:new or Gamma::new is called.
@WarrenWeckesser
Copy link
Collaborator Author

I have updated the pull request with significant changes, both to handle Dirichlet now having a const generic parameter, and to propagate errors from calling Gamma::new and Beta::new.

In DirichletFromGamma::new(alpha), my rust-fu was not strong enough to figure out how to loop through alpha to generate the Gamma instances as an array while also handling the possibility of an error from Gamma::new(). So the samplers field of DirichletFromGamma is now Box<[Gamma<F>]> instead of [Gamma<F>; N].

There are now more Error variants returned by Dirichlet::new, because of more careful validation of alpha (e.g. reject subnormal values) and to handle possible errors from Gamma::new or Beta::new.

@dhardy
Copy link
Member

dhardy commented Mar 26, 2023

my rust-fu was not strong enough to figure out how to loop through alpha to generate the Gamma instances as an array while also handling the possibility of an error from Gamma::new().

The only options I can think of are panic with catch_unwind or unsafe code (e.g. use MaybeUninit and cast). Rust doesn't have a good option here that I am aware of so keeping your boxed array is probably the best choice.

@vks
Copy link
Collaborator

vks commented Mar 28, 2023

In such cases, the easiest way is to initialize the array with some default values. We could use [Option<Gamma<F>>; N] and initialize it with [None; N], which can then be overwritten iteratively with Gamma::new().

Avoiding the indirection requires unsafe code, or a safe abstraction like arrayvec.

@WarrenWeckesser
Copy link
Collaborator Author

We could use [Option<Gamma>; N] ...

Then when samplers is used in sample(), instead of

*s = g.sample(rng);

we would need something like

*s = g.expect("This should not happen!").sample(rng)

correct? new() must never return without completely initializing samplers, so having samplers be an array of Option smells like an implementation detail of new() that shouldn't "leak" out to other code. But maybe that's fine--it is only a couple lines of code in a private struct.

Another less-than-ideal alternative is to initialize the array as [Gamma::new(1, 1); N] to keep the compiler happy, and then overwrite all the values in the array. Actually, it would look more like

[Gamma::new(F::one(), F::one()).map_err(|_| DirichletFromGammaError::GammmaNewFailed)?; N]

Yet another alternative for using an array is to assemble the Gamma instances in a vector, and then cast the vector to an array with .try_into(). To make this work, I had to add the bound core::fmt::Debug to F. Would anyone be interested in seeing this implemented in the pull request? I don't mind trying out alternatives, if y'all have the patience to review them.

Or, I'm fine with "Just leave it as it is" or "Just do X".

@dhardy
Copy link
Member

dhardy commented Mar 29, 2023

Yet another alternative for using an array is to assemble the Gamma instances in a vector, and then cast the vector to an array with .try_into(). To make this work, I had to add the bound core::fmt::Debug to F.

I'd forgotten (or missed) this method of constructing an array: https://doc.rust-lang.org/stable/std/vec/struct.Vec.html#impl-TryFrom%3CVec%3CT%2C%20A%3E%3E-for-%5BT%3B%20N%5D
I don't see any Debug requirement. Looks viable to me?

@WarrenWeckesser
Copy link
Collaborator Author

I don't see any Debug requirement. Looks viable to me?

If I make these changes:

diff --git a/rand_distr/src/dirichlet.rs b/rand_distr/src/dirichlet.rs
index 244ff1897b..17637df36e 100644
--- a/rand_distr/src/dirichlet.rs
+++ b/rand_distr/src/dirichlet.rs
@@ -26,7 +26,7 @@ where
     Exp1: Distribution<F>,
     Open01: Distribution<F>,
 {
-    samplers: Box<[Gamma<F>]>,
+    samplers: [Gamma<F>; N],
 }
 
 /// Error type returned from `DirchletFromGamma::new`.
@@ -57,7 +57,8 @@ where
             gamma_dists.push(dist);
         }
         Ok(DirichletFromGamma {
-            samplers: gamma_dists.into_boxed_slice(),
+            // By construction, the call of `.try_into()` should not return an error.
+            samplers: gamma_dists.try_into().unwrap(),
         })
     }
 }

and I run cargo test, I get this:

$ cargo test
   Compiling rand_distr v0.5.0 (/home/warren/repos/git/forks/rand/rand_distr)
error[E0277]: `F` doesn't implement `Debug`
    --> rand_distr/src/dirichlet.rs:61:46
     |
61   |             samplers: gamma_dists.try_into().unwrap(),
     |                                              ^^^^^^ `F` cannot be formatted using `{:?}` because it doesn't implement `Debug`
     |
note: required for `Gamma<F>` to implement `Debug`
    --> rand_distr/src/gamma.rs:57:23
     |
57   | #[derive(Clone, Copy, Debug, PartialEq)]
     |                       ^^^^^ unsatisfied trait bound introduced in this `derive` macro
     = note: 1 redundant requirement hidden
     = note: required for `Vec<Gamma<F>>` to implement `Debug`
note: required by a bound in `Result::<T, E>::unwrap`
    --> /home/warren/.rustup/toolchains/nightly-x86_64-unknown-linux-gnu/lib/rustlib/src/rust/library/core/src/result.rs:1086:12
     |
1086 |         E: fmt::Debug,
     |            ^^^^^^^^^^ required by this bound in `Result::<T, E>::unwrap`
     = note: this error originates in the derive macro `Debug` (in Nightly builds, run with -Z macro-backtrace for more info)
help: consider further restricting this bound
     |
42   |     F: Float + core::fmt::Debug,
     |              ++++++++++++++++++

For more information about this error, try `rustc --explain E0277`.
error: could not compile `rand_distr` (lib) due to previous error
warning: build failed, waiting for other jobs to finish...
error[E0277]: `F` doesn't implement `Debug`
    --> rand_distr/src/dirichlet.rs:61:46
     |
61   |             samplers: gamma_dists.try_into().unwrap(),
     |                                              ^^^^^^ `F` cannot be formatted using `{:?}` because it doesn't implement `Debug`
     |
note: required for `gamma::Gamma<F>` to implement `Debug`
    --> rand_distr/src/gamma.rs:57:23
     |
57   | #[derive(Clone, Copy, Debug, PartialEq)]
     |                       ^^^^^ unsatisfied trait bound introduced in this `derive` macro
     = note: 1 redundant requirement hidden
     = note: required for `alloc::vec::Vec<gamma::Gamma<F>>` to implement `Debug`
note: required by a bound in `Result::<T, E>::unwrap`
    --> /home/warren/.rustup/toolchains/nightly-x86_64-unknown-linux-gnu/lib/rustlib/src/rust/library/core/src/result.rs:1086:12
     |
1086 |         E: fmt::Debug,
     |            ^^^^^^^^^^ required by this bound in `Result::<T, E>::unwrap`
     = note: this error originates in the derive macro `Debug` (in Nightly builds, run with -Z macro-backtrace for more info)
help: consider further restricting this bound
     |
42   |     F: Float + core::fmt::Debug,
     |              ++++++++++++++++++

error: could not compile `rand_distr` (lib test) due to previous error

I have to add the bound core::fmt::Debug at line 42 and 260 to get cargo test to compile and run.

@WarrenWeckesser
Copy link
Collaborator Author

Of course, just minutes after I comment, I dig a bit more and realize the cause of the issue. It is .unwrap() that has the Debug bound. If I deal with the Result returned by .try_into() differently, then Debug for F is not necessary.

@WarrenWeckesser
Copy link
Collaborator Author

Yet another alternative for using an array is to assemble the Gamma instances in a vector, and then cast the vector to an array with .try_into().

I pushed an update with this change.


/// gamma_dists.try_into() failed (in theory, this should not happen).
GammaArrayCreationFailed,
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not 100% sure about the error types here. Ideally, they should refer to invalid parameters supplied by the caller.

However, in this case they are only used internally, so I think this is fine.

Copy link
Collaborator

@vks vks left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for taking a while to come back to this. It looks great, thanks!

@vks vks merged commit 1464b88 into rust-random:master May 1, 2023
@WarrenWeckesser WarrenWeckesser deleted the dirichlet-small-alpha branch May 1, 2023 18:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
D-changes Do: changes requested D-review Do: needs review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants