Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 10 additions & 15 deletions src/sparse/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ const ROCSparseMatrix{Tv, Ti} = Union{

const ROCSparseVecOrMat = Union{ROCSparseVector, ROCSparseMatrix}

# NOTE: we use Cint as default Ti on CUDA instead of Int to provide
# maximum compatiblity to old CUSPARSE APIs
# NOTE: we use Cint as default Ti on ROCm instead of Int to provide
# maximum compatiblity to old ROCSparse APIs
# The same pattern was followed for AMDGPU as well
function ROCSparseVector{Tv}(iPtr::ROCVector{<:Integer}, nzVal::ROCVector, len::Integer) where Tv
ROCSparseVector{Tv, Cint}(convert(ROCVector{Cint}, iPtr), nzVal, len)
Expand Down Expand Up @@ -284,6 +284,7 @@ SparseArrays.nnz(g::AbstractROCSparseArray) = g.nnz
SparseArrays.nonzeros(g::AbstractROCSparseArray) = g.nzVal

SparseArrays.nonzeroinds(g::AbstractROCSparseVector) = g.iPtr
SparseArrays.rowvals(g::AbstractROCSparseVector) = nonzeroinds(g)

SparseArrays.rowvals(g::ROCSparseMatrixCSC) = g.rowVal
SparseArrays.getcolptr(g::ROCSparseMatrixCSC) = g.colPtr
Expand Down Expand Up @@ -422,14 +423,8 @@ ROCSparseMatrixCSC(x::Transpose{T}) where {T} = ROCSparseMatrixCSC{T}(x)
ROCSparseMatrixCSC(x::Adjoint{T}) where {T} = ROCSparseMatrixCSC{T}(x)

# gpu to cpu
function SparseVector(x::ROCSparseVector)
SparseVector(length(x), Array(nonzeroinds(x)), Array(nonzeros(x)))
end

function SparseMatrixCSC(x::ROCSparseMatrixCSC)
SparseMatrixCSC(size(x)..., Array(x.colPtr), Array(rowvals(x)), Array(nonzeros(x)))
end

SparseVector(x::ROCSparseVector) = SparseVector(length(x), Array(nonzeroinds(x)), Array(nonzeros(x)))
SparseMatrixCSC(x::ROCSparseMatrixCSC) = SparseMatrixCSC(size(x)..., Array(x.colPtr), Array(rowvals(x)), Array(nonzeros(x)))
SparseMatrixCSC(x::ROCSparseMatrixCSR) = SparseMatrixCSC(ROCSparseMatrixCSC(x)) # no direct conversion
SparseMatrixCSC(x::ROCSparseMatrixBSR) = SparseMatrixCSC(ROCSparseMatrixCSR(x)) # no direct conversion
SparseMatrixCSC(x::ROCSparseMatrixCOO) = SparseMatrixCSC(ROCSparseMatrixCSR(x)) # no direct conversion
Expand Down Expand Up @@ -519,7 +514,7 @@ Base.copy(Mat::ROCSparseMatrixCOO) = copyto!(similar(Mat), Mat)

# input/output

for (gpu, cpu) in [ROCSparseVector => SparseVector]
for (gpu, cpu) in [:ROCSparseVector => :SparseVector]
@eval function Base.show(io::IO, ::MIME"text/plain", x::$gpu)
xnnz = length(nonzeros(x))
print(io, length(x), "-element ", typeof(x), " with ", xnnz,
Expand All @@ -531,10 +526,10 @@ for (gpu, cpu) in [ROCSparseVector => SparseVector]
end
end

for (gpu, cpu) in [ROCSparseMatrixCSC => SparseMatrixCSC,
ROCSparseMatrixCSR => SparseMatrixCSC,
ROCSparseMatrixBSR => SparseMatrixCSC,
ROCSparseMatrixCOO => SparseMatrixCSC]
for (gpu, cpu) in [:ROCSparseMatrixCSC => :SparseMatrixCSC,
:ROCSparseMatrixCSR => :SparseMatrixCSC,
:ROCSparseMatrixBSR => :SparseMatrixCSC,
:ROCSparseMatrixCOO => :SparseMatrixCSC]
@eval Base.show(io::IOContext, x::$gpu) = show(io, $cpu(x))

@eval function Base.show(io::IO, mime::MIME"text/plain", S::$gpu)
Expand Down
54 changes: 54 additions & 0 deletions src/sparse/conversions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ function SparseArrays.sparse(
end
end

for SparseMatrixType in (:ROCSparseMatrixCSC, :ROCSparseMatrixCSR, :ROCSparseMatrixCOO)
@eval SparseArrays.sparse(A::$SparseMatrixType) = A
end

function sort_rows(coo::ROCSparseMatrixCOO{Tv,Ti}) where {Tv <: BlasFloat, Ti}
m,n = size(coo)
perm = ROCArray{Ti}(undef, nnz(coo))
Expand Down Expand Up @@ -487,3 +491,53 @@ function ROCSparseMatrixBSR(A::ROCMatrix; ind::SparseChar = 'O')
m, n = size(A) # TODO: always let the user choose, or provide defaults for other methods too
ROCSparseMatrixBSR(ROCSparseMatrixCSR(A; ind), gcd(m,n))
end


function AMDGPU.ROCMatrix{T}(coo::ROCSparseMatrixCOO{T}; index::SparseChar='O') where {T}
sparsetodense(coo, index)
end

function ROCSparseMatrixCOO(A::ROCMatrix{T}; index::SparseChar='O') where {T}
densetosparse(A, :coo, index)
end

## ROCSparseVector to ROCSparseMatrices and vice-versa
function ROCSparseVector(A::ROCSparseMatrixCSC{T}) where T
m, n = size(A)
(n == 1) || error("A doesn't have one column and can't be converted to a ROCSparseVector.")
ROCSparseVector{T}(A.rowVal, A.nzVal, m)
end

# no direct conversion
function ROCSparseVector(A::ROCSparseMatrixCSR{T}) where T
m, n = size(A)
(n == 1) || error("A doesn't have one column and can't be converted to a ROCSparseVector.")
B = ROCSparseMatrixCSC{T}(A)
ROCSparseVector(B)
end

function ROCSparseVector(A::ROCSparseMatrixCOO{T}) where T
m, n = size(A)
(n == 1) || error("A doesn't have one column and can't be converted to a ROCSparseVector.")
ROCSparseVector{T}(A.rowInd, A.nzVal, m)
end

function ROCSparseMatrixCSC(x::ROCSparseVector{T}) where T
n = length(x)
colPtr = CuVector{Int32}([1; nnz(x)+1])
ROCSparseMatrixCSC{T}(colPtr, x.iPtr, nonzeros(x), (n,1))
end

# no direct conversion
function ROCSparseMatrixCSR(x::ROCSparseVector{T}) where T
A = ROCSparseMatrixCSC(x)
ROCSparseMatrixCSR{T}(A)
end

function ROCSparseMatrixCOO(x::ROCSparseVector{T}) where T
n = length(x)
nnzx = nnz(x)
colInd = CuVector{Int32}(undef, nnzx)
fill!(colInd, one(Int32))
ROCSparseMatrixCOO{T}(x.iPtr, colInd, nonzeros(x), (n,1), nnzx)
end