Skip to content

Commit 318affa

Browse files
stevengjandreasnoack
authored andcommitted
bug fixes in matrix log (#32327)
* bug fixes in matrix log * patches to matrix log (#33245) * patches to matrix log Avoid integer overflow if `s > 63`. Correct logic for `s == 0`. Only use fancy divided difference formulae if eigenvalues are close - avoids dangerous roundoff error if they are in opposite sectors. * add tests
1 parent b13c64a commit 318affa

File tree

2 files changed

+36
-29
lines changed

2 files changed

+36
-29
lines changed

stdlib/LinearAlgebra/src/triangular.jl

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2294,32 +2294,14 @@ function log(A0::UpperTriangular{T}) where T<:BlasFloat
22942294
end
22952295

22962296
# Compute accurate superdiagonal of T
2297-
p = 1 / 2^s
2298-
for k = 1:n-1
2299-
Ak = A0[k,k]
2300-
Akp1 = A0[k+1,k+1]
2301-
Akp = Ak^p
2302-
Akp1p = Akp1^p
2303-
A[k,k] = Akp
2304-
A[k+1,k+1] = Akp1p
2305-
if Ak == Akp1
2306-
A[k,k+1] = p * A0[k,k+1] * Ak^(p-1)
2307-
elseif 2 * abs(Ak) < abs(Akp1) || 2 * abs(Akp1) < abs(Ak)
2308-
A[k,k+1] = A0[k,k+1] * (Akp1p - Akp) / (Akp1 - Ak)
2309-
else
2310-
logAk = log(Ak)
2311-
logAkp1 = log(Akp1)
2312-
w = atanh((Akp1 - Ak)/(Akp1 + Ak)) + im*pi*ceil((imag(logAkp1-logAk)-pi)/(2*pi))
2313-
dd = 2 * exp(p*(logAk+logAkp1)/2) * sinh(p*w) / (Akp1 - Ak)
2314-
A[k,k+1] = A0[k,k+1] * dd
2315-
end
2316-
end
2297+
blockpower!(A, A0, 0.5^s)
23172298

23182299
# Compute accurate diagonal of T
23192300
for i = 1:n
23202301
a = A0[i,i]
23212302
if s == 0
2322-
r = a - 1
2303+
A[i,i] = a - 1
2304+
continue
23232305
end
23242306
s0 = s
23252307
if angle(a) >= pi / 2
@@ -2356,7 +2338,7 @@ function log(A0::UpperTriangular{T}) where T<:BlasFloat
23562338
end
23572339

23582340
# Scale back
2359-
lmul!(2^s, Y)
2341+
lmul!(2.0^s, Y)
23602342

23612343
# Compute accurate diagonal and superdiagonal of log(T)
23622344
for k = 1:n-1
@@ -2368,11 +2350,16 @@ function log(A0::UpperTriangular{T}) where T<:BlasFloat
23682350
Y[k+1,k+1] = logAkp1
23692351
if Ak == Akp1
23702352
Y[k,k+1] = A0[k,k+1] / Ak
2371-
elseif 2 * abs(Ak) < abs(Akp1) || 2 * abs(Akp1) < abs(Ak)
2353+
elseif 2 * abs(Ak) < abs(Akp1) || 2 * abs(Akp1) < abs(Ak) || iszero(Akp1 + Ak)
23722354
Y[k,k+1] = A0[k,k+1] * (logAkp1 - logAk) / (Akp1 - Ak)
23732355
else
2374-
w = atanh((Akp1 - Ak)/(Akp1 + Ak) + im*pi*(ceil((imag(logAkp1-logAk) - pi)/(2*pi))))
2375-
Y[k,k+1] = 2 * A0[k,k+1] * w / (Akp1 - Ak)
2356+
z = (Akp1 - Ak)/(Akp1 + Ak)
2357+
if abs(z) > 1
2358+
Y[k,k+1] = A0[k,k+1] * (logAkp1 - logAk) / (Akp1 - Ak)
2359+
else
2360+
w = atanh(z) + im * pi * (unw(logAkp1-logAk) - unw(log1p(z)-log1p(-z)))
2361+
Y[k,k+1] = 2 * A0[k,k+1] * w / (Akp1 - Ak)
2362+
end
23762363
end
23772364
end
23782365

@@ -2519,14 +2506,19 @@ function blockpower!(A::UpperTriangular, A0::UpperTriangular, p)
25192506

25202507
if Ak == Akp1
25212508
A[k,k+1] = p * A0[k,k+1] * Ak^(p-1)
2522-
elseif 2 * abs(Ak) < abs(Akp1) || 2 * abs(Akp1) < abs(Ak)
2509+
elseif 2 * abs(Ak) < abs(Akp1) || 2 * abs(Akp1) < abs(Ak) || iszero(Akp1 + Ak)
25232510
A[k,k+1] = A0[k,k+1] * (Akp1p - Akp) / (Akp1 - Ak)
25242511
else
25252512
logAk = log(Ak)
25262513
logAkp1 = log(Akp1)
2527-
w = atanh((Akp1 - Ak)/(Akp1 + Ak)) + im * pi * unw(logAkp1-logAk)
2528-
dd = 2 * exp(p*(logAk+logAkp1)/2) * sinh(p*w) / (Akp1 - Ak);
2529-
A[k,k+1] = A0[k,k+1] * dd
2514+
z = (Akp1 - Ak)/(Akp1 + Ak)
2515+
if abs(z) > 1
2516+
A[k,k+1] = A0[k,k+1] * (Akp1p - Akp) / (Akp1 - Ak)
2517+
else
2518+
w = atanh(z) + im * pi * (unw(logAkp1-logAk) - unw(log1p(z)-log1p(-z)))
2519+
dd = 2 * exp(p*(logAk+logAkp1)/2) * sinh(p*w) / (Akp1 - Ak);
2520+
A[k,k+1] = A0[k,k+1] * dd
2521+
end
25302522
end
25312523
end
25322524
end

stdlib/LinearAlgebra/test/dense.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -903,4 +903,19 @@ end
903903
@test adjoint(factorize(adjoint(a))) == factorize(a)
904904
end
905905

906+
@testset "Matrix log issue #32313" begin
907+
for A in ([30 20; -50 -30], [10.0im 0; 0 -10.0im], randn(6,6))
908+
@test exp(log(A)) A
909+
end
910+
end
911+
912+
@testset "Matrix log PR #33245" begin
913+
# edge case for divided difference
914+
A1 = triu(ones(3,3),1) + diagm([1.0, -2eps()-1im, -eps()+0.75im])
915+
@test exp(log(A1)) A1
916+
# case where no sqrt is needed (s=0)
917+
A2 = [1.01 0.01 0.01; 0 1.01 0.01; 0 0 1.01]
918+
@test exp(log(A2)) A2
919+
end
920+
906921
end # module TestDense

0 commit comments

Comments
 (0)