Skip to content

Commit

Permalink
Add mean et al. for truncated log normal
Browse files Browse the repository at this point in the history
Fixes 709
  • Loading branch information
ararslan committed Jun 28, 2024
1 parent 65f056c commit 036a24d
Show file tree
Hide file tree
Showing 7 changed files with 163 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/truncate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ include(joinpath("truncated", "exponential.jl"))
include(joinpath("truncated", "uniform.jl"))
include(joinpath("truncated", "loguniform.jl"))
include(joinpath("truncated", "discrete_uniform.jl"))
include(joinpath("truncated", "lognormal.jl"))

#### Utilities

Expand Down
50 changes: 50 additions & 0 deletions src/truncated/lognormal.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Moments of the truncated log-normal can be computed directly from the moment generating
# function of the truncated normal:
# Let Y ~ LogNormal(μ, σ) truncated to (a, b). Then log(Y) ~ Normal(μ, σ) truncated
# to (log(a), log(b)), and E[Y^n] = E[(e^log(Y))^n] = E[e^(nlog(Y))] = mgf(log(Y), n).

# Given `truncate(LogNormal(μ, σ), a, b)`, return `truncate(Normal(μ, σ), log(a), log(b))`
function _truncnorm(d::Truncated{<:LogNormal})
μ, σ = params(d.untruncated)
T = partype(d)
a = d.lower === nothing ? nothing : log(T(d.lower))
b = d.upper === nothing ? nothing : log(T(d.upper))
return truncated(Normal(μ, σ), a, b)
end

mean(d::Truncated{<:LogNormal}) = mgf(_truncnorm(d), 1)

function var(d::Truncated{<:LogNormal})
tn = _truncnorm(d)
# Ensure the variance doesn't end up negative, which can occur due to numerical issues
return max(mgf(tn, 2) - mgf(tn, 1)^2, 0)
end

function skewness(d::Truncated{<:LogNormal})
tn = _truncnorm(d)
m1 = mgf(tn, 1)
m2 = mgf(tn, 2)
m3 = mgf(tn, 3)
sqm1 = m1^2
v = m2 - sqm1
return (m3 + m1 * (-3 * m2 + 2 * sqm1)) / (v * sqrt(v))
end

function kurtosis(d::Truncated{<:LogNormal})
tn = _truncnorm(d)
m1 = mgf(tn, 1)
m2 = mgf(tn, 2)
m3 = mgf(tn, 3)
m4 = mgf(tn, 4)
v = m2 - m1^2
return @horner(m1, m4, -4m3, 6m2, 0, -3) / v^2 - 3
end

# TODO: The entropy can be written "directly" as well, according to Mathematica, but
# the expression for it fills me with regret. There are some recognizable components,
# so a sufficiently motivated person could try to manually simplify it into something
# comprehensible. For reference, you can obtain the entropy with Mathematica like so:
#
# d = TruncatedDistribution[{a, b}, LogNormalDistribution[m, s]];
# Expectation[-LogLikelihood[d, {x}], Distributed[x, d],
# Assumptions -> Element[x | m | s | a | b, Reals] && s > 0 && 0 < a < x < b]
30 changes: 30 additions & 0 deletions src/truncated/normal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,36 @@ function entropy(d::Truncated{<:Normal{<:Real},Continuous})
0.5 * (log2π + 1.) + log* z) + (aφa - bφb) / (2.0 * z)
end

function mgf(d::Truncated{<:Normal{<:Real},Continuous}, t::Real)
T = float(promote_type(partype(d), typeof(t)))
a = T(minimum(d))
b = T(maximum(d))
if isnan(a) || isnan(b) # TODO: Disallow constructing `Truncated` with a `NaN` bound?
return T(NaN)
elseif isinf(a) && isinf(b) && a != b
# Distribution is `Truncated`-wrapped but not actually truncated
return T(mgf(d.untruncated, t))
elseif a == b
# Truncated to a Dirac distribution; this is `mgf(Dirac(a), t)`
return exp(a * t)
end
d0 = d.untruncated
μ = mean(d0)
σ = std(d0)
σ²t = σ^2 * t
a′ = (a - μ) / σ
b′ = (b - μ) / σ
stdnorm = Normal{T}(zero(T), one(T))
# log((Φ(b′ - σ²t) - Φ(a′ - σ²t)) / (Φ(b′) - Φ(a′)))
logratio = if isfinite(a) && isfinite(b) # doubly truncated
logdiffcdf(stdnorm, b′ - σ²t, a′ - σ²t) - logdiffcdf(stdnorm, b′, a′)
elseif isfinite(a) # left truncated: b = ∞, Φ(b′) = Φ(b′ - σ²t) = 1
logccdf(stdnorm, a′ - σ²t) - logccdf(stdnorm, a′)
else # isfinite(b), right truncated: a = ∞, Φ(a′) = Φ(a′ - σ²t) = 0
logcdf(stdnorm, b′ - σ²t) - logcdf(stdnorm, b′)
end
return exp(t *+ σ²t / 2) + logratio)
end

### sampling

Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ const tests = [
"truncated/exponential",
"truncated/uniform",
"truncated/discrete_uniform",
"truncated/lognormal",
"censored",
"univariate/continuous/normal",
"univariate/continuous/laplace",
Expand Down
11 changes: 11 additions & 0 deletions test/testutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,17 @@ function _linspace(a::Float64, b::Float64, n::Int)
return r
end

# Enables testing against values computed at high precision by transforming an expression
# that uses numeric literals and constants to wrap those in `big()`, similar to how the
# high-precision values for irrational constants are defined with `Base.@irrational` and
# in IrrationalConstants.jl. See e.g. `test/truncated/normal.jl` for example use.
bigly(x) = x
bigly(x::Symbol) = x in (, :ℯ, :Inf, :NaN) ? Expr(:call, :big, x) : x
bigly(x::Real) = Expr(:call, :big, x)
bigly(x::Expr) = (map!(bigly, x.args, x.args); x)
macro bigly(ex)
return esc(bigly(ex))
end

#################################################
#
Expand Down
36 changes: 36 additions & 0 deletions test/truncated/lognormal.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
using Distributions, Test
using Distributions: expectation

naive_moment(d, n, μ, σ²) == sqrt(σ²); expectation(x -> ((x - μ) / σ)^n, d))

@testset "Truncated log normal" begin
@testset "truncated(LogNormal{$T}(0, 1), ℯ⁻², ℯ²)" for T in (Float32, Float64, BigFloat)
d = truncated(LogNormal{T}(zero(T), one(T)), exp(T(-2)), exp(T(2)))
tn = truncated(Normal{BigFloat}(big(0.0), big(1.0)), -2, 2)
bigmean = mgf(tn, 1)
bigvar = mgf(tn, 2) - bigmean^2
@test @inferred(mean(d)) bigmean
@test @inferred(var(d)) bigvar
@test @inferred(median(d)) one(T)
@test @inferred(skewness(d)) naive_moment(d, 3, bigmean, bigvar)
@test @inferred(kurtosis(d)) naive_moment(d, 4, bigmean, bigvar) - big(3)
@test mean(d) isa T
end
@testset "Bound with no effect" begin
# Uses the example distribution from issue #709, though what's tested here is
# mostly unrelated to that issue (aside from `mean` not erroring).
# The specified left truncation at 0 has no effect for `LogNormal`
d1 = truncated(LogNormal(1, 5), 0, 1e5)
@test mean(d1) 0 atol=eps()
v1 = var(d1)
@test v1 0 atol=eps()
# Without a `max(_, 0)`, this would be within machine precision of 0 (as above) but
# numerically negative, which could cause downstream issues that assume a nonnegative
# variance
@test v1 >= 0
# Compare results with not specifying a lower bound at all
d2 = truncated(LogNormal(1, 5); upper=1e5)
@test mean(d1) == mean(d2)
@test var(d1) == var(d2)
end
end
34 changes: 34 additions & 0 deletions test/truncated/normal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,37 @@ end
@test isfinite(pdf(trunc, x))
end
end

@testset "Truncated normal MGF" begin
two = big(2)
sqrt2 = sqrt(two)
invsqrt2 = inv(sqrt2)
inv2sqrt2 = inv(two * sqrt2)
twoerfsqrt2 = two * erf(sqrt2)

for T in (Float32, Float64, BigFloat)
d = truncated(Normal{T}(zero(T), one(T)), -2, 2)
@test @inferred(mgf(d, 0)) == 1
@test @inferred(mgf(d, 1)) @bigly sqrt(ℯ) * (erf(invsqrt2) + erf(3 * invsqrt2)) / twoerfsqrt2
@test @inferred(mgf(d, 2.5)) @bigly exp(25//8) * (erf(9 * inv2sqrt2) - erf(inv2sqrt2)) / twoerfsqrt2
end

d = truncated(Normal(3, 10), 7, 8)
@test mgf(d, 0) == 1
@test mgf(d, 1) == 0

d = truncated(Normal(27, 3); lower=0)
@test mgf(d, 0) == 1
@test mgf(d, 1) @bigly 2 * exp(63//2) / (1 + erf(9 * invsqrt2))
@test mgf(d, 2.5) @bigly 2 * exp(765//8) / (1 + erf(9 * invsqrt2))

d = truncated(Normal(-5, 1); upper=-10)
@test mgf(d, 0) == 1
@test mgf(d, 1) @bigly erfc(3 * sqrt2) / (exp(9//2) * erfc(5 * invsqrt2))

@test isnan(mgf(truncated(Normal(); upper=NaN), 0))

@test mgf(truncated(Normal(), -Inf, Inf), 1) == mgf(Normal(), 1)

@test mgf(truncated(Normal(), 2, 2), 1) == exp(2)
end

0 comments on commit 036a24d

Please sign in to comment.