Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mean et al. for truncated log normal #1874

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/truncate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ include(joinpath("truncated", "exponential.jl"))
include(joinpath("truncated", "uniform.jl"))
include(joinpath("truncated", "loguniform.jl"))
include(joinpath("truncated", "discrete_uniform.jl"))
include(joinpath("truncated", "lognormal.jl"))

#### Utilities

Expand Down
50 changes: 50 additions & 0 deletions src/truncated/lognormal.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Moments of the truncated log-normal can be computed directly from the moment generating
# function of the truncated normal:
# Let Y ~ LogNormal(μ, σ) truncated to (a, b). Then log(Y) ~ Normal(μ, σ) truncated
# to (log(a), log(b)), and E[Y^n] = E[(e^log(Y))^n] = E[e^(nlog(Y))] = mgf(log(Y), n).

# Given `truncate(LogNormal(μ, σ), a, b)`, return `truncate(Normal(μ, σ), log(a), log(b))`
function _truncnorm(d::Truncated{<:LogNormal})
μ, σ = params(d.untruncated)
T = float(partype(d))
a = d.lower === nothing || d.lower <= 0 ? nothing : log(T(d.lower))
b = d.upper === nothing || isinf(d.upper) ? nothing : log(T(d.upper))
Comment on lines +10 to +11
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests on Julia 1.3 pass locally when changing this to

Suggested change
a = d.lower === nothing || d.lower <= 0 ? nothing : log(T(d.lower))
b = d.upper === nothing || isinf(d.upper) ? nothing : log(T(d.upper))
a = d.lower === nothing ? nothing : log(T(max(d.lower, 0)))
b = d.upper === nothing ? nothing : log(T(d.upper))

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That isn't functionally equivalent though; IIRC, I was relying on a and/or b being nothing in those cases so that truncated would handle it a particular way.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But arguably this optimization (d.lower === nothing or d.upper === nothing) should already have happened, either by a user or internally, when constructing the d = truncated(LogNormal(...))? Maybe one shouldn't expect that unoptimized inputs lead to optimized algorithms. AFAICT the optimization of truncated(Normal(...), ...) is also only exploited in the case where this returns a Normal (a = b = nothing); the code for Truncated{<:Normal} does not seem to use the fact that a bound might be nothing. In the Normal case, arguably the LogNormal shouldn't be truncated in the first place.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm currently sick and can only barely comprehend that message but if you'd prefer to go with your suggested change then feel free to apply it, I trust your judgement

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did we have this conversation a few months ago? My brain similarly can't comprehend whether this is the same discussion: #1874 (comment)

I should probably just log off and go sleep

return truncated(Normal{T}(T(μ), T(σ)), a, b)
end

mean(d::Truncated{<:LogNormal}) = mgf(_truncnorm(d), 1)

function var(d::Truncated{<:LogNormal})
tn = _truncnorm(d)
# Ensure the variance doesn't end up negative, which can occur due to numerical issues
return max(mgf(tn, 2) - mgf(tn, 1)^2, 0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Repeated evaluation of mgf involves repeated calculations of the same (intermediate) quantities. But my fear is that optimizing this further will lead to less readable code...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I had the same thought and wasn't sure what to do about it so I just... didn't address it, haha

end

function skewness(d::Truncated{<:LogNormal})
tn = _truncnorm(d)
m1 = mgf(tn, 1)
m2 = mgf(tn, 2)
m3 = mgf(tn, 3)
sqm1 = m1^2
v = m2 - sqm1
return (m3 + m1 * (-3 * m2 + 2 * sqm1)) / (v * sqrt(v))
end

function kurtosis(d::Truncated{<:LogNormal})
tn = _truncnorm(d)
m1 = mgf(tn, 1)
m2 = mgf(tn, 2)
m3 = mgf(tn, 3)
m4 = mgf(tn, 4)
v = m2 - m1^2
return @horner(m1, m4, -4m3, 6m2, 0, -3) / v^2 - 3
end

# TODO: The entropy can be written "directly" as well, according to Mathematica, but
# the expression for it fills me with regret. There are some recognizable components,
# so a sufficiently motivated person could try to manually simplify it into something
# comprehensible. For reference, you can obtain the entropy with Mathematica like so:
#
# d = TruncatedDistribution[{a, b}, LogNormalDistribution[m, s]];
# Expectation[-LogLikelihood[d, {x}], Distributed[x, d],
# Assumptions -> Element[x | m | s | a | b, Reals] && s > 0 && 0 < a < x < b]
30 changes: 30 additions & 0 deletions src/truncated/normal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,36 @@ function entropy(d::Truncated{<:Normal{<:Real},Continuous})
0.5 * (log2π + 1.) + log(σ * z) + (aφa - bφb) / (2.0 * z)
end

function mgf(d::Truncated{<:Normal{<:Real},Continuous}, t::Real)
T = float(promote_type(partype(d), typeof(t)))
a = T(minimum(d))
b = T(maximum(d))
if isnan(a) || isnan(b) # TODO: Disallow constructing `Truncated` with a `NaN` bound?
return T(NaN)
elseif isinf(a) && isinf(b) && a != b
# Distribution is `Truncated`-wrapped but not actually truncated
return T(mgf(d.untruncated, t))
elseif a == b
# Truncated to a Dirac distribution; this is `mgf(Dirac(a), t)`
return exp(a * t)
end
d0 = d.untruncated
μ = mean(d0)
σ = std(d0)
σ²t = σ^2 * t
a′ = (a - μ) / σ
b′ = (b - μ) / σ
stdnorm = Normal{T}(zero(T), one(T))
# log((Φ(b′ - σ²t) - Φ(a′ - σ²t)) / (Φ(b′) - Φ(a′)))
logratio = if isfinite(a) && isfinite(b) # doubly truncated
logdiffcdf(stdnorm, b′ - σ²t, a′ - σ²t) - logdiffcdf(stdnorm, b′, a′)
elseif isfinite(a) # left truncated: b = ∞, Φ(b′) = Φ(b′ - σ²t) = 1
logccdf(stdnorm, a′ - σ²t) - logccdf(stdnorm, a′)
else # isfinite(b), right truncated: a = ∞, Φ(a′) = Φ(a′ - σ²t) = 0
logcdf(stdnorm, b′ - σ²t) - logcdf(stdnorm, b′)
end
return exp(t * (μ + σ²t / 2) + logratio)
end

### sampling

Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ const tests = [
"truncated/exponential",
"truncated/uniform",
"truncated/discrete_uniform",
"truncated/lognormal",
"censored",
"univariate/continuous/normal",
"univariate/continuous/laplace",
Expand Down
11 changes: 11 additions & 0 deletions test/testutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,17 @@ function _linspace(a::Float64, b::Float64, n::Int)
return r
end

# Enables testing against values computed at high precision by transforming an expression
# that uses numeric literals and constants to wrap those in `big()`, similar to how the
# high-precision values for irrational constants are defined with `Base.@irrational` and
# in IrrationalConstants.jl. See e.g. `test/truncated/normal.jl` for example use.
bigly(x) = x
bigly(x::Symbol) = x in (:π, :ℯ, :Inf, :NaN) ? Expr(:call, :big, x) : x
bigly(x::Real) = Expr(:call, :big, x)
bigly(x::Expr) = (map!(bigly, x.args, x.args); x)
macro bigly(ex)
return esc(bigly(ex))
end

#################################################
#
Expand Down
40 changes: 40 additions & 0 deletions test/truncated/lognormal.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
using Distributions, Test
using Distributions: expectation

naive_moment(d, n, μ, σ²) = (σ = sqrt(σ²); expectation(x -> ((x - μ) / σ)^n, d))

@testset "Truncated log normal" begin
@testset "truncated(LogNormal{$T}(0, 1), ℯ⁻², ℯ²)" for T in (Float32, Float64, BigFloat)
d = truncated(LogNormal{T}(zero(T), one(T)), exp(T(-2)), exp(T(2)))
tn = truncated(Normal{BigFloat}(big(0.0), big(1.0)), -2, 2)
bigmean = mgf(tn, 1)
bigvar = mgf(tn, 2) - bigmean^2
@test @inferred(mean(d)) ≈ bigmean
@test @inferred(var(d)) ≈ bigvar
@test @inferred(median(d)) ≈ one(T)
@test @inferred(skewness(d)) ≈ naive_moment(d, 3, bigmean, bigvar)
@test @inferred(kurtosis(d)) ≈ naive_moment(d, 4, bigmean, bigvar) - big(3)
@test mean(d) isa T
end
@testset "Bound with no effect" begin
# Uses the example distribution from issue #709, though what's tested here is
# mostly unrelated to that issue (aside from `mean` not erroring).
# The specified left truncation at 0 has no effect for `LogNormal`
d1 = truncated(LogNormal(1, 5), 0, 1e5)
@test mean(d1) ≈ 0 atol=eps()
v1 = var(d1)
@test v1 ≈ 0 atol=eps()
# Without a `max(_, 0)`, this would be within machine precision of 0 (as above) but
# numerically negative, which could cause downstream issues that assume a nonnegative
# variance
@test v1 >= 0
# Compare results with not specifying a lower bound at all
d2 = truncated(LogNormal(1, 5); upper=1e5)
@test mean(d1) == mean(d2)
@test var(d1) == var(d2)

# Truncated outside of support where taking a log would error
d3 = truncated(LogNormal(); lower=-1)
@test mean(d3) == mean(d3.untruncated)
end
end
34 changes: 34 additions & 0 deletions test/truncated/normal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,37 @@ end
@test isfinite(pdf(trunc, x))
end
end

@testset "Truncated normal MGF" begin
two = big(2)
sqrt2 = sqrt(two)
invsqrt2 = inv(sqrt2)
inv2sqrt2 = inv(two * sqrt2)
twoerfsqrt2 = two * erf(sqrt2)

for T in (Float32, Float64, BigFloat)
d = truncated(Normal{T}(zero(T), one(T)), -2, 2)
@test @inferred(mgf(d, 0)) == 1
@test @inferred(mgf(d, 1)) ≈ @bigly sqrt(ℯ) * (erf(invsqrt2) + erf(3 * invsqrt2)) / twoerfsqrt2
@test @inferred(mgf(d, 2.5)) ≈ @bigly exp(25//8) * (erf(9 * inv2sqrt2) - erf(inv2sqrt2)) / twoerfsqrt2
end

d = truncated(Normal(3, 10), 7, 8)
@test mgf(d, 0) == 1
@test mgf(d, 1) == 0

d = truncated(Normal(27, 3); lower=0)
@test mgf(d, 0) == 1
@test mgf(d, 1) ≈ @bigly 2 * exp(63//2) / (1 + erf(9 * invsqrt2))
@test mgf(d, 2.5) ≈ @bigly 2 * exp(765//8) / (1 + erf(9 * invsqrt2))

d = truncated(Normal(-5, 1); upper=-10)
@test mgf(d, 0) == 1
@test mgf(d, 1) ≈ @bigly erfc(3 * sqrt2) / (exp(9//2) * erfc(5 * invsqrt2))

@test isnan(mgf(truncated(Normal(); upper=NaN), 0))

@test mgf(truncated(Normal(), -Inf, Inf), 1) == mgf(Normal(), 1)

@test mgf(truncated(Normal(), 2, 2), 1) == exp(2)
end
Loading