Skip to content

Commit

Permalink
Specialized vector rand! for many distributions (#1879)
Browse files Browse the repository at this point in the history
* Test scalar rand separately from vector rand

* Add specialized rand! for many distributions

* Restore location of old NormalInverseGaussian tests

* Remove duplication of inversegaussian in runtests.jl

* Apply many suggestions from code review

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>

* Apply other suggestions

* Remove redundant new tests

* Clean up more

* Partially undo previous undo to changes to tests

* Use xval for NormalCanon rand

* Apply suggestions from code review

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>

* Apply other recommendations to testutils

* Fix erroneous !

* Address reviewer comments

* `mean` not defined for `LogitNormal`

* Copy RNG with `copy`, not `deepcopy`

---------

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
  • Loading branch information
quildtide and devmotion authored Sep 25, 2024
1 parent b219803 commit 08c56ea
Show file tree
Hide file tree
Showing 16 changed files with 160 additions and 21 deletions.
5 changes: 5 additions & 0 deletions src/univariate/continuous/exponential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ cf(d::Exponential, t::Real) = 1/(1 - t * im * scale(d))
#### Sampling
rand(rng::AbstractRNG, d::Exponential{T}) where {T} = xval(d, randexp(rng, float(T)))

function rand!(rng::AbstractRNG, d::Exponential, A::AbstractArray{<:Real})
randexp!(rng, A)
map!(Base.Fix1(xval, d), A, A)
return A
end

#### Fit model

Expand Down
9 changes: 8 additions & 1 deletion src/univariate/continuous/logitnormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,14 @@ end

#### Sampling

rand(rng::AbstractRNG, d::LogitNormal) = logistic(randn(rng) * d.σ + d.μ)
xval(d::LogitNormal, z::Real) = logistic(muladd(d.σ, z, d.μ))

rand(rng::AbstractRNG, d::LogitNormal) = xval(d, randn(rng))
function rand!(rng::AbstractRNG, d::LogitNormal, A::AbstractArray{<:Real})
randn!(rng, A)
map!(Base.Fix1(xval, d), A, A)
return A
end

## Fitting

Expand Down
9 changes: 8 additions & 1 deletion src/univariate/continuous/lognormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,14 @@ end

#### Sampling

rand(rng::AbstractRNG, d::LogNormal) = exp(randn(rng) * d.σ + d.μ)
xval(d::LogNormal, z::Real) = exp(muladd(d.σ, z, d.μ))

rand(rng::AbstractRNG, d::LogNormal) = xval(d, randn(rng))
function rand!(rng::AbstractRNG, d::LogNormal, A::AbstractArray{<:Real})
randn!(rng, A)
map!(Base.Fix1(xval, d), A, A)
return A
end

## Fitting

Expand Down
9 changes: 7 additions & 2 deletions src/univariate/continuous/normal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,14 @@ Base.:*(c::Real, d::Normal) = Normal(c * d.μ, abs(c) * d.σ)

#### Sampling

rand(rng::AbstractRNG, d::Normal{T}) where {T} = d.μ + d.σ * randn(rng, float(T))
xval(d::Normal, z::Real) = muladd(d.σ, z, d.μ)

rand!(rng::AbstractRNG, d::Normal, A::AbstractArray{<:Real}) = A .= muladd.(d.σ, randn!(rng, A), d.μ)
rand(rng::AbstractRNG, d::Normal{T}) where {T} = xval(d, randn(rng, float(T)))
function rand!(rng::AbstractRNG, d::Normal, A::AbstractArray{<:Real})
randn!(rng, A)
map!(Base.Fix1(xval, d), A, A)
return A
end

#### Fitting

Expand Down
8 changes: 7 additions & 1 deletion src/univariate/continuous/normalcanon.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,13 @@ invlogccdf(d::NormalCanon, lp::Real) = xval(d, norminvlogccdf(lp))

#### Sampling

rand(rng::AbstractRNG, cf::NormalCanon) = cf.μ + randn(rng) / sqrt(cf.λ)
rand(rng::AbstractRNG, cf::NormalCanon) = xval(cf, randn(rng))

function rand!(rng::AbstractRNG, cf::NormalCanon, A::AbstractArray{<:Real})
randn!(rng, A)
map!(Base.Fix1(xval, cf), A, A)
return A
end

#### Affine transformations

Expand Down
9 changes: 8 additions & 1 deletion src/univariate/continuous/pareto.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,14 @@ quantile(d::Pareto, p::Real) = cquantile(d, 1 - p)

#### Sampling

rand(rng::AbstractRNG, d::Pareto) = d.θ * exp(randexp(rng) / d.α)
xval(d::Pareto, z::Real) = d.θ * exp(z / d.α)

rand(rng::AbstractRNG, d::Pareto) = xval(d, randexp(rng))
function rand!(rng::AbstractRNG, d::Pareto, A::AbstractArray{<:Real})
randexp!(rng, A)
map!(Base.Fix1(xval, d), A, A)
return A
end

## Fitting

Expand Down
2 changes: 1 addition & 1 deletion src/univariate/continuous/pgeneralizedgaussian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ function rand(rng::AbstractRNG, d::PGeneralizedGaussian)
inv_p = inv(d.p)
g = Gamma(inv_p, 1)
z = d.α * rand(rng, g)^inv_p
if rand(rng) < 0.5
if rand(rng, Bool)
return d.μ - z
else
return d.μ + z
Expand Down
2 changes: 1 addition & 1 deletion test/fit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ end
for func in funcs, dist in (Laplace, Laplace{Float64})
d = fit(dist, func[2](dist(5.0, 3.0), N + 1))
@test isa(d, dist)
@test isapprox(location(d), 5.0, atol=0.02)
@test isapprox(location(d), 5.0, atol=0.03)
@test isapprox(scale(d) , 3.0, atol=0.03)
end
end
Expand Down
4 changes: 2 additions & 2 deletions test/multivariate/mvlognormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ end
@test entropy(l1) entropy(l2)
@test logpdf(l1,5.0) logpdf(l2,[5.0])
@test pdf(l1,5.0) pdf(l2,[5.0])
@test (Random.seed!(78393) ; [rand(l1)]) == (Random.seed!(78393) ; rand(l2))
@test [rand(MersenneTwister(78393), l1)] == rand(MersenneTwister(78393), l2)
@test (Random.seed!(78393) ; [rand(l1)]) (Random.seed!(78393) ; rand(l2))
@test [rand(MersenneTwister(78393), l1)] rand(MersenneTwister(78393), l2)
end

###### General Testing
Expand Down
4 changes: 2 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ const tests = [
"truncated/discrete_uniform",
"censored",
"univariate/continuous/normal",
"univariate/continuous/normalcanon",
"univariate/continuous/laplace",
"univariate/continuous/cauchy",
"univariate/continuous/uniform",
Expand Down Expand Up @@ -83,6 +84,7 @@ const tests = [
"univariate/continuous/noncentralchisq",
"univariate/continuous/weibull",
"pdfnorm",
"univariate/continuous/pareto",
"univariate/continuous/rician",
"functionals",
"density_interface",
Expand Down Expand Up @@ -143,9 +145,7 @@ const tests = [
# "univariate/continuous/levy",
# "univariate/continuous/noncentralbeta",
# "univariate/continuous/noncentralf",
# "univariate/continuous/normalcanon",
# "univariate/continuous/normalinversegaussian",
# "univariate/continuous/pareto",
# "univariate/continuous/rayleigh",
# "univariate/continuous/studentizedrange",
# "univariate/continuous/symtriangular",
Expand Down
61 changes: 53 additions & 8 deletions test/testutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ end
# testing the implementation of a discrete univariate distribution
#
function test_distr(distr::DiscreteUnivariateDistribution, n::Int;
testquan::Bool=true)
testquan::Bool=true, rng::AbstractRNG = Random.default_rng())

test_range(distr)
vs = get_evalsamples(distr, 0.00001)
Expand All @@ -40,7 +40,7 @@ function test_distr(distr::DiscreteUnivariateDistribution, n::Int;

test_stats(distr, vs)
test_samples(distr, n)
test_samples(distr, n, rng=MersenneTwister())
test_samples(distr, n; rng=rng)
test_params(distr)
end

Expand Down Expand Up @@ -150,31 +150,55 @@ function test_samples(s::Sampleable{Univariate, Discrete}, # the sampleable
samples = rand(s, n)
Random.seed!(1234)
samples2 = rand(s, n)
Random.seed!(1234)
samples3 = [rand(s) for _ in 1:n]
Random.seed!(1234)
samples4 = [rand(s) for _ in 1:n]
else
rng2 = deepcopy(rng)
# RNGs have to be copied with `copy`, not `deepcopy`
# Ref https://github.com/JuliaLang/julia/issues/42899
rng2 = copy(rng)
rng3 = copy(rng)
rng4 = copy(rng)
samples = rand(rng, s, n)
samples2 = rand(rng2, s, n)
samples3 = [rand(rng3, s) for _ in 1:n]
samples4 = [rand(rng4, s) for _ in 1:n]
end
@test length(samples) == n
@test samples2 == samples
@test samples3 == samples4

# scan samples and get counts
cnts = zeros(Int, m)
cnts_sc = zeros(Int, m)
for i = 1:n
@inbounds si = samples[i]
if rmin <= si <= rmax
cnts[si - rmin + 1] += 1
else
vmin <= si <= vmax ||
error("Sample value out of valid range.")
throw(DomainError(si, "sample generated by `rand(s, n)` is out of valid range [$vmin, $vmax]."))
end

@inbounds si_sc = samples3[i]
if rmin <= si_sc <= rmax
cnts_sc[si_sc - rmin + 1] += 1
else
vmin <= si_sc <= vmax ||
throw(DomainError(si, "sample generated by `[rand(s) for _ in 1:n]` is out of valid range [$vmin, $vmax]."))
end
end

# check the counts
for i = 1:m
verbose && println("v = $(rmin+i-1) ==> ($(clb[i]), $(cub[i])): $(cnts[i])")
clb[i] <= cnts[i] <= cub[i] ||
error("The counts are out of the confidence interval.")
error("The counts of samples generated by `rand(s, n)` are out of the confidence interval.")

verbose && println("v = $(rmin+i-1) ==> ($(clb[i]), $(cub[i])): $(cnts_sc[i])")
clb[i] <= cnts_sc[i] <= cub[i] ||
error("The counts of samples generated by `[rand(s) for _ in 1:n]` are out of the confidence interval.")
end
return samples
end
Expand Down Expand Up @@ -250,13 +274,24 @@ function test_samples(s::Sampleable{Univariate, Continuous}, # the sampleable
samples = rand(s, n)
Random.seed!(1234)
samples2 = rand(s, n)
Random.seed!(1234)
samples3 = [rand(s) for _ in 1:n]
Random.seed!(1234)
samples4 = [rand(s) for _ in 1:n]
else
rng2 = deepcopy(rng)
# RNGs have to be copied with `copy`, not `deepcopy`
# Ref https://github.com/JuliaLang/julia/issues/42899
rng2 = copy(rng)
rng3 = copy(rng)
rng4 = copy(rng)
samples = rand(rng, s, n)
samples2 = rand(rng2, s, n)
samples3 = [rand(rng3, s) for _ in 1:n]
samples4 = [rand(rng4, s) for _ in 1:n]
end
@test length(samples) == n
@test samples2 == samples
@test samples3 == samples4

if isa(distr, StudentizedRange)
samples[isnan.(samples)] .= 0.0 # Underlying implementation in Rmath can't handle very low values.
Expand All @@ -266,20 +301,29 @@ function test_samples(s::Sampleable{Univariate, Continuous}, # the sampleable
for i = 1:n
@inbounds si = samples[i]
vmin <= si <= vmax ||
error("Sample value out of valid range.")
throw(DomainError(si, "sample generated by `rand(s, n)` is out of valid range [$vmin, $vmax]."))
@inbounds si_sc = samples3[i]
vmin <= si_sc <= vmax ||
throw(DomainError(si, "sample generated by `[rand(s) for _ in 1:n]` is out of valid range [$vmin, $vmax]."))
end

# get counts
cnts = fit(Histogram, samples, edges; closed=:right).weights
@assert length(cnts) == nbins

cnts_sc = fit(Histogram, samples3, edges; closed=:right).weights
@assert length(cnts_sc) == nbins

# check the counts
for i = 1:nbins
if verbose
@printf("[%.4f, %.4f) ==> (%d, %d): %d\n", edges[i], edges[i+1], clb[i], cub[i], cnts[i])
@printf("[%.4f, %.4f) ==> (%d, %d): %d\n", edges[i], edges[i+1], clb[i], cub[i], cnts_sc[i])
end
clb[i] <= cnts[i] <= cub[i] ||
error("The counts are out of the confidence interval.")
error("The counts of samples generated by `rand(s, n)` are out of the confidence interval.")
clb[i] <= cnts_sc[i] <= cub[i] ||
error("The counts of samples generated by `[rand(s) for _ in 1:n]` are out of the confidence interval.")
end
return samples
end
Expand Down Expand Up @@ -583,6 +627,7 @@ end
allow_test_stats(d::UnivariateDistribution) = true
allow_test_stats(d::NoncentralBeta) = false
allow_test_stats(::StudentizedRange) = false
allow_test_stats(::LogitNormal) = false # `mean` is not defined since it has no analytical solution

function test_stats(d::ContinuousUnivariateDistribution, xs::AbstractVector{Float64})
# using Monte Carlo methods
Expand Down
16 changes: 15 additions & 1 deletion test/univariate/continuous/exponential.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

@testset "Exponential" begin
test_cgf(Exponential(1), (0.9, -1, -100f0, -1e6))
test_cgf(Exponential(0.91), (0.9, -1, -100f0, -1e6))
Expand All @@ -8,3 +7,18 @@
@test @inferred(rand(Exponential(T(1)))) isa T
end
end

test_cgf(Exponential(1), (0.9, -1, -100f0, -1e6))
test_cgf(Exponential(0.91), (0.9, -1, -100f0, -1e6))
test_cgf(Exponential(10), (0.08, -1, -100f0, -1e6))

# Sampling Tests
@testset "Exponential sampling tests" begin
for d in [
Exponential(1),
Exponential(0.91),
Exponential(10)
]
test_distr(d, 10^6)
end
end
9 changes: 9 additions & 0 deletions test/univariate/continuous/logitnormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,12 @@ end
@test convert(LogitNormal{Float32}, d) === d
@test typeof(convert(LogitNormal{Float64}, d)) == typeof(LogitNormal(2,1))
end

@testset "Logitnormal Sampling Tests" begin
for d in [
LogitNormal(-2, 3),
LogitNormal(0, 0.2)
]
test_distr(d, 10^6)
end
end
14 changes: 14 additions & 0 deletions test/univariate/continuous/lognormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -314,3 +314,17 @@ end
@test @inferred(gradlogpdf(LogNormal(0.0, 1.0), BigFloat(-1))) == big(0.0)
@test isnan_type(BigFloat, @inferred(gradlogpdf(LogNormal(0.0, 1.0), BigFloat(NaN))))
end

@testset "LogNormal Sampling Tests" begin
for d in [
LogNormal()
LogNormal(1.0)
LogNormal(0.0, 2.0)
LogNormal(1.0, 2.0)
LogNormal(3.0, 0.5)
LogNormal(3.0, 1.0)
LogNormal(3.0, 2.0)
]
test_distr(d, 10^6)
end
end
10 changes: 10 additions & 0 deletions test/univariate/continuous/normalcanon.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Sampling Tests
@testset "NormalCanon sampling tests" begin
for d in [
NormalCanon()
NormalCanon(-1.0, 2.5)
NormalCanon(2.0, 0.8)
]
test_distr(d, 10^6)
end
end
10 changes: 10 additions & 0 deletions test/univariate/continuous/pareto.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
@testset "Pareto Sampling Tests" begin
for d in [
Pareto()
Pareto(2.0)
Pareto(2.0, 1.5)
Pareto(3.0, 2.0)
]
test_distr(d, 10^6)
end
end

0 comments on commit 08c56ea

Please sign in to comment.