Skip to content

Commit b6ea18c

Browse files
committed
CuSparseMatrixCSR plus and minus other sparse format
1 parent 3125790 commit b6ea18c

File tree

3 files changed

+40
-4
lines changed

3 files changed

+40
-4
lines changed

lib/cusparse/array.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ mutable struct CuSparseMatrixCSC{Tv} <: AbstractCuSparseMatrix{Tv}
4545
end
4646
end
4747

48+
CuSparseMatrixCSC(A::CuSparseMatrixCSC) = A
49+
4850
function CUDA.unsafe_free!(xs::CuSparseMatrixCSC)
4951
unsafe_free!(xs.colPtr)
5052
unsafe_free!(rowvals(xs))
@@ -72,6 +74,8 @@ mutable struct CuSparseMatrixCSR{Tv} <: AbstractCuSparseMatrix{Tv}
7274
end
7375
end
7476

77+
CuSparseMatrixCSR(A::CuSparseMatrixCSR) = A
78+
7579
function CUDA.unsafe_free!(xs::CuSparseMatrixCSR)
7680
unsafe_free!(xs.rowPtr)
7781
unsafe_free!(xs.colVal)
@@ -100,6 +104,8 @@ mutable struct CuSparseMatrixBSR{Tv} <: AbstractCuSparseMatrix{Tv}
100104
end
101105
end
102106

107+
CuSparseMatrixBSR(A::CuSparseMatrixBSR) = A
108+
103109
function CUDA.unsafe_free!(xs::CuSparseMatrixBSR)
104110
unsafe_free!(xs.rowPtr)
105111
unsafe_free!(xs.colVal)
@@ -126,6 +132,8 @@ mutable struct CuSparseMatrixCOO{Tv} <: AbstractCuSparseMatrix{Tv}
126132
end
127133
end
128134

135+
CuSparseMatrixCOO(A::CuSparseMatrixCOO) = A
136+
129137
"""
130138
Utility union type of [`CuSparseMatrixCSC`](@ref), [`CuSparseMatrixCSR`](@ref),
131139
[`CuSparseMatrixBSR`](@ref), [`CuSparseMatrixCOO`](@ref).

lib/cusparse/interfaces.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,22 +131,22 @@ function Base.:(-)(A::Transpose{T,<:CuSparseMatrixCSR}, B::Transpose{T,<:CuSpars
131131
return CuSparseMatrixCSR(cscC.colPtr, cscC.rowVal, cscC.nzVal, size(cscC))
132132
end
133133

134-
function Base.:(+)(A::CuSparseMatrixCSR, B::CuSparseMatrixCSC)
134+
function Base.:(+)(A::CuSparseMatrixCSR, B::CuSparseMatrix)
135135
csrB = CuSparseMatrixCSR(B)
136136
return geam(one(eltype(A)), A, one(eltype(A)), csrB, 'O')
137137
end
138138

139-
function Base.:(-)(A::CuSparseMatrixCSR, B::CuSparseMatrixCSC)
139+
function Base.:(-)(A::CuSparseMatrixCSR, B::CuSparseMatrix)
140140
csrB = CuSparseMatrixCSR(B)
141141
return geam(one(eltype(A)), A, -one(eltype(A)), csrB, 'O')
142142
end
143143

144-
function Base.:(+)(A::CuSparseMatrixCSC, B::CuSparseMatrixCSR)
144+
function Base.:(+)(A::CuSparseMatrix, B::CuSparseMatrixCSR)
145145
csrA = CuSparseMatrixCSR(A)
146146
return geam(one(eltype(A)), csrA, one(eltype(A)), B, 'O')
147147
end
148148

149-
function Base.:(-)(A::CuSparseMatrixCSC, B::CuSparseMatrixCSR)
149+
function Base.:(-)(A::CuSparseMatrix, B::CuSparseMatrixCSR)
150150
csrA = CuSparseMatrixCSR(A)
151151
return geam(one(eltype(A)), csrA, -one(eltype(A)), B, 'O')
152152
end

test/cusparse/interfaces.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,34 @@ using LinearAlgebra, SparseArrays
2424
@test C collect(dC)
2525
end
2626

27+
@testset "$f(B) $elty" for elty in [Float32, Float64, ComplexF32, ComplexF64],
28+
f in (CuSparseMatrixCSR, CuSparseMatrixCSC, CuSparseMatrixCOO, x->CuSparseMatrixBSR(x,1))
29+
n = 10
30+
alpha = rand()
31+
beta = rand()
32+
A = sprand(elty, n, n, rand())
33+
B = sprand(elty, n, n, rand())
34+
35+
dA = CuSparseMatrixCSR(A)
36+
dB = CuSparseMatrixCSR(B)
37+
38+
C = A + B
39+
dC = dA + f(dB)
40+
@test C collect(dC)
41+
42+
C = B + A
43+
dC = f(dB) + dA
44+
@test C collect(dC)
45+
46+
C = A - B
47+
dC = dA - f(dB)
48+
@test C collect(dC)
49+
50+
C = B - A
51+
dC = f(dB) - dA
52+
@test C collect(dC)
53+
end
54+
2755
@testset "dense(A)$(op)sparse(B) $elty" for elty in [Float32, Float64, ComplexF32, ComplexF64],
2856
op in [+, -]
2957
n = 10

0 commit comments

Comments
 (0)