Skip to content

Fix Dirichlet rand overflows #1702 #1886

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 11 commits into
base: master
Choose a base branch
from

Conversation

quildtide
Copy link
Contributor

Closes #1702

Core Issues

The rand(d::Dirichlet) calls Gamma(d.α[i]) i times and writes to x.

It then rescales this result by inv(sum(x)). When this overflows to Inf, we run into our 2 failure modes:

  1. When all x_i == 0, we get Inf * 0 = NaN

  2. When some x_i != 0, but are all deeply subnormal enough that inv(sum(x)) still overflows. We get some Inf values as a result.

For case 2, on Julia 1.11.0-rc1 on Windows, for example:

julia> rand(Xoshiro(123322), Dirichlet([4.5e-5, 4.5e-5, 8e-5]))
3-element Vector{Float64}:
  Inf
  Inf
 NaN

Fixing Case 1

If case 1 is happening, the best thing possible from a runtime perspective is probably to just choose a random x from a categorical distribution with the same mean. This is the limit behavior of the Dirichlet distribution, and my logic on why it's "safe enough" is:

  • If all-zeros are a rare occurance, this has little impact on the end sample
  • If all-zeros are common, rejecting samples and pulling another will probably yield a near-infinite reject loop. On the other hand, we're close enough to the limit behavior that floating point arithmetic errors are probably hurting us more than adopting the limit behavior.
  • While this should theoretically result in incorrect variance, testing shows that variance is within reasonable tolerance (0.01) of the real value.

There is another option where we could try rejecting all-0 samples until a certain maximum amount of samples before failing, but I think this is probably a waste of time for little gain in accuracy.

Fixing Case 2

We rescale all values by multiplying them by floatmax(), so inv doesn't overflow. This should work consistently for all float types where floatmax() * nextfloat() > floatmin() by at least ~1 magnitudes, which I think should be true for any non-exotic float types. I originally thought it would be enough to just set the largest value to 1, but it's actually possible to currently pull multiple subnormal values pre-normalization, and the method I adopted maintains the ratio between them.

Currently:

julia> rand(Xoshiro(123322), Dirichlet([4.5e-5, 4.5e-5, 8e-5]))
3-element Vector{Float64}:
  Inf
  Inf
 NaN

After this patch:

julia> rand(Xoshiro(123322), Dirichlet([4.5e-5, 4.5e-5, 8e-5]))
3-element Vector{Float64}:
  0.625061099164708
  0.37493890083529186
  0.0

Subnormal Parameters

While testing, I realized that my original fix for case 1 would break when all of the parameters themselves were deeply subnormal, e.g. Dirichlet([5e-321, 1e-321, 4e-321]). Given that the Dirichlet distribution is decently common in things like Bayesian inference, I thought it would be worth attempting to support these cases too.

Note that mean, var, etc. currently break on these deeply subnormally-parameterized distributions, but fixing that felt out of scope to this pull request. Fixing mean would be simple, but it could potentially be rather chunky. I am less sure about var and others.

@codecov-commenter
Copy link

codecov-commenter commented Aug 16, 2024

Codecov Report

Attention: Patch coverage is 83.33333% with 11 lines in your changes missing coverage. Please review.

Project coverage is 86.17%. Comparing base (b348b5b) to head (c77e35c).

Files with missing lines Patch % Lines
src/samplers/expgamma.jl 72.50% 11 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1886      +/-   ##
==========================================
- Coverage   86.20%   86.17%   -0.04%     
==========================================
  Files         146      147       +1     
  Lines        8769     8829      +60     
==========================================
+ Hits         7559     7608      +49     
- Misses       1210     1221      +11     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Comment on lines 157 to 161
function _rand_handle_overflow!(
rng::AbstractRNG,
d::Union{Dirichlet,DirichletCanon},
x::AbstractVector{<:Real}
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes the style consistent with the surrounding code:

Suggested change
function _rand_handle_overflow!(
rng::AbstractRNG,
d::Union{Dirichlet,DirichletCanon},
x::AbstractVector{<:Real}
)
function _rand_handle_overflow!(rng::AbstractRNG,
d::Union{Dirichlet,DirichletCanon},
x::AbstractVector{<:Real})

@devmotion
Copy link
Member

devmotion commented Aug 17, 2024

Instead of dealing with subnormals, at least for the example here sampling in log space would be sufficient (see also #1003 (comment), #1003 (comment), and #1810). For instance, with an ExpGamma version of the Marsaglia sampler I get:

julia> using Distributions, LogExpFunctions, Random

julia> using Distributions: GammaMTSampler

julia> # Inverse Power sampler in log-space (exp-gamma distribution)
       # uses the x*u^(1/a) trick from Marsaglia and Tsang (2000) for when shape < 1
       struct ExpGammaIPSampler{S<:Sampleable{Univariate,Continuous},T<:Real} <: Sampleable{Univariate,Continuous}
           s::S #sampler for Gamma(1+shape,scale)
           nia::T #-1/scale
       end

julia> ExpGammaIPSampler(d::Gamma) = ExpGammaIPSampler(d, GammaMTSampler)

julia> function ExpGammaIPSampler(d::Gamma, ::Type{S}) where {S<:Sampleable}
           shape_d = shape(d)
           sampler = S(Gamma{partype(d)}(1 + shape_d, scale(d)))
           return ExpGammaIPSampler(sampler, -inv(shape_d))
       end

julia> function rand(rng::AbstractRNG, s::ExpGammaIPSampler)
           x = log(rand(rng, s.s))
           e = randexp(rng)
           return muladd(s.nia, e, x)
       end

julia> function myrand!(rng::AbstractRNG, d::Dirichlet, x::AbstractVector{<:Real})
           for (i, αi) in zip(eachindex(x), d.alpha)
               @inbounds x[i] = rand(rng, ExpGammaIPSampler(Gamma(αi)))
           end
           return softmax!(x)
       end

julia> myrand!(Xoshiro(123322), Dirichlet([4.5e-5, 4.5e-5, 8e-5]), zeros(3))
3-element Vector{Float64}:
 0.6250610991638559
 0.37493890083615117
 0.0

@quildtide
Copy link
Contributor Author

For instance, with an ExpGamma version of the Marsaglia sampler I get:

Okay, after doing some testing, this implementation seems to be superior to what I was doing until sum(alpha) itself is subnormal enough.

With your example implementation:

julia> myrand!(Random.default_rng(), Dirichlet([6e-309, 5e-309, 5e-309]), zeros(3))
3-element Vector{Float64}:
 1.0
 0.0
 0.0

julia> myrand!(Random.default_rng(), Dirichlet([5e-309, 5e-309, 5e-309]), zeros(3))
3-element Vector{Float64}:
 NaN
 NaN
 NaN

I brought in the code snippet from #1810 and that worked for a bit longer:

julia> function myrand2!(rng::AbstractRNG, d::Dirichlet, x::AbstractVector{<:Real})
                  for (i, αi) in zip(eachindex(x), d.alpha)
                      @inbounds x[i] = randlogGamma(αi)
                  end
                  return softmax!(x)
           end
julia> myrand2!(Random.default_rng(), Dirichlet([5e-310, 5e-310, 5e-310]), zeros(3))
3-element Vector{Float64}:
 0.0
 1.0
 0.0

julia> myrand2!(Random.default_rng(), Dirichlet([5e-311, 5e-311, 5e-311]), zeros(3))
3-element Vector{Float64}:
 NaN
 NaN
 NaN

The good news though is that there's only 1 failure mode now: when rand(ExpGamma) == -Inf. I'll maintain an edge case check to go into the Categorical sampler failure mode.

@quildtide quildtide marked this pull request as draft August 17, 2024 05:34
quildtide and others added 5 commits September 4, 2024 00:41
Co-Authored-By: David Widmann <devmotion@users.noreply.github.com>
Co-Authored-By: chelate <42802644+chelate@users.noreply.github.com>
@quildtide quildtide marked this pull request as ready for review September 4, 2024 06:45
@quildtide
Copy link
Contributor Author

quildtide commented Sep 4, 2024

@devmotion So this pull request's scope has gotten larger in a strange way.

New Summary of changes:

  • Implement ExpGammaIPSampler (based off of your code above)
  • Implement ExpGammaSSSampler (based off of random log-gamma for taming underflow issues #1810, with some improvements)
  • Implement _logsampler, _logrand, and _logrand! on Gamma for these
  • Dirichlet rand now has the following cases:
    • If any alpha are > 0.5, do what we were doing before
      • I also tried to set this cutoff at 1, but this caused multiple DirichletMultinomial tests to error for reasons I do not yet have an explanation for.
    • Else, try to sample via _logrand
      • This dispatches to ExpGammaIPSampler for alpha > 0.3
      • Else dispatches to ExpGammaSSSampler
    • If even these fail (all -Inf), use Categorical limit behavior fallback

What this doesn't do:

  • Document or export ExpGammaIPSampler, ExpGammaSSSampler, or any of the _log sampling methods

This may seem a bit backwards, but I think that can be saved for another pull request later. The goal here is to close #1702.

@chelate
Copy link

chelate commented Dec 18, 2024

I started writing a PR for the ExpGamma distribution and documentation. But this pr gets the dirichlet sampling right, which is really a harder problem and much more important. I will wait for it to merge and then promise to build on it, moving the undocumented methods to an expgamma.jl univariate distribution page.

@quildtide
Copy link
Contributor Author

@devmotion Could this be looked at again? Thanks.

@chelate
Copy link

chelate commented Feb 5, 2025

Just wondering if there is an objection. Do we need to make expgamma.jl before this can be merged?

@ararslan
Copy link
Member

I wouldn't take the lack of response as objection so much as lack of maintainer bandwidth to review and respond (certainly speaking for myself, at least). I appreciate the contribution and your patience, @quildtide.

Though I'm not currently able to provide a thoughtful review, I can say that something that will make a future reviewer's job easier would be to include comments in the code that justify the choices of 0.5 and 0.3 as cutoffs where applicable.

@ararslan ararslan requested a review from devmotion March 12, 2025 18:16
@quildtide
Copy link
Contributor Author

Though I'm not currently able to provide a thoughtful review, I can say that something that will make a future reviewer's job easier would be to include comments in the code that justify the choices of 0.5 and 0.3 as cutoffs where applicable.

The 0.3 was based off of the note in Liu, Martin, and Syring that their algorithm's acceptance rate is higher until 0.3 when compared to algorithm 3 in Kundu and Gupta. I neglected, however, to notice that we do not currently have Kundu and Gupta's algorithm 3 implemented at the moment.

The 0.5 was mostly arbitary; it was originally 1, but a test failed when it was that high.

It's possible that these cutoffs are not optimal for performance reasons; I did not have time when I made this PR to do proper performance testing. I think I may do some of that in the near future.

I am also tempted to try implementing the Kundu-Gupta sampler now, but I reckon that would only make the PR harder to review.

@quildtide
Copy link
Contributor Author

I have pushed comments for now. I will do some performance testing to find potential better thresholds if I wind up having time to do so before this can be reviewed.

Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My general feeling is that subnormal numbers are not a major concern in Distributions - even if there are some improvements here, I assume there are many other problems both in upstream and downstream code. Floating point numbers are inherently limited, we can only operate within their restrictions. Alternatively, you might have to switch to number types with higher or arbitrary precision.

On the other hand, I think alternative samplers and distributions such as ExpGamma that operate in log-space would be quite useful in different places (as evidenced by a few old issues I had opened a few years ago IIRC). So I think we should

  1. separate the ExpGamma part, ie, add an ExpGamma distribution + the samplers in a separate PR and make sure they are properly tested using the existing test infrastructure for distributions and samplers
  2. change this PR to use ExpGamma in Dirichlet when it's beneficial (requires numerical experiments + benchmarks)?

@inbounds x[i] = rand(rng, Gamma(αi))
x::AbstractVector{E}) where {E<:Real}

if any(a -> a >= .5, d.alpha)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why 0.5?

else
# Sample in log-space to lower underflow risk
for (i, αi) in zip(eachindex(x), d.alpha)
@inbounds x[i] = _logrand(rng, Gamma(αi))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conceptually, I think this should be

Suggested change
@inbounds x[i] = _logrand(rng, Gamma(αi))
@inbounds x[i] = rand(rng, ExpGamma(αi))

@inbounds x[i] = _logrand(rng, Gamma(αi))
end

if all(isinf, x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems somewhat incorrect - you're only looking for -Inf but checking for Inf and -Inf. For Inf already a single value is problematic. An additional question is why isinf instead of !isfinite which would e.g. also guard against NaN.

Comment on lines +178 to +181
p = copy(d.alpha)
p .*= floatmax(eltype(p)) # rescale to non-subnormal
x .= zero(E)
x[rand(rng, Categorical(inv(sum(p)) .* p))] = one(E)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This causes many allocations and won't work in general - p might not be mutable.

return lmul!(inv(sum(x)), x)
else
# Sample in log-space to lower underflow risk
_logrand!(rng, Gamma(FillArrays.getindex_value(d.alpha)), x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, conceptually, this should be

Suggested change
_logrand!(rng, Gamma(FillArrays.getindex_value(d.alpha)), x)
_rand!(rng, ExpGamma(FillArrays.getindex_value(d.alpha)), x)

# Sample in log-space to lower underflow risk
_logrand!(rng, Gamma(FillArrays.getindex_value(d.alpha)), x)

if all(isinf, x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same questions as above.

x::AbstractVector{<:Real}) where {T<:Real}
rand!(rng, Gamma(FillArrays.getindex_value(d.alpha)), x)
lmul!(inv(sum(x)), x) # this returns x
x::AbstractVector{E}) where {T<:Real, E<:Real}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to change the function signature here (see comment below).

Comment on lines +205 to +208
n = length(d.alpha)
p = Fill(inv(n), n)
x .= zero(E)
x[rand(rng, Categorical(p))] = one(E)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't work in general - x might not use one-based indexing. Moreover, there's no need to construct a Categorical distribution here:

Suggested change
n = length(d.alpha)
p = Fill(inv(n), n)
x .= zero(E)
x[rand(rng, Categorical(p))] = one(E)
fill!(x, false)
x[rand(rng, firstindex(x):lastindex(x))] = true

(or alternatively zero(eltype(x)) and oneunit(eltype(x))).

@quildtide
Copy link
Contributor Author

quildtide commented Mar 28, 2025

My general feeling is that subnormal numbers are not a major concern in Distributions - even if there are some improvements here, I assume there are many other problems both in upstream and downstream code. Floating point numbers are inherently limited, we can only operate within their restrictions. Alternatively, you might have to switch to number types with higher or arbitrary precision.

I think this is a reasonable position, especially since the subnormal edge case only emerges when all alphas themselves were already deeply subnormal (after implementing log-space sampling).

On the other hand, I think alternative samplers and distributions such as ExpGamma that operate in log-space would be quite useful in different places (as evidenced by a few old issues I had opened a few years ago IIRC). So I think we should

  1. separate the ExpGamma part, ie, add an ExpGamma distribution + the samplers in a separate PR and make sure they are properly tested using the existing test infrastructure for distributions and samplers

I think @chelate was working on this. I can fork this branch to a branch with only the ExpGamma sampling so chelate can do a PR with that and their own work (testing, documentation, etc.).

  1. change this PR to use ExpGamma in Dirichlet when it's beneficial (requires numerical experiments + benchmarks)?

There's 2 types of testing that can be done.

We do already know that there's a cutoff for alphas around 4e-8 where the current method (no logspace sampling) breaks completely. So we already know that anything around this cutoff is already beneficial.

But determining a high bound for the cutoff would indeed require some benchmarking.

And then there's benchmarking for when to switch between Liu-Martin-Syring and the Inverse Power sampler (the one currently at 0.3). That one might actually be the one that requires more benchmarking, since performance between the log-space and current sampler should be similar.

@quildtide quildtide marked this pull request as draft March 28, 2025 21:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

sampling from Dirichlet produces NaN and Inf at extreme alpha
5 participants