-
Notifications
You must be signed in to change notification settings - Fork 426
Fix Dirichlet rand overflows #1702 #1886
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
5332f3a
1b39c0c
d1baaf4
735324b
1ae6210
06d8172
09842e3
f50e8c8
0bd5b5c
fbc6763
c77e35c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -156,18 +156,61 @@ end | |||||||||||||
|
||||||||||||||
function _rand!(rng::AbstractRNG, | ||||||||||||||
d::Union{Dirichlet,DirichletCanon}, | ||||||||||||||
x::AbstractVector{<:Real}) | ||||||||||||||
for (i, αi) in zip(eachindex(x), d.alpha) | ||||||||||||||
@inbounds x[i] = rand(rng, Gamma(αi)) | ||||||||||||||
x::AbstractVector{E}) where {E<:Real} | ||||||||||||||
|
||||||||||||||
if any(a -> a >= .5, d.alpha) | ||||||||||||||
# 0.5 is a placeholder; optimal value unknown | ||||||||||||||
# 1 is known to be too high. | ||||||||||||||
for (i, αi) in zip(eachindex(x), d.alpha) | ||||||||||||||
@inbounds x[i] = rand(rng, Gamma(αi)) | ||||||||||||||
end | ||||||||||||||
|
||||||||||||||
return lmul!(inv(sum(x)), x) | ||||||||||||||
else | ||||||||||||||
# Sample in log-space to lower underflow risk | ||||||||||||||
for (i, αi) in zip(eachindex(x), d.alpha) | ||||||||||||||
@inbounds x[i] = _logrand(rng, Gamma(αi)) | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Conceptually, I think this should be
Suggested change
|
||||||||||||||
end | ||||||||||||||
|
||||||||||||||
if all(isinf, x) | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems somewhat incorrect - you're only looking for |
||||||||||||||
# Final fallback, parameters likely deeply subnormal | ||||||||||||||
# Distribution behavior approaches categorical as Σα -> 0 | ||||||||||||||
p = copy(d.alpha) | ||||||||||||||
p .*= floatmax(eltype(p)) # rescale to non-subnormal | ||||||||||||||
x .= zero(E) | ||||||||||||||
x[rand(rng, Categorical(inv(sum(p)) .* p))] = one(E) | ||||||||||||||
Comment on lines
+178
to
+181
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This causes many allocations and won't work in general - |
||||||||||||||
return x | ||||||||||||||
end | ||||||||||||||
|
||||||||||||||
return softmax!(x) | ||||||||||||||
end | ||||||||||||||
lmul!(inv(sum(x)), x) # this returns x | ||||||||||||||
end | ||||||||||||||
|
||||||||||||||
function _rand!(rng::AbstractRNG, | ||||||||||||||
d::Dirichlet{T,<:FillArrays.AbstractFill{T}}, | ||||||||||||||
x::AbstractVector{<:Real}) where {T<:Real} | ||||||||||||||
rand!(rng, Gamma(FillArrays.getindex_value(d.alpha)), x) | ||||||||||||||
lmul!(inv(sum(x)), x) # this returns x | ||||||||||||||
x::AbstractVector{E}) where {T<:Real, E<:Real} | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need to change the function signature here (see comment below). |
||||||||||||||
|
||||||||||||||
if FillArrays.getindex_value(d.alpha) >= 0.5 | ||||||||||||||
# 0.5 is a placeholder; optimal value unknown | ||||||||||||||
# 1 is known to be too high. | ||||||||||||||
rand!(rng, Gamma(FillArrays.getindex_value(d.alpha)), x) | ||||||||||||||
return lmul!(inv(sum(x)), x) | ||||||||||||||
else | ||||||||||||||
# Sample in log-space to lower underflow risk | ||||||||||||||
_logrand!(rng, Gamma(FillArrays.getindex_value(d.alpha)), x) | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here, conceptually, this should be
Suggested change
|
||||||||||||||
|
||||||||||||||
if all(isinf, x) | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same questions as above. |
||||||||||||||
# Final fallback, parameters likely deeply subnormal | ||||||||||||||
# Distribution behavior approaches categorical as Σα -> 0 | ||||||||||||||
n = length(d.alpha) | ||||||||||||||
p = Fill(inv(n), n) | ||||||||||||||
x .= zero(E) | ||||||||||||||
x[rand(rng, Categorical(p))] = one(E) | ||||||||||||||
Comment on lines
+205
to
+208
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This won't work in general -
Suggested change
(or alternatively zero(eltype(x)) and oneunit(eltype(x))). |
||||||||||||||
return x | ||||||||||||||
end | ||||||||||||||
|
||||||||||||||
return softmax!(x) | ||||||||||||||
end | ||||||||||||||
end | ||||||||||||||
|
||||||||||||||
####################################### | ||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# These are used to bypass subnormals when sampling from | ||
|
||
# Inverse Power sampler | ||
# uses the x*u^(1/a) trick from Marsaglia and Tsang (2000) for when shape < 1 | ||
struct ExpGammaIPSampler{S<:Sampleable{Univariate,Continuous},T<:Real} <: Sampleable{Univariate,Continuous} | ||
s::S #sampler for Gamma(1+shape,scale) | ||
nia::T #-1/scale | ||
end | ||
|
||
ExpGammaIPSampler(d::Gamma) = ExpGammaIPSampler(d, GammaMTSampler) | ||
function ExpGammaIPSampler(d::Gamma, ::Type{S}) where {S<:Sampleable} | ||
shape_d = shape(d) | ||
sampler = S(Gamma{partype(d)}(1 + shape_d, scale(d))) | ||
return GammaIPSampler(sampler, -inv(shape_d)) | ||
end | ||
|
||
function rand(rng::AbstractRNG, s::ExpGammaIPSampler) | ||
x = log(rand(rng, s.s)) | ||
e = randexp(rng, typeof(x)) | ||
return muladd(s.nia, e, x) | ||
end | ||
|
||
|
||
# Small Shape sampler | ||
# From Liu, C., Martin, R., and Syring, N. (2015) for when shape < 0.3 | ||
struct ExpGammaSSSampler{T<:Real} <: Sampleable{Univariate,Continuous} | ||
α::T | ||
θ::T | ||
λ::T | ||
ω::T | ||
ωω::T | ||
end | ||
|
||
function ExpGammaSSSampler(d::Gamma) | ||
α = shape(d) | ||
ω = α / MathConstants.e / (1 - α) | ||
return ExpGammaSSSampler(promote( | ||
α, | ||
scale(d), | ||
inv(α) - 1, | ||
ω, | ||
inv(ω + 1) | ||
)...) | ||
end | ||
|
||
function rand(rng::AbstractRNG, s::ExpGammaSSSampler{T})::Float64 where T | ||
flT = float(T) | ||
while true | ||
U = rand(rng, flT) | ||
z = (U <= s.ωω) ? -log(U / s.ωω) : log(rand(rng, flT)) / s.λ | ||
h = exp(-z - exp(-z / s.α)) | ||
η = z >= zero(T) ? exp(-z) : s.ω * s.λ * exp(s.λ * z) | ||
if h / η > rand(rng, flT) | ||
return s.θ - z / s.α | ||
end | ||
end | ||
end | ||
|
||
|
||
function _logsampler(d::Gamma) | ||
if shape(d) < 0.3 | ||
# Liu, Martin, and Syring recommend 0.3 as a cutoff to switch | ||
# to Kundu-Gupta, but we have not implemented Kundu-Gupta yet. | ||
return ExpGammaSSSampler(d) | ||
else | ||
# TODO: Kundu-Gupta algo. #3 for performance reasons? | ||
return ExpGammaIPSampler(d) | ||
end | ||
end | ||
|
||
function _logrand(rng::AbstractRNG, d::Gamma) | ||
if shape(d) < 0.3 | ||
return rand(rng, ExpGammaSSSampler(d)) | ||
else | ||
return rand(rng, ExpGammaIPSampler(d)) | ||
end | ||
end | ||
|
||
function _logrand!(rng::AbstractRNG, d::Gamma, A::AbstractArray{<:Real}) | ||
if shape(d) < 0.3 | ||
@inbounds for i in eachindex(A) | ||
A[i] = rand(rng, ExpGammaSSSampler(d)) | ||
end | ||
else | ||
@inbounds for i in eachindex(A) | ||
A[i] = rand(rng, ExpGammaIPSampler(d)) | ||
end | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why 0.5?