Skip to content

Commit

Permalink
fix type stability of sampling from Chisq, TDist, Gamma (#1885)
Browse files Browse the repository at this point in the history
* fix type stability of sampling from `Chisq`, `TDist`, `Gamma`

* fix remove type specification in `rand(Exponential)`

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>

* fix type specificaton in `rand(TDist)`

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>

* fix remove type test for `rand(Chisq)`

* fix make `Exponential` use the `Normal` sampling type policy

* fix missing type signature

* fix type signature for `rand(Exponential)`

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>

* fix use `@inferred` in tests for `Gamma`

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>

* fix use `@inferred` in tests for `TDist`

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>

* add type stability tests for `rand(Exponential)`

* add type stability test for `rand(Chisq)`

* fix remove type stability test for `entropy(TDist)` (not stable)

---------

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
  • Loading branch information
Red-Portal and devmotion authored Aug 23, 2024
1 parent 13029c0 commit 3946acc
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 28 deletions.
2 changes: 1 addition & 1 deletion src/samplers/gamma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,6 @@ end

function rand(rng::AbstractRNG, s::GammaIPSampler)
x = rand(rng, s.s)
e = randexp(rng)
e = randexp(rng, typeof(x))
x*exp(s.nia*e)
end
2 changes: 1 addition & 1 deletion src/univariate/continuous/exponential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ cf(d::Exponential, t::Real) = 1/(1 - t * im * scale(d))


#### Sampling
rand(rng::AbstractRNG, d::Exponential) = xval(d, randexp(rng))
rand(rng::AbstractRNG, d::Exponential{T}) where {T} = xval(d, randexp(rng, float(T)))


#### Fit model
Expand Down
2 changes: 1 addition & 1 deletion src/univariate/continuous/tdist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ end
function rand(rng::AbstractRNG, d::TDist)
ν = d.ν
z = sqrt(rand(rng, Chisq{typeof(ν)}(ν)) / ν)
return randn(rng) / (isinf(ν) ? one(z) : z)
return randn(rng, typeof(z)) / (isinf(ν) ? one(z) : z)
end

function cf(d::TDist{T}, t::Real) where T <: Real
Expand Down
11 changes: 9 additions & 2 deletions test/univariate/continuous/chisq.jl
Original file line number Diff line number Diff line change
@@ -1,2 +1,9 @@
test_cgf(Chisq(1), (0.49, -1, -100, -1f6))
test_cgf(Chisq(3), (0.49, -1, -100, -1f6))

@testset "Chisq" begin
test_cgf(Chisq(1), (0.49, -1, -100, -1.0f6))
test_cgf(Chisq(3), (0.49, -1, -100, -1.0f6))

for T in (Float32, Float64)
@test @inferred(rand(Chisq(T(1)))) isa T
end
end
12 changes: 9 additions & 3 deletions test/univariate/continuous/exponential.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@

test_cgf(Exponential(1), (0.9, -1, -100f0, -1e6))
test_cgf(Exponential(0.91), (0.9, -1, -100f0, -1e6))
test_cgf(Exponential(10 ), (0.08, -1, -100f0, -1e6))
@testset "Exponential" begin
test_cgf(Exponential(1), (0.9, -1, -100f0, -1e6))
test_cgf(Exponential(0.91), (0.9, -1, -100f0, -1e6))
test_cgf(Exponential(10 ), (0.08, -1, -100f0, -1e6))

for T in (Float32, Float64)
@test @inferred(rand(Exponential(T(1)))) isa T
end
end
38 changes: 23 additions & 15 deletions test/univariate/continuous/gamma.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,33 @@
using Test, Distributions, OffsetArrays

test_cgf(Gamma(1 ,1 ), (0.9, -1, -100f0, -1e6))
test_cgf(Gamma(10 ,1 ), (0.9, -1, -100f0, -1e6))
test_cgf(Gamma(0.2, 10), (0.08, -1, -100f0, -1e6))
@testset "Gamma" begin
test_cgf(Gamma(1, 1), (0.9, -1, -100.0f0, -1e6))
test_cgf(Gamma(10, 1), (0.9, -1, -100.0f0, -1e6))
test_cgf(Gamma(0.2, 10), (0.08, -1, -100.0f0, -1e6))

@testset "Gamma suffstats and OffsetArrays" begin
a = rand(Gamma(), 11)
wa = 1.0:11.0
@testset "Gamma suffstats and OffsetArrays" begin
a = rand(Gamma(), 11)
wa = 1.0:11.0

resulta = @inferred(suffstats(Gamma, a))
resulta = @inferred(suffstats(Gamma, a))

resultwa = @inferred(suffstats(Gamma, a, wa))
resultwa = @inferred(suffstats(Gamma, a, wa))

b = OffsetArray(a, -5:5)
wb = OffsetArray(wa, -5:5)
b = OffsetArray(a, -5:5)
wb = OffsetArray(wa, -5:5)

resultb = @inferred(suffstats(Gamma, b))
@test resulta == resultb
resultb = @inferred(suffstats(Gamma, b))
@test resulta == resultb

resultwb = @inferred(suffstats(Gamma, b, wb))
@test resultwa == resultwb
resultwb = @inferred(suffstats(Gamma, b, wb))
@test resultwa == resultwb

@test_throws DimensionMismatch suffstats(Gamma, a, wb)
@test_throws DimensionMismatch suffstats(Gamma, a, wb)
end

for T in (Float32, Float64)
@test @inferred(rand(Gamma(T(1), T(1)))) isa T
@test @inferred(rand(Gamma(1/T(2), T(1)))) isa T
@test @inferred(rand(Gamma(T(2), T(1)))) isa T
end
end
17 changes: 12 additions & 5 deletions test/univariate/continuous/tdist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,17 @@ using ForwardDiff

using Test

@testset "Type stability of `rand` (#1614)" begin
if VERSION >= v"1.9.0-DEV.348"
# randn(::BigFloat) was only added in https://github.com/JuliaLang/julia/pull/44714
@inferred(rand(TDist(big"1.0")))
@testset "TDist" begin
@testset "Type stability of `rand` (#1614)" begin
if VERSION >= v"1.9.0-DEV.348"
# randn(::BigFloat) was only added in https://github.com/JuliaLang/julia/pull/44714
@inferred(rand(TDist(big"1.0")))
end
@inferred(rand(TDist(ForwardDiff.Dual(1.0))))

end

for T in (Float32, Float64)
@test @inferred(rand(TDist(T(1)))) isa T
end
@inferred(rand(TDist(ForwardDiff.Dual(1.0))))
end

0 comments on commit 3946acc

Please sign in to comment.