Skip to content

Commit b8093de

Browse files
authored
Broadcast binary ops involving strided triangular (#55798)
Currently, we evaluate expressions like `(A::UpperTriangular) + (B::UpperTriangular)` using broadcasting if both `A` and `B` have strided parents, and forward the summation to the parents otherwise. This PR changes this to use broadcasting if either of the two has a strided parent. This avoids accessing the parent corresponding to the structural zero elements, as the index might not be initialized. Fixes https://github.com/JuliaLang/julia/issues/55590 This isn't a general fix, as we still sum the parents if neither is strided. However, it will address common cases. This also improves performance, as we only need to loop over one half: ```julia julia> using LinearAlgebra julia> U = UpperTriangular(zeros(100,100)); julia> B = Bidiagonal(zeros(100), zeros(99), :U); julia> @Btime $U + $B; 35.530 μs (4 allocations: 78.22 KiB) # nightly 13.441 μs (4 allocations: 78.22 KiB) # This PR ```
1 parent a73ba3b commit b8093de

File tree

3 files changed

+94
-30
lines changed

3 files changed

+94
-30
lines changed

stdlib/LinearAlgebra/src/symmetric.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -687,10 +687,10 @@ for f in (:+, :-)
687687
@eval begin
688688
$f(A::Hermitian, B::Symmetric{<:Real}) = $f(A, Hermitian(parent(B), sym_uplo(B.uplo)))
689689
$f(A::Symmetric{<:Real}, B::Hermitian) = $f(Hermitian(parent(A), sym_uplo(A.uplo)), B)
690-
$f(A::SymTridiagonal, B::Symmetric) = Symmetric($f(A, B.data), sym_uplo(B.uplo))
691-
$f(A::Symmetric, B::SymTridiagonal) = Symmetric($f(A.data, B), sym_uplo(A.uplo))
692-
$f(A::SymTridiagonal{<:Real}, B::Hermitian) = Hermitian($f(A, B.data), sym_uplo(B.uplo))
693-
$f(A::Hermitian, B::SymTridiagonal{<:Real}) = Hermitian($f(A.data, B), sym_uplo(A.uplo))
690+
$f(A::SymTridiagonal, B::Symmetric) = $f(Symmetric(A, sym_uplo(B.uplo)), B)
691+
$f(A::Symmetric, B::SymTridiagonal) = $f(A, Symmetric(B, sym_uplo(A.uplo)))
692+
$f(A::SymTridiagonal{<:Real}, B::Hermitian) = $f(Hermitian(A, sym_uplo(B.uplo)), B)
693+
$f(A::Hermitian, B::SymTridiagonal{<:Real}) = $f(A, Hermitian(B, sym_uplo(A.uplo)))
694694
end
695695
end
696696

stdlib/LinearAlgebra/src/triangular.jl

Lines changed: 65 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -850,35 +850,74 @@ fillstored!(A::UpperTriangular, x) = (fillband!(A.data, x, 0, size(A,2)-1);
850850
fillstored!(A::UnitUpperTriangular, x) = (fillband!(A.data, x, 1, size(A,2)-1); A)
851851

852852
# Binary operations
853-
+(A::UpperTriangular, B::UpperTriangular) = UpperTriangular(A.data + B.data)
854-
+(A::LowerTriangular, B::LowerTriangular) = LowerTriangular(A.data + B.data)
855-
+(A::UpperTriangular, B::UnitUpperTriangular) = UpperTriangular(A.data + triu(B.data, 1) + I)
856-
+(A::LowerTriangular, B::UnitLowerTriangular) = LowerTriangular(A.data + tril(B.data, -1) + I)
857-
+(A::UnitUpperTriangular, B::UpperTriangular) = UpperTriangular(triu(A.data, 1) + B.data + I)
858-
+(A::UnitLowerTriangular, B::LowerTriangular) = LowerTriangular(tril(A.data, -1) + B.data + I)
859-
+(A::UnitUpperTriangular, B::UnitUpperTriangular) = UpperTriangular(triu(A.data, 1) + triu(B.data, 1) + 2I)
860-
+(A::UnitLowerTriangular, B::UnitLowerTriangular) = LowerTriangular(tril(A.data, -1) + tril(B.data, -1) + 2I)
853+
# use broadcasting if the parents are strided, where we loop only over the triangular part
854+
function +(A::UpperTriangular, B::UpperTriangular)
855+
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .+ B
856+
UpperTriangular(A.data + B.data)
857+
end
858+
function +(A::LowerTriangular, B::LowerTriangular)
859+
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .+ B
860+
LowerTriangular(A.data + B.data)
861+
end
862+
function +(A::UpperTriangular, B::UnitUpperTriangular)
863+
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .+ B
864+
UpperTriangular(A.data + triu(B.data, 1) + I)
865+
end
866+
function +(A::LowerTriangular, B::UnitLowerTriangular)
867+
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .+ B
868+
LowerTriangular(A.data + tril(B.data, -1) + I)
869+
end
870+
function +(A::UnitUpperTriangular, B::UpperTriangular)
871+
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .+ B
872+
UpperTriangular(triu(A.data, 1) + B.data + I)
873+
end
874+
function +(A::UnitLowerTriangular, B::LowerTriangular)
875+
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .+ B
876+
LowerTriangular(tril(A.data, -1) + B.data + I)
877+
end
878+
function +(A::UnitUpperTriangular, B::UnitUpperTriangular)
879+
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .+ B
880+
UpperTriangular(triu(A.data, 1) + triu(B.data, 1) + 2I)
881+
end
882+
function +(A::UnitLowerTriangular, B::UnitLowerTriangular)
883+
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .+ B
884+
LowerTriangular(tril(A.data, -1) + tril(B.data, -1) + 2I)
885+
end
861886
+(A::AbstractTriangular, B::AbstractTriangular) = copyto!(similar(parent(A)), A) + copyto!(similar(parent(B)), B)
862887

863-
-(A::UpperTriangular, B::UpperTriangular) = UpperTriangular(A.data - B.data)
864-
-(A::LowerTriangular, B::LowerTriangular) = LowerTriangular(A.data - B.data)
865-
-(A::UpperTriangular, B::UnitUpperTriangular) = UpperTriangular(A.data - triu(B.data, 1) - I)
866-
-(A::LowerTriangular, B::UnitLowerTriangular) = LowerTriangular(A.data - tril(B.data, -1) - I)
867-
-(A::UnitUpperTriangular, B::UpperTriangular) = UpperTriangular(triu(A.data, 1) - B.data + I)
868-
-(A::UnitLowerTriangular, B::LowerTriangular) = LowerTriangular(tril(A.data, -1) - B.data + I)
869-
-(A::UnitUpperTriangular, B::UnitUpperTriangular) = UpperTriangular(triu(A.data, 1) - triu(B.data, 1))
870-
-(A::UnitLowerTriangular, B::UnitLowerTriangular) = LowerTriangular(tril(A.data, -1) - tril(B.data, -1))
871-
-(A::AbstractTriangular, B::AbstractTriangular) = copyto!(similar(parent(A)), A) - copyto!(similar(parent(B)), B)
872-
873-
# use broadcasting if the parents are strided, where we loop only over the triangular part
874-
for op in (:+, :-)
875-
for TM1 in (:LowerTriangular, :UnitLowerTriangular), TM2 in (:LowerTriangular, :UnitLowerTriangular)
876-
@eval $op(A::$TM1{<:Any, <:StridedMaybeAdjOrTransMat}, B::$TM2{<:Any, <:StridedMaybeAdjOrTransMat}) = broadcast($op, A, B)
877-
end
878-
for TM1 in (:UpperTriangular, :UnitUpperTriangular), TM2 in (:UpperTriangular, :UnitUpperTriangular)
879-
@eval $op(A::$TM1{<:Any, <:StridedMaybeAdjOrTransMat}, B::$TM2{<:Any, <:StridedMaybeAdjOrTransMat}) = broadcast($op, A, B)
880-
end
888+
function -(A::UpperTriangular, B::UpperTriangular)
889+
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .- B
890+
UpperTriangular(A.data - B.data)
891+
end
892+
function -(A::LowerTriangular, B::LowerTriangular)
893+
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .- B
894+
LowerTriangular(A.data - B.data)
895+
end
896+
function -(A::UpperTriangular, B::UnitUpperTriangular)
897+
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .- B
898+
UpperTriangular(A.data - triu(B.data, 1) - I)
899+
end
900+
function -(A::LowerTriangular, B::UnitLowerTriangular)
901+
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .- B
902+
LowerTriangular(A.data - tril(B.data, -1) - I)
881903
end
904+
function -(A::UnitUpperTriangular, B::UpperTriangular)
905+
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .- B
906+
UpperTriangular(triu(A.data, 1) - B.data + I)
907+
end
908+
function -(A::UnitLowerTriangular, B::LowerTriangular)
909+
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .- B
910+
LowerTriangular(tril(A.data, -1) - B.data + I)
911+
end
912+
function -(A::UnitUpperTriangular, B::UnitUpperTriangular)
913+
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .- B
914+
UpperTriangular(triu(A.data, 1) - triu(B.data, 1))
915+
end
916+
function -(A::UnitLowerTriangular, B::UnitLowerTriangular)
917+
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .- B
918+
LowerTriangular(tril(A.data, -1) - tril(B.data, -1))
919+
end
920+
-(A::AbstractTriangular, B::AbstractTriangular) = copyto!(similar(parent(A)), A) - copyto!(similar(parent(B)), B)
882921

883922
function kron(A::UpperTriangular{<:Number,<:StridedMaybeAdjOrTransMat}, B::UpperTriangular{<:Number,<:StridedMaybeAdjOrTransMat})
884923
C = UpperTriangular(Matrix{promote_op(*, eltype(A), eltype(B))}(undef, _kronsize(A, B)))

stdlib/LinearAlgebra/test/symmetric.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1135,4 +1135,29 @@ end
11351135
end
11361136
end
11371137

1138+
@testset "partly iniitalized matrices" begin
1139+
a = Matrix{BigFloat}(undef, 2,2)
1140+
a[1] = 1; a[3] = 1; a[4] = 1
1141+
h = Hermitian(a)
1142+
s = Symmetric(a)
1143+
d = Diagonal([1,1])
1144+
symT = SymTridiagonal([1 1;1 1])
1145+
@test h+d == Array(h) + Array(d)
1146+
@test h+symT == Array(h) + Array(symT)
1147+
@test s+d == Array(s) + Array(d)
1148+
@test s+symT == Array(s) + Array(symT)
1149+
@test h-d == Array(h) - Array(d)
1150+
@test h-symT == Array(h) - Array(symT)
1151+
@test s-d == Array(s) - Array(d)
1152+
@test s-symT == Array(s) - Array(symT)
1153+
@test d+h == Array(d) + Array(h)
1154+
@test symT+h == Array(symT) + Array(h)
1155+
@test d+s == Array(d) + Array(s)
1156+
@test symT+s == Array(symT) + Array(s)
1157+
@test d-h == Array(d) - Array(h)
1158+
@test symT-h == Array(symT) - Array(h)
1159+
@test d-s == Array(d) - Array(s)
1160+
@test symT-s == Array(symT) - Array(s)
1161+
end
1162+
11381163
end # module TestSymmetric

0 commit comments

Comments
 (0)