Skip to content

Commit a163483

Browse files
authored
Fix tr for block SymTridiagonal (#55371)
This ensures that `tr` for a block `SymTridiagonal` symmetrizes the diagonal elements.
1 parent f4d1381 commit a163483

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

stdlib/LinearAlgebra/src/tridiag.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ Base.copy(S::Adjoint{<:Any,<:SymTridiagonal}) = SymTridiagonal(map(x -> copy.(ad
181181
ishermitian(S::SymTridiagonal) = isreal(S.dv) && isreal(_evview(S))
182182
issymmetric(S::SymTridiagonal) = true
183183

184-
tr(S::SymTridiagonal) = sum(S.dv)
184+
tr(S::SymTridiagonal) = sum(symmetric, S.dv)
185185

186186
@noinline function throw_diag_outofboundserror(n, sz)
187187
sz1, sz2 = sz

stdlib/LinearAlgebra/test/tridiag.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -471,7 +471,7 @@ end
471471
end
472472

473473
@testset "SymTridiagonal/Tridiagonal block matrix" begin
474-
M = [1 2; 2 4]
474+
M = [1 2; 3 4]
475475
n = 5
476476
A = SymTridiagonal(fill(M, n), fill(M, n-1))
477477
@test @inferred A[1,1] == Symmetric(M)
@@ -485,6 +485,9 @@ end
485485
@test_throws ArgumentError diag(A, n+1)
486486
@test_throws ArgumentError diag(A, -n-1)
487487

488+
@test tr(A) == sum(diag(A))
489+
@test issymmetric(tr(A))
490+
488491
A = Tridiagonal(fill(M, n-1), fill(M, n), fill(M, n-1))
489492
@test @inferred A[1,1] == M
490493
@test @inferred A[1,2] == M

0 commit comments

Comments
 (0)