Skip to content

Fix Dirichlet rand overflows #1702 #1886

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 11 commits into
base: master
Choose a base branch
from
57 changes: 50 additions & 7 deletions src/multivariate/dirichlet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,18 +156,61 @@ end

function _rand!(rng::AbstractRNG,
d::Union{Dirichlet,DirichletCanon},
x::AbstractVector{<:Real})
for (i, αi) in zip(eachindex(x), d.alpha)
@inbounds x[i] = rand(rng, Gamma(αi))
x::AbstractVector{E}) where {E<:Real}

if any(a -> a >= .5, d.alpha)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why 0.5?

# 0.5 is a placeholder; optimal value unknown
# 1 is known to be too high.
for (i, αi) in zip(eachindex(x), d.alpha)
@inbounds x[i] = rand(rng, Gamma(αi))
end

return lmul!(inv(sum(x)), x)
else
# Sample in log-space to lower underflow risk
for (i, αi) in zip(eachindex(x), d.alpha)
@inbounds x[i] = _logrand(rng, Gamma(αi))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conceptually, I think this should be

Suggested change
@inbounds x[i] = _logrand(rng, Gamma(αi))
@inbounds x[i] = rand(rng, ExpGamma(αi))

end

if all(isinf, x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems somewhat incorrect - you're only looking for -Inf but checking for Inf and -Inf. For Inf already a single value is problematic. An additional question is why isinf instead of !isfinite which would e.g. also guard against NaN.

# Final fallback, parameters likely deeply subnormal
# Distribution behavior approaches categorical as Σα -> 0
p = copy(d.alpha)
p .*= floatmax(eltype(p)) # rescale to non-subnormal
x .= zero(E)
x[rand(rng, Categorical(inv(sum(p)) .* p))] = one(E)
Comment on lines +178 to +181
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This causes many allocations and won't work in general - p might not be mutable.

return x
end

return softmax!(x)
end
lmul!(inv(sum(x)), x) # this returns x
end

function _rand!(rng::AbstractRNG,
d::Dirichlet{T,<:FillArrays.AbstractFill{T}},
x::AbstractVector{<:Real}) where {T<:Real}
rand!(rng, Gamma(FillArrays.getindex_value(d.alpha)), x)
lmul!(inv(sum(x)), x) # this returns x
x::AbstractVector{E}) where {T<:Real, E<:Real}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to change the function signature here (see comment below).


if FillArrays.getindex_value(d.alpha) >= 0.5
# 0.5 is a placeholder; optimal value unknown
# 1 is known to be too high.
rand!(rng, Gamma(FillArrays.getindex_value(d.alpha)), x)
return lmul!(inv(sum(x)), x)
else
# Sample in log-space to lower underflow risk
_logrand!(rng, Gamma(FillArrays.getindex_value(d.alpha)), x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, conceptually, this should be

Suggested change
_logrand!(rng, Gamma(FillArrays.getindex_value(d.alpha)), x)
_rand!(rng, ExpGamma(FillArrays.getindex_value(d.alpha)), x)


if all(isinf, x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same questions as above.

# Final fallback, parameters likely deeply subnormal
# Distribution behavior approaches categorical as Σα -> 0
n = length(d.alpha)
p = Fill(inv(n), n)
x .= zero(E)
x[rand(rng, Categorical(p))] = one(E)
Comment on lines +205 to +208
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't work in general - x might not use one-based indexing. Moreover, there's no need to construct a Categorical distribution here:

Suggested change
n = length(d.alpha)
p = Fill(inv(n), n)
x .= zero(E)
x[rand(rng, Categorical(p))] = one(E)
fill!(x, false)
x[rand(rng, firstindex(x):lastindex(x))] = true

(or alternatively zero(eltype(x)) and oneunit(eltype(x))).

return x
end

return softmax!(x)
end
end

#######################################
Expand Down
1 change: 1 addition & 0 deletions src/samplers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ for fname in ["aliastable.jl",
"poisson.jl",
"exponential.jl",
"gamma.jl",
"expgamma.jl",
"multinomial.jl",
"vonmises.jl",
"vonmisesfisher.jl",
Expand Down
89 changes: 89 additions & 0 deletions src/samplers/expgamma.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# These are used to bypass subnormals when sampling from

# Inverse Power sampler
# uses the x*u^(1/a) trick from Marsaglia and Tsang (2000) for when shape < 1
struct ExpGammaIPSampler{S<:Sampleable{Univariate,Continuous},T<:Real} <: Sampleable{Univariate,Continuous}
s::S #sampler for Gamma(1+shape,scale)
nia::T #-1/scale
end

ExpGammaIPSampler(d::Gamma) = ExpGammaIPSampler(d, GammaMTSampler)
function ExpGammaIPSampler(d::Gamma, ::Type{S}) where {S<:Sampleable}
shape_d = shape(d)
sampler = S(Gamma{partype(d)}(1 + shape_d, scale(d)))
return GammaIPSampler(sampler, -inv(shape_d))
end

function rand(rng::AbstractRNG, s::ExpGammaIPSampler)
x = log(rand(rng, s.s))
e = randexp(rng, typeof(x))
return muladd(s.nia, e, x)
end


# Small Shape sampler
# From Liu, C., Martin, R., and Syring, N. (2015) for when shape < 0.3
struct ExpGammaSSSampler{T<:Real} <: Sampleable{Univariate,Continuous}
α::T
θ::T
λ::T
ω::T
ωω::T
end

function ExpGammaSSSampler(d::Gamma)
α = shape(d)
ω = α / MathConstants.e / (1 - α)
return ExpGammaSSSampler(promote(
α,
scale(d),
inv(α) - 1,
ω,
inv(ω + 1)
)...)
end

function rand(rng::AbstractRNG, s::ExpGammaSSSampler{T})::Float64 where T
flT = float(T)
while true
U = rand(rng, flT)
z = (U <= s.ωω) ? -log(U / s.ωω) : log(rand(rng, flT)) / s.λ
h = exp(-z - exp(-z / s.α))
η = z >= zero(T) ? exp(-z) : s.ω * s.λ * exp(s.λ * z)
if h / η > rand(rng, flT)
return s.θ - z / s.α
end
end
end


function _logsampler(d::Gamma)
if shape(d) < 0.3
# Liu, Martin, and Syring recommend 0.3 as a cutoff to switch
# to Kundu-Gupta, but we have not implemented Kundu-Gupta yet.
return ExpGammaSSSampler(d)
else
# TODO: Kundu-Gupta algo. #3 for performance reasons?
return ExpGammaIPSampler(d)
end
end

function _logrand(rng::AbstractRNG, d::Gamma)
if shape(d) < 0.3
return rand(rng, ExpGammaSSSampler(d))
else
return rand(rng, ExpGammaIPSampler(d))
end
end

function _logrand!(rng::AbstractRNG, d::Gamma, A::AbstractArray{<:Real})
if shape(d) < 0.3
@inbounds for i in eachindex(A)
A[i] = rand(rng, ExpGammaSSSampler(d))
end
else
@inbounds for i in eachindex(A)
A[i] = rand(rng, ExpGammaIPSampler(d))
end
end
end
26 changes: 26 additions & 0 deletions test/multivariate/dirichlet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,29 @@ end
end
end
end

@testset "Dirichlet rand Inf and NaN (#1702)" begin
for d in [
Dirichlet([8e-5, 1e-5, 2e-5]),
Dirichlet([8e-4, 1e-4, 2e-4]),
Dirichlet([4.5e-5, 8e-5]),
Dirichlet([6e-5, 2e-5, 3e-5, 4e-5, 5e-5]),
Dirichlet(FillArrays.Fill(1e-5, 5))
]
x = rand(d, 10^6)
@test mean(x, dims = 2) ≈ mean(d) atol=0.01
@test var(x, dims = 2) ≈ var(d) atol=0.01
end

for (d, μ) in [ # Subnormal params cause mean(d) to error

(Dirichlet([5e-310, 5e-310, 5e-310]), [1/3, 1/3, 1/3]),
(Dirichlet(FillArrays.Fill(5e-310, 3)), [1/3, 1/3, 1/3]),
(Dirichlet([5e-321, 1e-321, 4e-321]), [.5, .1, .4]),
(Dirichlet([1e-321, 2e-321, 3e-321, 4e-321]), [.1, .2, .3, .4]),
(Dirichlet(FillArrays.Fill(1e-321, 4)), [.25, .25, .25, .25])
]
x = rand(d, 10^6)
@test mean(x, dims = 2) ≈ μ atol=0.01
end
end
Loading