Skip to content

Commit

Permalink
Address reviewer comments
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion authored Sep 25, 2024
1 parent 0746991 commit 38faf5c
Show file tree
Hide file tree
Showing 9 changed files with 28 additions and 29 deletions.
2 changes: 1 addition & 1 deletion src/univariate/continuous/lognormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ end

xval(d::LogNormal, z::Real) = exp(muladd(d.σ, z, d.μ))

rand(rng::AbstractRNG, d::LogNormal) = xval(d, randn(rng))
rand(rng::AbstractRNG, d::LogNormal) = xval(d, randn(rng))
function rand!(rng::AbstractRNG, d::LogNormal, A::AbstractArray{<:Real})
randn!(rng, A)
map!(Base.Fix1(xval, d), A, A)
Expand Down
2 changes: 1 addition & 1 deletion src/univariate/continuous/pgeneralizedgaussian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,4 +146,4 @@ function rand(rng::AbstractRNG, d::PGeneralizedGaussian)
else
return d.μ + z
end
end
end
2 changes: 1 addition & 1 deletion test/fit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ end
for func in funcs, dist in (Laplace, Laplace{Float64})
d = fit(dist, func[2](dist(5.0, 3.0), N + 1))
@test isa(d, dist)
@test isapprox(location(d), 5.0, atol=0.021)
@test isapprox(location(d), 5.0, atol=0.03)
@test isapprox(scale(d) , 3.0, atol=0.03)
end
end
Expand Down
41 changes: 20 additions & 21 deletions test/testutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ end
# testing the implementation of a discrete univariate distribution
#
function test_distr(distr::DiscreteUnivariateDistribution, n::Int;
testquan::Bool=true, rng::AbstractRNG=MersenneTwister(123))
testquan::Bool=true, rng::AbstractRNG = Random.default_rng())

test_range(distr)
vs = get_evalsamples(distr, 0.00001)

Expand All @@ -39,8 +40,7 @@ function test_distr(distr::DiscreteUnivariateDistribution, n::Int;

test_stats(distr, vs)
test_samples(distr, n)
test_samples(distr, n, rng=rng)

test_samples(distr, n; rng=rng)
test_params(distr)
end

Expand Down Expand Up @@ -82,7 +82,6 @@ function test_distr(distr::ContinuousUnivariateDistribution, n::Int;
allow_test_stats(distr) && test_stats(distr, xs)
xs = test_samples(distr, n, rng=rng)
allow_test_stats(distr) && test_stats(distr, xs)

test_params(distr)
end

Expand All @@ -103,6 +102,7 @@ function test_samples(s::Sampleable{Univariate, Discrete}, # the sampleable
q::Float64=1.0e-7, # confidence interval, 1 - q as confidence
verbose::Bool=false, # show intermediate info (for debugging)
rng::Union{AbstractRNG, Missing}=missing) # add an rng?

# The basic idea
# ------------------
# Generate n samples, and count the occurrences of each value within a reasonable range.
Expand Down Expand Up @@ -151,17 +151,17 @@ function test_samples(s::Sampleable{Univariate, Discrete}, # the sampleable
Random.seed!(1234)
samples2 = rand(s, n)
Random.seed!(1234)
samples3 = map((_) -> rand(s), 1:n)
samples3 = [rand(s) for _ in 1:n]
Random.seed!(1234)
samples4 = map((_) -> rand(s), 1:n)
samples4 = [rand(s) for _ in 1:n]
else
rng2 = deepcopy(rng)
rng3 = deepcopy(rng)
rng4 = deepcopy(rng)
samples = rand(rng, s, n)
samples2 = rand(rng2, s, n)
samples3 = map((_) -> rand(rng3, s), 1:n)
samples4 = map((_) -> rand(rng4, s), 1:n)
samples3 = [rand(rng3, s) for _ in 1:n]
samples4 = [rand(rng4, s) for _ in 1:n]
end
@test length(samples) == n
@test samples2 == samples
Expand All @@ -176,27 +176,27 @@ function test_samples(s::Sampleable{Univariate, Discrete}, # the sampleable
cnts[si - rmin + 1] += 1
else
vmin <= si <= vmax ||
error("Sample value out of valid range. (Vector Method)")
throw(DomainError(si, "sample generated by `rand(s, n)` is out of valid range [$vmin, $vmax]."))
end

@inbounds si_sc = samples3[i]
if rmin <= si_sc <= rmax
cnts_sc[si_sc - rmin + 1] += 1
else
vmin <= si_sc <= vmax ||
error("Sample value out of valid range. (Scalar Method)")
throw(DomainError(si, "sample generated by `[rand(s) for _ in 1:n]` is out of valid range [$vmin, $vmax]."))
end
end

# check the counts
for i = 1:m
verbose && println("v = $(rmin+i-1) ==> ($(clb[i]), $(cub[i])): $(cnts[i])")
clb[i] <= cnts[i] <= cub[i] ||
error("The counts are out of the confidence interval. (Vector Method)")
error("The counts of samples generated by `rand(s, n)` are out of the confidence interval.")

verbose && println("v = $(rmin+i-1) ==> ($(clb[i]), $(cub[i])): $(cnts_sc[i])")
clb[i] <= cnts_sc[i] <= cub[i] ||
error("The counts are out of the confidence interval. (Scalar Method)")
error("The counts of samples generated by `[rand(s) for _ in 1:n]` are out of the confidence interval.")
end
return samples
end
Expand Down Expand Up @@ -273,17 +273,17 @@ function test_samples(s::Sampleable{Univariate, Continuous}, # the sampleable
Random.seed!(1234)
samples2 = rand(s, n)
Random.seed!(1234)
samples3 = map((_) -> rand(s), 1:n)
samples3 = [rand(s) for _ in 1:n]
Random.seed!(1234)
samples4 = map((_) -> rand(s), 1:n)
samples4 = [rand(s) for _ in 1:n]
else
rng2 = deepcopy(rng)
rng3 = deepcopy(rng)
rng4 = deepcopy(rng)
samples = rand(rng, s, n)
samples2 = rand(rng2, s, n)
samples3 = map((_) -> rand(rng3, s), 1:n)
samples4 = map((_) -> rand(rng4, s), 1:n)
samples3 = [rand(rng3, s) for _ in 1:n]
samples4 = [rand(rng4, s) for _ in 1:n]
end
@test length(samples) == n
@test samples2 == samples
Expand All @@ -297,10 +297,10 @@ function test_samples(s::Sampleable{Univariate, Continuous}, # the sampleable
for i = 1:n
@inbounds si = samples[i]
vmin <= si <= vmax ||
error("Sample value out of valid range. (Vector Method)")
throw(DomainError(si, "sample generated by `rand(s, n)` is out of valid range [$vmin, $vmax]."))
@inbounds si_sc = samples3[i]
vmin <= si_sc <= vmax ||
error("Sample value out of valid range. (Scalar Method)")
throw(DomainError(si, "sample generated by `[rand(s) for _ in 1:n]` is out of valid range [$vmin, $vmax]."))
end

# get counts
Expand All @@ -317,9 +317,9 @@ function test_samples(s::Sampleable{Univariate, Continuous}, # the sampleable
@printf("[%.4f, %.4f) ==> (%d, %d): %d\n", edges[i], edges[i+1], clb[i], cub[i], cnts_sc[i])
end
clb[i] <= cnts[i] <= cub[i] ||
error("The counts are out of the confidence interval. (Vector Method)")
error("The counts of samples generated by `rand(s, n)` are out of the confidence interval.")
clb[i] <= cnts_sc[i] <= cub[i] ||
error("The counts are out of the confidence interval. (Scalar Method)")
error("The counts of samples generated by `[rand(s) for _ in 1:n]` are out of the confidence interval.")
end
return samples
end
Expand Down Expand Up @@ -623,7 +623,6 @@ end
allow_test_stats(d::UnivariateDistribution) = true
allow_test_stats(d::NoncentralBeta) = false
allow_test_stats(::StudentizedRange) = false
allow_test_stats(::LogitNormal) = false

function test_stats(d::ContinuousUnivariateDistribution, xs::AbstractVector{Float64})
# using Monte Carlo methods
Expand Down
2 changes: 1 addition & 1 deletion test/univariate/continuous/exponential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ end

test_cgf(Exponential(1), (0.9, -1, -100f0, -1e6))
test_cgf(Exponential(0.91), (0.9, -1, -100f0, -1e6))
test_cgf(Exponential(10 ), (0.08, -1, -100f0, -1e6))
test_cgf(Exponential(10), (0.08, -1, -100f0, -1e6))

# Sampling Tests
@testset "Exponential sampling tests" begin
Expand Down
2 changes: 1 addition & 1 deletion test/univariate/continuous/logitnormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,4 @@ end
]
test_distr(d, 10^6)
end
end
end
2 changes: 1 addition & 1 deletion test/univariate/continuous/lognormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -327,4 +327,4 @@ end
]
test_distr(d, 10^6)
end
end
end
2 changes: 1 addition & 1 deletion test/univariate/continuous/normalcanon.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
]
test_distr(d, 10^6)
end
end
end
2 changes: 1 addition & 1 deletion test/univariate/continuous/pareto.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
]
test_distr(d, 10^6)
end
end
end

0 comments on commit 38faf5c

Please sign in to comment.