Skip to content

Commit 8a4dd86

Browse files
jishnubKristofferC
authored andcommitted
Fix tr for block SymTridiagonal (#55371)
This ensures that `tr` for a block `SymTridiagonal` symmetrizes the diagonal elements. (cherry picked from commit a163483)
1 parent cb7a962 commit 8a4dd86

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

stdlib/LinearAlgebra/src/tridiag.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ Base.copy(S::Adjoint{<:Any,<:SymTridiagonal}) = SymTridiagonal(map(x -> copy.(ad
174174
ishermitian(S::SymTridiagonal) = isreal(S.dv) && isreal(_evview(S))
175175
issymmetric(S::SymTridiagonal) = true
176176

177-
tr(S::SymTridiagonal) = sum(S.dv)
177+
tr(S::SymTridiagonal) = sum(symmetric, S.dv)
178178

179179
@noinline function throw_diag_outofboundserror(n, sz)
180180
sz1, sz2 = sz

stdlib/LinearAlgebra/test/tridiag.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ end
468468
end
469469

470470
@testset "SymTridiagonal/Tridiagonal block matrix" begin
471-
M = [1 2; 2 4]
471+
M = [1 2; 3 4]
472472
n = 5
473473
A = SymTridiagonal(fill(M, n), fill(M, n-1))
474474
@test @inferred A[1,1] == Symmetric(M)
@@ -482,6 +482,9 @@ end
482482
@test_throws ArgumentError diag(A, n+1)
483483
@test_throws ArgumentError diag(A, -n-1)
484484

485+
@test tr(A) == sum(diag(A))
486+
@test issymmetric(tr(A))
487+
485488
A = Tridiagonal(fill(M, n-1), fill(M, n), fill(M, n-1))
486489
@test @inferred A[1,1] == M
487490
@test @inferred A[1,2] == M

0 commit comments

Comments
 (0)