Skip to content

Commit 0b0cd8c

Browse files
EveyKristofferC
authored andcommitted
Fix type instability in matrix log and add missing exp(::Matrix{Complex{<:Integer}}) (#23707)
* Matrices: exp and log minor changes Add `exp` for matrices of complex integers and fix type instability in `log` by wrapping sym/herm in `full`. * Add tests * Avoid multiple testset loops
1 parent ed2de3c commit 0b0cd8c

File tree

2 files changed

+25
-4
lines changed

2 files changed

+25
-4
lines changed

base/linalg/dense.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -462,12 +462,17 @@ julia> exp(A)
462462
```
463463
"""
464464
exp(A::StridedMatrix{<:BlasFloat}) = exp!(copy(A))
465-
exp(A::StridedMatrix{<:Integer}) = exp!(float(A))
465+
exp(A::StridedMatrix{<:Union{Integer,Complex{<:Integer}}}) = exp!(float.(A))
466466

467467
## Destructive matrix exponential using algorithm from Higham, 2008,
468468
## "Functions of Matrices: Theory and Computation", SIAM
469469
function exp!(A::StridedMatrix{T}) where T<:BlasFloat
470470
n = checksquare(A)
471+
if T <: Real
472+
if issymmetric(A)
473+
return full(exp(Symmetric(A)))
474+
end
475+
end
471476
if ishermitian(A)
472477
return full(exp(Hermitian(A)))
473478
end
@@ -592,11 +597,13 @@ julia> log(A)
592597
"""
593598
function log(A::StridedMatrix{T}) where T
594599
# If possible, use diagonalization
595-
if issymmetric(A) && T <: Real
596-
return log(Symmetric(A))
600+
if T <: Real
601+
if issymmetric(A)
602+
return full(log(Symmetric(A)))
603+
end
597604
end
598605
if ishermitian(A)
599-
return log(Hermitian(A))
606+
return full(log(Hermitian(A)))
600607
end
601608

602609
# Use Schur decomposition

test/linalg/dense.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,14 @@ end
423423
@test exp(log(A7)) A7
424424
end
425425

426+
@testset "Integer promotion tests" begin
427+
for (elty1, elty2) in ((Int64, Float64), (Complex{Int64}, Complex{Float64}))
428+
A4int = convert(Matrix{elty1}, [1 2; 3 4])
429+
A4float = convert(Matrix{elty2}, A4int)
430+
@test exp(A4int) == exp(A4float)
431+
end
432+
end
433+
426434
A8 = 100 * [-1+1im 0 0 1e-8; 0 1 0 0; 0 0 1 0; 0 0 0 1]
427435
@test exp(log(A8)) A8
428436
end
@@ -450,6 +458,12 @@ end
450458
A12 = convert(Matrix{elty}, [1 -1; 1 -1])
451459
@test typeof(log(A12)) == Array{Complex{Float64}, 2}
452460

461+
A13 = convert(Matrix{elty}, [2 0; 0 2])
462+
@test typeof(log(A13)) == Array{elty, 2}
463+
464+
T = elty == Float64 ? Symmetric : Hermitian
465+
@test typeof(log(T(A13))) == T{elty, Array{elty, 2}}
466+
453467
A1 = convert(Matrix{elty}, [4 2 0; 1 4 1; 1 1 4])
454468
logA1 = convert(Matrix{elty}, [1.329661349 0.5302876358 -0.06818951543;
455469
0.2310490602 1.295566591 0.2651438179;

0 commit comments

Comments
 (0)