Skip to content

Commit

Permalink
Fix dispatch of SparseMatrixCSC*Diagonal multiplication (#29045)
Browse files Browse the repository at this point in the history
* Fix type signature of mul! methods for multiplying SparseMatrixCSCs with Diagonal matrices. Type signature for diagonal matrices was wrong, causing fallback to generic Matmul.

* Add SparseMatrixCSC*Diagonal dispatch test

* Fix trailing whitespace

* Don't copy with deepcopy

(cherry picked from commit 8d99356)
  • Loading branch information
Pbellive authored and KristofferC committed Sep 8, 2018
1 parent 4137472 commit 8729c63
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
4 changes: 2 additions & 2 deletions stdlib/SparseArrays/src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -968,7 +968,7 @@ function copyinds!(C::SparseMatrixCSC, A::SparseMatrixCSC)
end

# multiply by diagonal matrix as vector
function mul!(C::SparseMatrixCSC, A::SparseMatrixCSC, D::Diagonal{<:Vector})
function mul!(C::SparseMatrixCSC, A::SparseMatrixCSC, D::Diagonal{T, <:Vector}) where T
m, n = size(A)
b = D.diag
(n==length(b) && size(A)==size(C)) || throw(DimensionMismatch())
Expand All @@ -982,7 +982,7 @@ function mul!(C::SparseMatrixCSC, A::SparseMatrixCSC, D::Diagonal{<:Vector})
C
end

function mul!(C::SparseMatrixCSC, D::Diagonal{<:Vector}, A::SparseMatrixCSC)
function mul!(C::SparseMatrixCSC, D::Diagonal{T, <:Vector}, A::SparseMatrixCSC) where T
m, n = size(A)
b = D.diag
(m==length(b) && size(A)==size(C)) || throw(DimensionMismatch())
Expand Down
12 changes: 12 additions & 0 deletions stdlib/SparseArrays/test/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using LinearAlgebra
using Base.Printf: @printf
using Random
using Test: guardseed
using InteractiveUtils: @which

@testset "issparse" begin
@test issparse(sparse(fill(1,5,5)))
Expand Down Expand Up @@ -2295,4 +2296,15 @@ end
@test typeof(a) === typeof(na)
end

#PR #29045
@testset "Issue #28934" begin
A = sprand(5,5,0.5)
D = Diagonal(rand(5))
C = copy(A)
m1 = @which mul!(C,A,D)
m2 = @which mul!(C,D,A)
@test m1.module == SparseArrays
@test m2.module == SparseArrays
end

end # module

0 comments on commit 8729c63

Please sign in to comment.