-
Notifications
You must be signed in to change notification settings - Fork 415
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix Dirichlet rand overflows #1702 #1886
base: master
Are you sure you want to change the base?
Fix Dirichlet rand overflows #1702 #1886
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #1886 +/- ##
==========================================
- Coverage 85.99% 85.96% -0.03%
==========================================
Files 144 145 +1
Lines 8666 8726 +60
==========================================
+ Hits 7452 7501 +49
- Misses 1214 1225 +11 ☔ View full report in Codecov by Sentry. |
src/multivariate/dirichlet.jl
Outdated
function _rand_handle_overflow!( | ||
rng::AbstractRNG, | ||
d::Union{Dirichlet,DirichletCanon}, | ||
x::AbstractVector{<:Real} | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes the style consistent with the surrounding code:
function _rand_handle_overflow!( | |
rng::AbstractRNG, | |
d::Union{Dirichlet,DirichletCanon}, | |
x::AbstractVector{<:Real} | |
) | |
function _rand_handle_overflow!(rng::AbstractRNG, | |
d::Union{Dirichlet,DirichletCanon}, | |
x::AbstractVector{<:Real}) |
Instead of dealing with subnormals, at least for the example here sampling in log space would be sufficient (see also #1003 (comment), #1003 (comment), and #1810). For instance, with an julia> using Distributions, LogExpFunctions, Random
julia> using Distributions: GammaMTSampler
julia> # Inverse Power sampler in log-space (exp-gamma distribution)
# uses the x*u^(1/a) trick from Marsaglia and Tsang (2000) for when shape < 1
struct ExpGammaIPSampler{S<:Sampleable{Univariate,Continuous},T<:Real} <: Sampleable{Univariate,Continuous}
s::S #sampler for Gamma(1+shape,scale)
nia::T #-1/scale
end
julia> ExpGammaIPSampler(d::Gamma) = ExpGammaIPSampler(d, GammaMTSampler)
julia> function ExpGammaIPSampler(d::Gamma, ::Type{S}) where {S<:Sampleable}
shape_d = shape(d)
sampler = S(Gamma{partype(d)}(1 + shape_d, scale(d)))
return ExpGammaIPSampler(sampler, -inv(shape_d))
end
julia> function rand(rng::AbstractRNG, s::ExpGammaIPSampler)
x = log(rand(rng, s.s))
e = randexp(rng)
return muladd(s.nia, e, x)
end
julia> function myrand!(rng::AbstractRNG, d::Dirichlet, x::AbstractVector{<:Real})
for (i, αi) in zip(eachindex(x), d.alpha)
@inbounds x[i] = rand(rng, ExpGammaIPSampler(Gamma(αi)))
end
return softmax!(x)
end
julia> myrand!(Xoshiro(123322), Dirichlet([4.5e-5, 4.5e-5, 8e-5]), zeros(3))
3-element Vector{Float64}:
0.6250610991638559
0.37493890083615117
0.0 |
Okay, after doing some testing, this implementation seems to be superior to what I was doing until sum(alpha) itself is subnormal enough. With your example implementation: julia> myrand!(Random.default_rng(), Dirichlet([6e-309, 5e-309, 5e-309]), zeros(3))
3-element Vector{Float64}:
1.0
0.0
0.0
julia> myrand!(Random.default_rng(), Dirichlet([5e-309, 5e-309, 5e-309]), zeros(3))
3-element Vector{Float64}:
NaN
NaN
NaN I brought in the code snippet from #1810 and that worked for a bit longer: julia> function myrand2!(rng::AbstractRNG, d::Dirichlet, x::AbstractVector{<:Real})
for (i, αi) in zip(eachindex(x), d.alpha)
@inbounds x[i] = randlogGamma(αi)
end
return softmax!(x)
end
julia> myrand2!(Random.default_rng(), Dirichlet([5e-310, 5e-310, 5e-310]), zeros(3))
3-element Vector{Float64}:
0.0
1.0
0.0
julia> myrand2!(Random.default_rng(), Dirichlet([5e-311, 5e-311, 5e-311]), zeros(3))
3-element Vector{Float64}:
NaN
NaN
NaN The good news though is that there's only 1 failure mode now: when |
Co-Authored-By: David Widmann <devmotion@users.noreply.github.com>
Co-Authored-By: chelate <42802644+chelate@users.noreply.github.com>
@devmotion So this pull request's scope has gotten larger in a strange way. New Summary of changes:
What this doesn't do:
This may seem a bit backwards, but I think that can be saved for another pull request later. The goal here is to close #1702. |
Closes #1702
Core Issues
The
rand(d::Dirichlet)
callsGamma(d.α[i])
i times and writes tox
.It then rescales this result by
inv(sum(x))
. When this overflows toInf
, we run into our 2 failure modes:When all x_i == 0, we get Inf * 0 = NaN
When some x_i != 0, but are all deeply subnormal enough that
inv(sum(x))
still overflows. We get some Inf values as a result.For case 2, on Julia 1.11.0-rc1 on Windows, for example:
Fixing Case 1
If case 1 is happening, the best thing possible from a runtime perspective is probably to just choose a random x from a categorical distribution with the same mean. This is the limit behavior of the Dirichlet distribution, and my logic on why it's "safe enough" is:
There is another option where we could try rejecting all-0 samples until a certain maximum amount of samples before failing, but I think this is probably a waste of time for little gain in accuracy.
Fixing Case 2
We rescale all values by multiplying them by floatmax(), so
inv
doesn't overflow. This should work consistently for all float types wherefloatmax() * nextfloat() > floatmin()
by at least ~1 magnitudes, which I think should be true for any non-exotic float types. I originally thought it would be enough to just set the largest value to 1, but it's actually possible to currently pull multiple subnormal values pre-normalization, and the method I adopted maintains the ratio between them.Currently:
After this patch:
Subnormal Parameters
While testing, I realized that my original fix for case 1 would break when all of the parameters themselves were deeply subnormal, e.g.
Dirichlet([5e-321, 1e-321, 4e-321])
. Given that the Dirichlet distribution is decently common in things like Bayesian inference, I thought it would be worth attempting to support these cases too.Note that
mean
,var
, etc. currently break on these deeply subnormally-parameterized distributions, but fixing that felt out of scope to this pull request. Fixingmean
would be simple, but it could potentially be rather chunky. I am less sure aboutvar
and others.