Skip to content

Commit 58e4f29

Browse files
committed
Dispatch on block matrix types
1 parent ea064f7 commit 58e4f29

File tree

3 files changed

+13
-11
lines changed

3 files changed

+13
-11
lines changed

stdlib/LinearAlgebra/src/diagonal.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -191,19 +191,21 @@ end
191191
Return the appropriate zero element `A[i, j]` corresponding to a banded matrix `A`.
192192
"""
193193
diagzero(::Diagonal{T}, i, j) where {T} = zero(T)
194-
diagzero(D::Diagonal{<:AbstractMatrix{T}}, i, j) where {T} = diagzero(T, axes(D.diag[i], 1), axes(D.diag[j], 2))
194+
diagzero(D::Diagonal{M, <:AbstractVector{M}}, i, j) where {T,M<:AbstractMatrix{T}} =
195+
diagzero(M, axes(D.diag[i], 1), axes(D.diag[j], 2))
195196
# dispatching on the axes permits specializing on the axis types to return something other than an Array
196-
diagzero(T::Type, ax::Union{AbstractUnitRange, Integer}...) = diagzero(T, ax)
197-
diagzero(T::Type, ::Tuple{}) = zeros(T)
197+
diagzero(M::Type, ax::Union{AbstractUnitRange, Integer}...) = diagzero(M, ax)
198+
diagzero(M::Type, ::Tuple{}) = zeros(eltype(M))
198199
"""
199-
diagzero(T::Type, ax::Tuple{AbstractUnitRange, Vararg{AbstractUnitRange}})
200+
diagzero(::Type{M}, ax::Tuple{AbstractUnitRange, Vararg{AbstractUnitRange}}) where {M<:AbstractMatrix}
200201
201-
Return an appropriate zero-ed array with either the axes `ax`, or the `size` `map(length, ax)`,
202-
which may be used as a structural zero element of a banded matrix. By default, this falls back to
202+
Return an appropriate zero-ed matrix similar to `M`, with either
203+
the axes `ax`, or the `size` `map(length, ax)`.
204+
This will be used as a structural zero element of a banded matrix. By default, `diagzero` falls back to
203205
using the size along each axis to construct the result.
204206
"""
205-
diagzero(T::Type, ax::Tuple{AbstractUnitRange, Vararg{AbstractUnitRange}}) = diagzero(T, map(length, ax))
206-
diagzero(T::Type, sz::Tuple{Integer, Vararg{Integer}}) = zeros(T, sz)
207+
diagzero(M::Type, ax::Tuple{AbstractUnitRange, Vararg{AbstractUnitRange}}) = diagzero(M, map(length, ax))
208+
diagzero(::Type{M}, sz::Tuple{Integer, Vararg{Integer}}) where {M<:AbstractMatrix} = zeros(eltype(M), sz)
207209

208210
@inline function getindex(D::Diagonal, b::BandIndex)
209211
@boundscheck checkbounds(D, b)

stdlib/LinearAlgebra/test/diagonal.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -817,9 +817,6 @@ end
817817
@test fill(S,3,2)' * D == fill(S' * S, 2, 3)
818818

819819
@testset "indexing with non-standard-axes" begin
820-
LinearAlgebra.diagzero(T::Type, ax::Tuple{SizedArrays.SOneTo, Vararg{SizedArrays.SOneTo}}) =
821-
zeros(T, ax)
822-
823820
s = SizedArrays.SizedArray{(2,2)}([1 2; 3 4])
824821
D = Diagonal(fill(s,3))
825822
@test @inferred(D[1,2]) isa typeof(s)

test/testhelpers/SizedArrays.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,4 +99,7 @@ mul!(dest::AbstractMatrix, S1::SizedMatrix, S2::SizedMatrix, α::Number, β::Num
9999
mul!(dest::AbstractVector, M::AbstractMatrix, v::SizedVector, α::Number, β::Number) =
100100
mul!(dest, M, _data(v), α, β)
101101

102+
LinearAlgebra.diagzero(::Type{S}, ax::Tuple{SizedArrays.SOneTo, Vararg{SizedArrays.SOneTo}}) where {S<:SizedArray} =
103+
zeros(eltype(S), ax)
104+
102105
end

0 commit comments

Comments
 (0)