From 04259daf8a339f99d7cd8503d9c1f154b247e4e1 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 21 Oct 2024 11:31:36 +0530 Subject: [PATCH] Reroute` (Upper/Lower)Triangular * Diagonal` through `__muldiag` (#55984) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Currently, `::Diagonal * ::AbstractMatrix` calls the method `LinearAlgebra.__muldiag!` in general that scales the rows, and similarly for the diagonal on the right. The implementation of `__muldiag` was duplicating the logic in `LinearAlgebra.modify!` and the methods for `MulAddMul`. This PR replaces the various branches with calls to `modify!` instead. I've also extracted the multiplication logic into its own function `__muldiag_nonzeroalpha!` so that this may be specialized for matrix types, such as triangular ones. Secondly, `::Diagonal * ::UpperTriangular` (and similarly, other triangular matrices) was specialized to forward the multiplication to the parent of the triangular. For strided matrices, however, it makes more sense to use the structure and scale only the filled half of the matrix. Firstly, this improves performance, and secondly, this avoids errors in case the parent isn't fully initialized corresponding to the structural zero elements. Performance improvement: ```julia julia> D = Diagonal(1:400); julia> U = UpperTriangular(zeros(size(D))); julia> @btime $D * $U; 314.944 μs (3 allocations: 1.22 MiB) # v"1.12.0-DEV.1288" 195.960 μs (3 allocations: 1.22 MiB) # This PR ``` Fix: ```julia julia> M = Matrix{BigFloat}(undef, 2, 2); julia> M[1,1] = M[2,2] = M[1,2] = 3; julia> U = UpperTriangular(M) 2×2 UpperTriangular{BigFloat, Matrix{BigFloat}}: 3.0 3.0 ⋅ 3.0 julia> D = Diagonal(1:2); julia> U * D # works after this PR 2×2 UpperTriangular{BigFloat, Matrix{BigFloat}}: 3.0 6.0 ⋅ 6.0 ``` --- stdlib/LinearAlgebra/src/LinearAlgebra.jl | 2 + stdlib/LinearAlgebra/src/diagonal.jl | 158 +++++++++++++--------- stdlib/LinearAlgebra/test/diagonal.jl | 40 +++++- 3 files changed, 134 insertions(+), 66 deletions(-) diff --git a/stdlib/LinearAlgebra/src/LinearAlgebra.jl b/stdlib/LinearAlgebra/src/LinearAlgebra.jl index 15354603943c2..88fc3476c9d7f 100644 --- a/stdlib/LinearAlgebra/src/LinearAlgebra.jl +++ b/stdlib/LinearAlgebra/src/LinearAlgebra.jl @@ -655,6 +655,8 @@ matprod_dest(A::StructuredMatrix, B::Diagonal, TS) = _matprod_dest_diag(A, TS) matprod_dest(A::Diagonal, B::StructuredMatrix, TS) = _matprod_dest_diag(B, TS) matprod_dest(A::Diagonal, B::Diagonal, TS) = _matprod_dest_diag(B, TS) _matprod_dest_diag(A, TS) = similar(A, TS) +_matprod_dest_diag(A::UnitUpperTriangular, TS) = UpperTriangular(similar(parent(A), TS)) +_matprod_dest_diag(A::UnitLowerTriangular, TS) = LowerTriangular(similar(parent(A), TS)) function _matprod_dest_diag(A::SymTridiagonal, TS) n = size(A, 1) ev = similar(A, TS, max(0, n-1)) diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index 6e8ce96259fc1..8ba4c3d457e83 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -396,82 +396,120 @@ function lmul!(D::Diagonal, T::Tridiagonal) return T end -function __muldiag!(out, D::Diagonal, B, _add::MulAddMul{ais1,bis0}) where {ais1,bis0} +@inline function __muldiag_nonzeroalpha!(out, D::Diagonal, B, _add::MulAddMul) + @inbounds for j in axes(B, 2) + @simd for i in axes(B, 1) + _modify!(_add, D.diag[i] * B[i,j], out, (i,j)) + end + end + out +end +_maybe_unwrap_tri(out, A) = out, A +_maybe_unwrap_tri(out::UpperTriangular, A::UpperOrUnitUpperTriangular) = parent(out), parent(A) +_maybe_unwrap_tri(out::LowerTriangular, A::LowerOrUnitLowerTriangular) = parent(out), parent(A) +@inline function __muldiag_nonzeroalpha!(out, D::Diagonal, B::UpperOrLowerTriangular, _add::MulAddMul) + isunit = B isa Union{UnitUpperTriangular, UnitLowerTriangular} + # if both B and out have the same upper/lower triangular structure, + # we may directly read and write from the parents + out_maybeparent, B_maybeparent = _maybe_unwrap_tri(out, B) + for j in axes(B, 2) + if isunit + _modify!(_add, D.diag[j] * B[j,j], out, (j,j)) + end + rowrange = B isa UpperOrUnitUpperTriangular ? (1:min(j-isunit, size(B,1))) : (j+isunit:size(B,1)) + @inbounds @simd for i in rowrange + _modify!(_add, D.diag[i] * B_maybeparent[i,j], out_maybeparent, (i,j)) + end + end + out +end +function __muldiag!(out, D::Diagonal, B, _add::MulAddMul) require_one_based_indexing(out, B) alpha, beta = _add.alpha, _add.beta if iszero(alpha) _rmul_or_fill!(out, beta) else - if bis0 - @inbounds for j in axes(B, 2) - @simd for i in axes(B, 1) - out[i,j] = D.diag[i] * B[i,j] * alpha - end - end - else - @inbounds for j in axes(B, 2) - @simd for i in axes(B, 1) - out[i,j] = D.diag[i] * B[i,j] * alpha + out[i,j] * beta - end - end - end + __muldiag_nonzeroalpha!(out, D, B, _add) end return out end -function __muldiag!(out, A, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0} + +@inline function __muldiag_nonzeroalpha!(out, A, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0} + beta = _add.beta + _add_aisone = MulAddMul{true,bis0,Bool,typeof(beta)}(true, beta) + @inbounds for j in axes(A, 2) + dja = _add(D.diag[j]) + @simd for i in axes(A, 1) + _modify!(_add_aisone, A[i,j] * dja, out, (i,j)) + end + end + out +end +@inline function __muldiag_nonzeroalpha!(out, A::UpperOrLowerTriangular, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0} + isunit = A isa Union{UnitUpperTriangular, UnitLowerTriangular} + beta = _add.beta + # since alpha is multiplied to the diagonal element of D, + # we may skip alpha in the second multiplication by setting ais1 to true + _add_aisone = MulAddMul{true,bis0,Bool,typeof(beta)}(true, beta) + # if both A and out have the same upper/lower triangular structure, + # we may directly read and write from the parents + out_maybeparent, A_maybeparent = _maybe_unwrap_tri(out, A) + @inbounds for j in axes(A, 2) + dja = _add(D.diag[j]) + if isunit + _modify!(_add_aisone, A[j,j] * dja, out, (j,j)) + end + rowrange = A isa UpperOrUnitUpperTriangular ? (1:min(j-isunit, size(A,1))) : (j+isunit:size(A,1)) + @simd for i in rowrange + _modify!(_add_aisone, A_maybeparent[i,j] * dja, out_maybeparent, (i,j)) + end + end + out +end +function __muldiag!(out, A, D::Diagonal, _add::MulAddMul) require_one_based_indexing(out, A) alpha, beta = _add.alpha, _add.beta if iszero(alpha) _rmul_or_fill!(out, beta) else - if bis0 - @inbounds for j in axes(A, 2) - dja = D.diag[j] * alpha - @simd for i in axes(A, 1) - out[i,j] = A[i,j] * dja - end - end - else - @inbounds for j in axes(A, 2) - dja = D.diag[j] * alpha - @simd for i in axes(A, 1) - out[i,j] = A[i,j] * dja + out[i,j] * beta - end - end - end + __muldiag_nonzeroalpha!(out, A, D, _add) end return out end -function __muldiag!(out::Diagonal, D1::Diagonal, D2::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0} + +@inline function __muldiag_nonzeroalpha!(out::Diagonal, D1::Diagonal, D2::Diagonal, _add::MulAddMul) d1 = D1.diag d2 = D2.diag + outd = out.diag + @inbounds @simd for i in eachindex(d1, d2, outd) + _modify!(_add, d1[i] * d2[i], outd, i) + end + out +end +function __muldiag!(out::Diagonal, D1::Diagonal, D2::Diagonal, _add::MulAddMul) alpha, beta = _add.alpha, _add.beta if iszero(alpha) _rmul_or_fill!(out.diag, beta) else - if bis0 - @inbounds @simd for i in eachindex(out.diag) - out.diag[i] = d1[i] * d2[i] * alpha - end - else - @inbounds @simd for i in eachindex(out.diag) - out.diag[i] = d1[i] * d2[i] * alpha + out.diag[i] * beta - end - end + __muldiag_nonzeroalpha!(out, D1, D2, _add) end return out end -function __muldiag!(out, D1::Diagonal, D2::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0} - require_one_based_indexing(out) - alpha, beta = _add.alpha, _add.beta - mA = size(D1, 1) +@inline function __muldiag_nonzeroalpha!(out, D1::Diagonal, D2::Diagonal, _add::MulAddMul) d1 = D1.diag d2 = D2.diag + @inbounds @simd for i in eachindex(d1, d2) + _modify!(_add, d1[i] * d2[i], out, (i,i)) + end + out +end +function __muldiag!(out, D1::Diagonal, D2::Diagonal, _add::MulAddMul{ais1}) where {ais1} + require_one_based_indexing(out) + alpha, beta = _add.alpha, _add.beta _rmul_or_fill!(out, beta) if !iszero(alpha) - @inbounds @simd for i in 1:mA - out[i,i] += d1[i] * d2[i] * alpha - end + _add_bis1 = MulAddMul{ais1,false,typeof(alpha),Bool}(alpha,true) + __muldiag_nonzeroalpha!(out, D1, D2, _add_bis1) end return out end @@ -658,31 +696,21 @@ for Tri in (:UpperTriangular, :LowerTriangular) @eval $fun(A::$Tri, D::Diagonal) = $Tri($fun(A.data, D)) @eval $fun(A::$UTri, D::Diagonal) = $Tri(_setdiag!($fun(A.data, D), $f, D.diag)) end + @eval *(A::$Tri{<:Any, <:StridedMaybeAdjOrTransMat}, D::Diagonal) = + @invoke *(A::AbstractMatrix, D::Diagonal) + @eval *(A::$UTri{<:Any, <:StridedMaybeAdjOrTransMat}, D::Diagonal) = + @invoke *(A::AbstractMatrix, D::Diagonal) for (fun, f) in zip((:*, :lmul!, :ldiv!, :\), (:identity, :identity, :inv, :inv)) @eval $fun(D::Diagonal, A::$Tri) = $Tri($fun(D, A.data)) @eval $fun(D::Diagonal, A::$UTri) = $Tri(_setdiag!($fun(D, A.data), $f, D.diag)) end + @eval *(D::Diagonal, A::$Tri{<:Any, <:StridedMaybeAdjOrTransMat}) = + @invoke *(D::Diagonal, A::AbstractMatrix) + @eval *(D::Diagonal, A::$UTri{<:Any, <:StridedMaybeAdjOrTransMat}) = + @invoke *(D::Diagonal, A::AbstractMatrix) # 3-arg ldiv! @eval ldiv!(C::$Tri, D::Diagonal, A::$Tri) = $Tri(ldiv!(C.data, D, A.data)) @eval ldiv!(C::$Tri, D::Diagonal, A::$UTri) = $Tri(_setdiag!(ldiv!(C.data, D, A.data), inv, D.diag)) - # 3-arg mul! is disambiguated in special.jl - # 5-arg mul! - @eval _mul!(C::$Tri, D::Diagonal, A::$Tri, _add) = $Tri(mul!(C.data, D, A.data, _add.alpha, _add.beta)) - @eval function _mul!(C::$Tri, D::Diagonal, A::$UTri, _add::MulAddMul{ais1,bis0}) where {ais1,bis0} - α, β = _add.alpha, _add.beta - iszero(α) && return _rmul_or_fill!(C, β) - diag′ = bis0 ? nothing : diag(C) - data = mul!(C.data, D, A.data, α, β) - $Tri(_setdiag!(data, _add, D.diag, diag′)) - end - @eval _mul!(C::$Tri, A::$Tri, D::Diagonal, _add) = $Tri(mul!(C.data, A.data, D, _add.alpha, _add.beta)) - @eval function _mul!(C::$Tri, A::$UTri, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0} - α, β = _add.alpha, _add.beta - iszero(α) && return _rmul_or_fill!(C, β) - diag′ = bis0 ? nothing : diag(C) - data = mul!(C.data, A.data, D, α, β) - $Tri(_setdiag!(data, _add, D.diag, diag′)) - end end @inline function kron!(C::AbstractMatrix, A::Diagonal, B::Diagonal) diff --git a/stdlib/LinearAlgebra/test/diagonal.jl b/stdlib/LinearAlgebra/test/diagonal.jl index 1c3a9dfa676ac..380a0465028d1 100644 --- a/stdlib/LinearAlgebra/test/diagonal.jl +++ b/stdlib/LinearAlgebra/test/diagonal.jl @@ -1188,7 +1188,7 @@ end @test oneunit(D3) isa typeof(D3) end -@testset "AbstractTriangular" for (Tri, UTri) in ((UpperTriangular, UnitUpperTriangular), (LowerTriangular, UnitLowerTriangular)) +@testset "$Tri" for (Tri, UTri) in ((UpperTriangular, UnitUpperTriangular), (LowerTriangular, UnitLowerTriangular)) A = randn(4, 4) TriA = Tri(A) UTriA = UTri(A) @@ -1218,6 +1218,44 @@ end @test outTri === mul!(outTri, D, UTriA, 2, 1)::Tri == mul!(out, D, Matrix(UTriA), 2, 1) @test outTri === mul!(outTri, TriA, D, 2, 1)::Tri == mul!(out, Matrix(TriA), D, 2, 1) @test outTri === mul!(outTri, UTriA, D, 2, 1)::Tri == mul!(out, Matrix(UTriA), D, 2, 1) + + # we may write to a Unit triangular if the diagonal is preserved + ID = Diagonal(ones(size(UTriA,2))) + @test mul!(copy(UTriA), UTriA, ID) == UTriA + @test mul!(copy(UTriA), ID, UTriA) == UTriA + + @testset "partly filled parents" begin + M = Matrix{BigFloat}(undef, 2, 2) + M[1,1] = M[2,2] = 3 + isupper = Tri == UpperTriangular + M[1+!isupper, 1+isupper] = 3 + D = Diagonal(1:2) + T = Tri(M) + TA = Array(T) + @test T * D == TA * D + @test D * T == D * TA + @test mul!(copy(T), T, D, 2, 3) == 2T * D + 3T + @test mul!(copy(T), D, T, 2, 3) == 2D * T + 3T + + U = UTri(M) + UA = Array(U) + @test U * D == UA * D + @test D * U == D * UA + @test mul!(copy(T), U, D, 2, 3) == 2 * UA * D + 3TA + @test mul!(copy(T), D, U, 2, 3) == 2 * D * UA + 3TA + + M2 = Matrix{BigFloat}(undef, 2, 2) + M2[1+!isupper, 1+isupper] = 3 + U = UTri(M2) + UA = Array(U) + @test U * D == UA * D + @test D * U == D * UA + ID = Diagonal(ones(size(U,2))) + @test mul!(copy(U), U, ID) == U + @test mul!(copy(U), ID, U) == U + @test mul!(copy(U), U, ID, 2, -1) == U + @test mul!(copy(U), ID, U, 2, -1) == U + end end struct SMatrix1{T} <: AbstractArray{T,2}