@@ -7,36 +7,53 @@ export UnitVector, UnitSimplex, CorrCholeskyFactor, corr_cholesky_factor
7
7
"""
8
8
$(SIGNATURES)
9
9
10
- `log(abs(…))` of the derivative of `tanh`, calculated accurately.
10
+ Return a `NamedTuple` of
11
+
12
+ - `log_l2_rem`, for `log(1 - tanh(x)^2)`,
13
+
14
+ - `logjac`, for `log(abs( ∂(log(abs(tanh(x))) / ∂x ))`
15
+
16
+ Caller ensures that `x ≥ 0`. `x == 0` is handled correctly, but results in infinities.
11
17
"""
12
- function _tanh_logabsderiv (x)
18
+ function tanh_helpers (x)
13
19
d = 2 * x
14
- log (4 ) + d - 2 * log1pexp (d)
20
+ log_denom = log1pexp (d) # log(exp(2x) + 1)
21
+ logjac = log (4 ) + d - 2 * log_denom # log(ab
22
+ log_l2_rem = 2 * (log (2 ) + x - log_denom) # log(2exp(x) / (exp(2x) + 1))
23
+ (; logjac, log_l2_rem)
15
24
end
16
25
17
26
"""
18
- (y, r , ℓ) = $SIGNATURES
27
+ (y, log_r , ℓ) = $SIGNATURES
19
28
20
- Given ``x ∈ ℝ`` and ``0 ≤ r ≤ 1``, return `(y, r′)` such that
29
+ Given ``x ∈ ℝ`` and ``0 ≤ r ≤ 1``, we define `(y, r′)` such that
21
30
22
31
1. ``y² + (r′)² = r²``,
23
32
24
- 2. ``y: |y| ≤ r`` is mapped with a bijection from `x`.
33
+ 2. ``y: |y| ≤ r`` is mapped with a bijection from `x`, with the sign depending on `x`,
34
+
35
+ but use `log(r)` for actual calculations so that large `y`s still give nonsingular results.
25
36
26
37
`ℓ` is the log Jacobian (whether it is evaluated depends on `flag`).
27
38
"""
28
- @inline function l2_remainder_transform (flag:: LogJacFlag , x, r)
39
+ @inline function l2_remainder_transform (flag:: LogJacFlag , x, log_r)
40
+ (; logjac, log_l2_rem) = tanh_helpers (x)
29
41
# note that 1-tanh(x)^2 = sech(x)^2
30
- (tanh (x) * √ r, r* sech (x)^ 2 ,
31
- flag isa NoLogJac ? flag : _tanh_logabsderiv (x) + 0.5 * log (r))
42
+ (tanh (x) * exp (log_r / 2 ),
43
+ log_r + log_l2_rem,
44
+ flag isa NoLogJac ? flag : logjac + 0.5 * log_r)
32
45
end
33
46
34
47
"""
35
48
(x, r′) = $SIGNATURES
36
49
37
50
Inverse of [`l2_remainder_transform`](@ref) in `x` and `y`.
38
51
"""
39
- @inline l2_remainder_inverse (y, r) = atanh (y/√ r), r- y^ 2
52
+ @inline function l2_remainder_inverse (y, log_r)
53
+ x = atanh (y / exp (log_r / 2 ))
54
+ log_r′ = logsubexp (log_r, 2 * log (abs (y)))
55
+ x, log_r′
56
+ end
40
57
41
58
# ###
42
59
# ### UnitVector
65
82
function transform_with (flag:: LogJacFlag , t:: UnitVector , x:: AbstractVector , index)
66
83
@unpack n = t
67
84
T = robust_eltype (x)
68
- r = one (T)
85
+ log_r = zero (T)
69
86
y = Vector {T} (undef, n)
70
87
ℓ = logjac_zero (flag, T)
71
88
@inbounds for i in 1 : (n - 1 )
72
89
xi = x[index]
73
90
index += 1
74
- y[i], r , ℓi = l2_remainder_transform (flag, xi, r )
91
+ y[i], log_r , ℓi = l2_remainder_transform (flag, xi, log_r )
75
92
ℓ += ℓi
76
93
end
77
- y[end ] = √ r
94
+ y[end ] = exp (log_r / 2 )
78
95
y, ℓ, index
79
96
end
80
97
@@ -83,9 +100,9 @@ inverse_eltype(t::UnitVector, y::AbstractVector) = robust_eltype(y)
83
100
function inverse_at! (x:: AbstractVector , index, t:: UnitVector , y:: AbstractVector )
84
101
@unpack n = t
85
102
@argcheck length (y) == n
86
- r = one (eltype (y))
103
+ log_r = zero (eltype (y))
87
104
@inbounds for yi in axes (y, 1 )[1 : (end - 1 )]
88
- x[index], r = l2_remainder_inverse (y[yi], r )
105
+ x[index], log_r = l2_remainder_inverse (y[yi], log_r )
89
106
index += 1
90
107
end
91
108
index
@@ -244,14 +261,14 @@ function calculate_corr_cholesky_factor!(U::AbstractMatrix{T}, flag::LogJacFlag,
244
261
n = size (U, 1 )
245
262
ℓ = logjac_zero (flag, T)
246
263
@inbounds for col_index in 1 : n
247
- r = one (T)
264
+ log_r = zero (T)
248
265
for row_index in 1 : (col_index- 1 )
249
266
xi = x[index]
250
- U[row_index, col_index], r , ℓi = l2_remainder_transform (flag, xi, r )
267
+ U[row_index, col_index], log_r , ℓi = l2_remainder_transform (flag, xi, log_r )
251
268
ℓ += ℓi
252
269
index += 1
253
270
end
254
- U[col_index, col_index] = √ r
271
+ U[col_index, col_index] = exp (log_r / 2 )
255
272
end
256
273
U, ℓ, index
257
274
end
@@ -285,9 +302,9 @@ function inverse_at!(x::AbstractVector, index,
285
302
n = result_size (t)
286
303
@argcheck size (U, 1 ) == n
287
304
@inbounds for col in 1 : n
288
- r = one (eltype (U))
305
+ log_r = zero (eltype (U))
289
306
for row in 1 : (col- 1 )
290
- x[index], r = l2_remainder_inverse (U[row, col], r )
307
+ x[index], log_r = l2_remainder_inverse (U[row, col], log_r )
291
308
index += 1
292
309
end
293
310
end
0 commit comments