Skip to content

Commit 706a081

Browse files
committed
Make corr_cholesky_factor deal with large inputs correctly.
1 parent 1771aab commit 706a081

File tree

4 files changed

+63
-26
lines changed

4 files changed

+63
-26
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TransformVariables"
22
uuid = "84d833dd-6860-57f9-a1a7-6da5db126cff"
33
authors = ["Tamas K. Papp <tkpapp@gmail.com>"]
4-
version = "0.8.9"
4+
version = "0.8.10"
55

66
[deps]
77
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"

src/special_arrays.jl

+37-20
Original file line numberDiff line numberDiff line change
@@ -7,36 +7,53 @@ export UnitVector, UnitSimplex, CorrCholeskyFactor, corr_cholesky_factor
77
"""
88
$(SIGNATURES)
99
10-
`log(abs(…))` of the derivative of `tanh`, calculated accurately.
10+
Return a `NamedTuple` of
11+
12+
- `log_l2_rem`, for `log(1 - tanh(x)^2)`,
13+
14+
- `logjac`, for `log(abs( ∂(log(abs(tanh(x))) / ∂x ))`
15+
16+
Caller ensures that `x ≥ 0`. `x == 0` is handled correctly, but results in infinities.
1117
"""
12-
function _tanh_logabsderiv(x)
18+
function tanh_helpers(x)
1319
d = 2*x
14-
log(4) + d - 2 * log1pexp(d)
20+
log_denom = log1pexp(d) # log(exp(2x) + 1)
21+
logjac = log(4) + d - 2 * log_denom # log(ab
22+
log_l2_rem = 2*(log(2) + x - log_denom) # log(2exp(x) / (exp(2x) + 1))
23+
(; logjac, log_l2_rem)
1524
end
1625

1726
"""
18-
(y, r, ℓ) = $SIGNATURES
27+
(y, log_r, ℓ) = $SIGNATURES
1928
20-
Given ``x ∈ ℝ`` and ``0 ≤ r ≤ 1``, return `(y, r′)` such that
29+
Given ``x ∈ ℝ`` and ``0 ≤ r ≤ 1``, we define `(y, r′)` such that
2130
2231
1. ``y² + (r′)² = r²``,
2332
24-
2. ``y: |y| ≤ r`` is mapped with a bijection from `x`.
33+
2. ``y: |y| ≤ r`` is mapped with a bijection from `x`, with the sign depending on `x`,
34+
35+
but use `log(r)` for actual calculations so that large `y`s still give nonsingular results.
2536
2637
`ℓ` is the log Jacobian (whether it is evaluated depends on `flag`).
2738
"""
28-
@inline function l2_remainder_transform(flag::LogJacFlag, x, r)
39+
@inline function l2_remainder_transform(flag::LogJacFlag, x, log_r)
40+
(; logjac, log_l2_rem) = tanh_helpers(x)
2941
# note that 1-tanh(x)^2 = sech(x)^2
30-
(tanh(x) * r, r*sech(x)^2,
31-
flag isa NoLogJac ? flag : _tanh_logabsderiv(x) + 0.5*log(r))
42+
(tanh(x) * exp(log_r / 2),
43+
log_r + log_l2_rem,
44+
flag isa NoLogJac ? flag : logjac + 0.5*log_r)
3245
end
3346

3447
"""
3548
(x, r′) = $SIGNATURES
3649
3750
Inverse of [`l2_remainder_transform`](@ref) in `x` and `y`.
3851
"""
39-
@inline l2_remainder_inverse(y, r) = atanh(y/√r), r-y^2
52+
@inline function l2_remainder_inverse(y, log_r)
53+
x = atanh(y / exp(log_r / 2))
54+
log_r′ = logsubexp(log_r, 2 * log(abs(y)))
55+
x, log_r′
56+
end
4057

4158
####
4259
#### UnitVector
@@ -65,16 +82,16 @@ end
6582
function transform_with(flag::LogJacFlag, t::UnitVector, x::AbstractVector, index)
6683
@unpack n = t
6784
T = robust_eltype(x)
68-
r = one(T)
85+
log_r = zero(T)
6986
y = Vector{T}(undef, n)
7087
= logjac_zero(flag, T)
7188
@inbounds for i in 1:(n - 1)
7289
xi = x[index]
7390
index += 1
74-
y[i], r, ℓi = l2_remainder_transform(flag, xi, r)
91+
y[i], log_r, ℓi = l2_remainder_transform(flag, xi, log_r)
7592
+= ℓi
7693
end
77-
y[end] = r
94+
y[end] = exp(log_r / 2)
7895
y, ℓ, index
7996
end
8097

@@ -83,9 +100,9 @@ inverse_eltype(t::UnitVector, y::AbstractVector) = robust_eltype(y)
83100
function inverse_at!(x::AbstractVector, index, t::UnitVector, y::AbstractVector)
84101
@unpack n = t
85102
@argcheck length(y) == n
86-
r = one(eltype(y))
103+
log_r = zero(eltype(y))
87104
@inbounds for yi in axes(y, 1)[1:(end-1)]
88-
x[index], r = l2_remainder_inverse(y[yi], r)
105+
x[index], log_r = l2_remainder_inverse(y[yi], log_r)
89106
index += 1
90107
end
91108
index
@@ -244,14 +261,14 @@ function calculate_corr_cholesky_factor!(U::AbstractMatrix{T}, flag::LogJacFlag,
244261
n = size(U, 1)
245262
= logjac_zero(flag, T)
246263
@inbounds for col_index in 1:n
247-
r = one(T)
264+
log_r = zero(T)
248265
for row_index in 1:(col_index-1)
249266
xi = x[index]
250-
U[row_index, col_index], r, ℓi = l2_remainder_transform(flag, xi, r)
267+
U[row_index, col_index], log_r, ℓi = l2_remainder_transform(flag, xi, log_r)
251268
+= ℓi
252269
index += 1
253270
end
254-
U[col_index, col_index] = r
271+
U[col_index, col_index] = exp(log_r / 2)
255272
end
256273
U, ℓ, index
257274
end
@@ -285,9 +302,9 @@ function inverse_at!(x::AbstractVector, index,
285302
n = result_size(t)
286303
@argcheck size(U, 1) == n
287304
@inbounds for col in 1:n
288-
r = one(eltype(U))
305+
log_r = zero(eltype(U))
289306
for row in 1:(col-1)
290-
x[index], r = l2_remainder_inverse(U[row, col], r)
307+
x[index], log_r = l2_remainder_inverse(U[row, col], log_r)
291308
index += 1
292309
end
293310
end

test/runtests.jl

+22-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ using LogDensityProblems: logdensity, logdensity_and_gradient
66
using LogDensityProblemsAD
77
using TransformVariables:
88
AbstractTransform, ScalarTransform, VectorTransform, ArrayTransformation,
9-
unit_triangular_dimension, logistic, logistic_logjac, logit, inverse_and_logjac, NOLOGJAC, transform_with
9+
unit_triangular_dimension, logistic, logistic_logjac, logit, inverse_and_logjac,
10+
NOLOGJAC, transform_with
1011
import ChangesOfVariables, InverseFunctions
1112
using Enzyme: autodiff, ReverseWithPrimal, Active, Const
1213

@@ -136,9 +137,18 @@ end
136137
end
137138
end
138139

140+
@testset "tanh helpers" begin
141+
for _ in 1:10000
142+
x = (rand() - 0.5) * 100
143+
@unpack log_l2_rem, logjac = TransformVariables.tanh_helpers(x)
144+
@test Float64(AD_logjac(tanh, BigFloat(x))) logjac atol = 1e-4
145+
@test Float64(log(sech(BigFloat(x))^2)) log_l2_rem atol = 1e-4
146+
end
147+
end
148+
139149
@testset "to correlation cholesky factor" begin
140150
@testset "dimension checks" begin
141-
C = CorrCholeskyFactor(3)
151+
C = corr_cholesky_factor(3)
142152
wrong_x = zeros(dimension(C) + 1)
143153

144154
@test_throws ArgumentError transform(C, wrong_x)
@@ -147,7 +157,7 @@ end
147157

148158
@testset "consistency checks" begin
149159
for K in 1:8
150-
t = CorrCholeskyFactor(K)
160+
t = corr_cholesky_factor(K)
151161
@test dimension(t) == (K - 1)*K/2
152162
CIENV && @info "testing correlation cholesky K = $(K)"
153163
if K > 1
@@ -615,6 +625,15 @@ end
615625
end
616626
end
617627

628+
@testset "corr cholesky factor large inputs" begin
629+
t = corr_cholesky_factor(7)
630+
d = dimension(t)
631+
for _ in 1:100
632+
x = sign.(rand(d) .- 0.5) .* 100
633+
@test isfinite(logdet(transform(t, x)) )
634+
end
635+
end
636+
618637
@testset "pretty printing" begin
619638
t = as((a = asℝ₊,
620639
b = as(Array, asℝ₋, 3, 3),

test/utilities.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@ $(SIGNATURES)
33
44
Log jacobian abs determinant via automatic differentiation. For testing.
55
"""
6+
AD_logjac(f, x) = log(abs(ForwardDiff.derivative(f, x)))
7+
68
AD_logjac(t::VectorTransform, x, vec_y) =
79
logabsdet(ForwardDiff.jacobian(x -> vec_y(transform(t, x)), x))[1]
810

9-
AD_logjac(t::ScalarTransform, x) =
10-
log(abs(ForwardDiff.derivative(x -> transform(t, x), x)))
11+
AD_logjac(t::ScalarTransform, x) = AD_logjac(x -> transform(t, x), x)
1112

1213
"""
1314
$(SIGNATURES)

0 commit comments

Comments
 (0)