Skip to content

Commit

Permalink
Apply other recommendations to testutils
Browse files Browse the repository at this point in the history
  • Loading branch information
quildtide committed Sep 4, 2024
1 parent d13c8fb commit 1f7e28a
Showing 1 changed file with 55 additions and 58 deletions.
113 changes: 55 additions & 58 deletions test/testutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ end
# testing the implementation of a discrete univariate distribution
#
function test_distr(distr::DiscreteUnivariateDistribution, n::Int;
testquan::Bool=true, rng::AbstractRNG=MersenneTwister(),
test_scalar_rand::Bool=false)
testquan::Bool=true, rng::AbstractRNG=MersenneTwister(123))
test_range(distr)
vs = get_evalsamples(distr, 0.00001)

Expand All @@ -41,10 +40,6 @@ function test_distr(distr::DiscreteUnivariateDistribution, n::Int;
test_stats(distr, vs)
test_samples(distr, n)
test_samples(distr, n, rng=rng)
if test_scalar_rand
xs = test_samples(distr, n; call_scalar = true)
xs = test_samples(distr, n, rng=rng, call_scalar = true)
end

test_params(distr)
end
Expand Down Expand Up @@ -87,12 +82,6 @@ function test_distr(distr::ContinuousUnivariateDistribution, n::Int;
allow_test_stats(distr) && test_stats(distr, xs)
xs = test_samples(distr, n, rng=rng)
allow_test_stats(distr) && test_stats(distr, xs)
if test_scalar_rand
xs = test_samples(distr, n; call_scalar = true)
allow_test_stats(distr) && test_stats(distr, xs)
xs = test_samples(distr, n, rng=rng, call_scalar = true)
allow_test_stats(distr) && test_stats(distr, xs)
end

test_params(distr)
end
Expand Down Expand Up @@ -157,51 +146,57 @@ function test_samples(s::Sampleable{Univariate, Discrete}, # the sampleable
# generate samples using RNG passed or default RNG
# we also check reproducibility
if rng === missing
Random.seed!(1234)
samples = if !call_scalar
rand(s, n)
else
map((_) -> rand(s), 1:n)
end
Random.seed!(1234)
samples2 = if !call_scalar
rand(s, n)
else
map((_) -> rand(s), 1:n)
end
samples = rand(s, n)
Random.seed!(1234)
samples2 = rand(s, n)
Random.seed!(1234)
samples3 = map!((_) -> rand(s), 1:n)
Random.seed!(1234)
samples4 = map((_) -> rand(s), 1:n)
else
rng2 = deepcopy(rng)
samples = if !call_scalar
rand(rng, s, n)
else
map((_) -> rand(rng, s), 1:n)
end
samples2 = if !call_scalar
rand(rng2, s, n)
else
map((_) -> rand(rng2, s), 1:n)
end
rng3 = deepcopy(rng)
rng4 = deepcopy(rng)
samples = rand(rng, s, n)
samples2 = rand(rng2, s, n)
samples3 = map((_) -> rand(rng3, s), 1:n)
samples4 = map((_) -> rand(rng4, s), 1:n)
end
@test length(samples) == n
@test samples2 == samples
@test samples3 == samples4

# scan samples and get counts
cnts = zeros(Int, m)
cnts_sc = zeros(Int, m)
for i = 1:n
@inbounds si = samples[i]
if rmin <= si <= rmax
cnts[si - rmin + 1] += 1
else
vmin <= si <= vmax ||
error("Sample value out of valid range.")
error("Sample value out of valid range. (Vector Method)")
end

@inbounds si_sc = samples3[i]
if rmin <= si_sc <= rmax
cnts_sc[si_sc - rmin + 1] += 1
else
vmin <= si_sc <= vmax ||
error("Sample value out of valid range. (Scalar Method)")
end
end

# check the counts
for i = 1:m
verbose && println("v = $(rmin+i-1) ==> ($(clb[i]), $(cub[i])): $(cnts[i])")
clb[i] <= cnts[i] <= cub[i] ||
error("The counts are out of the confidence interval.")
error("The counts are out of the confidence interval. (Vector Method)")

verbose && println("v = $(rmin+i-1) ==> ($(clb[i]), $(cub[i])): $(cnts_sc[i])")
clb[i] <= cnts_sc[i] <= cub[i] ||
error("The counts are out of the confidence interval. (Scalar Method)")
end
return samples
end
Expand Down Expand Up @@ -273,33 +268,26 @@ function test_samples(s::Sampleable{Univariate, Continuous}, # the sampleable
# generate samples using RNG passed or default RNG
# we also check reproducibility
if rng === missing
Random.seed!(1234)
samples = if !call_scalar
rand(s, n)
else
map((_) -> rand(s), 1:n)
end
Random.seed!(1234)
samples2 = if !call_scalar
rand(s, n)
else
map((_) -> rand(s), 1:n)
end
samples = rand(s, n)
Random.seed!(1234)
samples2 = rand(s, n)
Random.seed!(1234)
samples3 = map!((_) -> rand(s), 1:n)
Random.seed!(1234)
samples4 = map((_) -> rand(s), 1:n)
else
rng2 = deepcopy(rng)
samples = if !call_scalar
rand(rng, s, n)
else
map((_) -> rand(rng, s), 1:n)
end
samples2 = if !call_scalar
rand(rng2, s, n)
else
map((_) -> rand(rng2, s), 1:n)
end
rng3 = deepcopy(rng)
rng4 = deepcopy(rng)
samples = rand(rng, s, n)
samples2 = rand(rng2, s, n)
samples3 = map((_) -> rand(rng3, s), 1:n)
samples4 = map((_) -> rand(rng4, s), 1:n)
end
@test length(samples) == n
@test samples2 == samples
@test samples3 == samples4

if isa(distr, StudentizedRange)
samples[isnan.(samples)] .= 0.0 # Underlying implementation in Rmath can't handle very low values.
Expand All @@ -309,20 +297,29 @@ function test_samples(s::Sampleable{Univariate, Continuous}, # the sampleable
for i = 1:n
@inbounds si = samples[i]
vmin <= si <= vmax ||
error("Sample value out of valid range.")
error("Sample value out of valid range. (Vector Method)")
@inbounds si_sc = samples3[i]
vmin <= si_sc <= vmax ||
error("Sample value out of valid range. (Scalar Method)")
end

# get counts
cnts = fit(Histogram, samples, edges; closed=:right).weights
@assert length(cnts) == nbins

cnts_sc = fit(Histogram, samples3, edges; closed=:right).weights
@assert length(cnts_sc) == nbins

# check the counts
for i = 1:nbins
if verbose
@printf("[%.4f, %.4f) ==> (%d, %d): %d\n", edges[i], edges[i+1], clb[i], cub[i], cnts[i])
@printf("[%.4f, %.4f) ==> (%d, %d): %d\n", edges[i], edges[i+1], clb[i], cub[i], cnts_sc[i])
end
clb[i] <= cnts[i] <= cub[i] ||
error("The counts are out of the confidence interval.")
error("The counts are out of the confidence interval. (Vector Method)")
clb[i] <= cnts_sc[i] <= cub[i] ||
error("The counts are out of the confidence interval. (Scalar Method)")
end
return samples
end
Expand Down

0 comments on commit 1f7e28a

Please sign in to comment.