Skip to content

Commit 606343e

Browse files
yuehhuamaleadt
authored andcommitted
dense(A)*sparse(B)
1 parent fe506d4 commit 606343e

File tree

3 files changed

+40
-11
lines changed

3 files changed

+40
-11
lines changed

Manifest.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,10 +142,10 @@ deps = ["Libdl"]
142142
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
143143

144144
[[LogExpFunctions]]
145-
deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"]
146-
git-tree-sha1 = "3d682c07e6dd250ed082f883dc88aee7996bf2cc"
145+
deps = ["ChainRulesCore", "DocStringExtensions", "IrrationalConstants", "LinearAlgebra"]
146+
git-tree-sha1 = "1f5097e3bce576e1cdf6dc9f051ab8c6e196b29e"
147147
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
148-
version = "0.3.0"
148+
version = "0.3.1"
149149

150150
[[Logging]]
151151
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"

lib/cusparse/interfaces.jl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,17 +79,22 @@ for (taga, untaga) in tag_wrappers, (wrapa, transa, unwrapa) in op_wrappers
7979
end
8080
end
8181

82+
function LinearAlgebra.:(*)(A::Union{CuMatrix, CuSparseMatrix}, B::CuSparseMatrix{T}) where {T}
83+
return mul!(similar(CuMatrix{T}, (size(A,1), size(B,2))), A, CuArray(B))
84+
end
85+
86+
function LinearAlgebra.mul!(C::CuMatrix{T}, A::CuMatrix{T}, B::CuSparseMatrix{T}) where {T}
87+
mul!(C, B', A', one(T), zero(T))
88+
return C'
89+
end
90+
8291
Base.:(+)(A::CuSparseMatrixCSR, B::CuSparseMatrixCSR) = geam(one(eltype(A)), A, one(eltype(A)), B, 'O')
8392
Base.:(-)(A::CuSparseMatrixCSR, B::CuSparseMatrixCSR) = geam(one(eltype(A)), A, -one(eltype(A)), B, 'O')
8493

85-
Base.:(+)(A::CuSparseMatrixCSR, B::Adjoint{T,<:CuSparseMatrixCSR}) where {T} =
86-
A + Transpose(conj(B.parent))
87-
Base.:(-)(A::CuSparseMatrixCSR, B::Adjoint{T,<:CuSparseMatrixCSR}) where {T} =
88-
A - Transpose(conj(B.parent))
89-
Base.:(+)(A::Adjoint{T,<:CuSparseMatrixCSR}, B::CuSparseMatrixCSR) where {T} =
90-
Transpose(conj(A.parent)) + B
91-
Base.:(-)(A::Adjoint{T,<:CuSparseMatrixCSR}, B::CuSparseMatrixCSR) where {T} =
92-
Transpose(conj(A.parent)) - B
94+
Base.:(+)(A::CuSparseMatrixCSR, B::Adjoint{T,<:CuSparseMatrixCSR}) where {T} = A + Transpose(conj(B.parent))
95+
Base.:(-)(A::CuSparseMatrixCSR, B::Adjoint{T,<:CuSparseMatrixCSR}) where {T} = A - Transpose(conj(B.parent))
96+
Base.:(+)(A::Adjoint{T,<:CuSparseMatrixCSR}, B::CuSparseMatrixCSR) where {T} = Transpose(conj(A.parent)) + B
97+
Base.:(-)(A::Adjoint{T,<:CuSparseMatrixCSR}, B::CuSparseMatrixCSR) where {T} = Transpose(conj(A.parent)) - B
9398
Base.:(+)(A::Adjoint{T,<:CuSparseMatrixCSR}, B::Adjoint{T,<:CuSparseMatrixCSR}) where {T} =
9499
Transpose(conj(A.parent)) + B
95100
Base.:(-)(A::Adjoint{T,<:CuSparseMatrixCSR}, B::Adjoint{T,<:CuSparseMatrixCSR}) where {T} =

test/cusparse/interfaces.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,30 @@ using LinearAlgebra, SparseArrays
137137
@test C collect(dC)
138138
end
139139

140+
@testset "dense(A)*sparse(B) $elty" for elty in [Float32, Float64, ComplexF32, ComplexF64]
141+
n = 10
142+
A = rand(elty, n, n)
143+
B = sprand(elty, n, n, rand())
144+
145+
dA = CuArray(A)
146+
dB = CUSPARSE.CuSparseMatrixCSR(B)
147+
148+
C = A * B
149+
dC = dA * dB
150+
@test C collect(dC)
151+
@test dC isa CuMatrix{elty}
152+
153+
C = B * A
154+
dC = dB * dA
155+
@test C collect(dC)
156+
@test dC isa CuMatrix{elty}
157+
158+
C = B * B
159+
dC = dB * dB
160+
@test C collect(dC)
161+
@test dC isa CuMatrix{elty}
162+
end
163+
140164
@testset "issue #1095 ($elty)" for elty in [Float32, Float64, ComplexF32, ComplexF64]
141165
# Test non-square matrices
142166
n, m, p = 10, 20, 4

0 commit comments

Comments
 (0)