Skip to content

More generic SparseMatrixCSC and SparseVector #33918

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion stdlib/SparseArrays/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[extras]
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Dates", "Test", "InteractiveUtils"]
test = ["Dates", "Test", "InteractiveUtils", "Printf"]
39 changes: 23 additions & 16 deletions stdlib/SparseArrays/src/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,26 @@ Matrix type for storing sparse matrices in the
of constructing SparseMatrixCSC is through the [`sparse`](@ref) function.
See also [`spzeros`](@ref), [`spdiagm`](@ref) and [`sprand`](@ref).
"""
struct SparseMatrixCSC{Tv,Ti<:Integer} <: AbstractSparseMatrixCSC{Tv,Ti}
struct SparseMatrixCSC{Tv,Ti<:Integer,
ColPtr <: AbstractVector{Ti},
RowVal <: AbstractVector{Ti},
NZVal <: AbstractVector{Tv}} <: AbstractSparseMatrixCSC{Tv,Ti}
m::Int # Number of rows
n::Int # Number of columns
colptr::Vector{Ti} # Column i is in colptr[i]:(colptr[i+1]-1)
rowval::Vector{Ti} # Row indices of stored values
nzval::Vector{Tv} # Stored values, typically nonzeros
colptr::ColPtr # Column i is in colptr[i]:(colptr[i+1]-1)
rowval::RowVal # Row indices of stored values
nzval::NZVal # Stored values, typically nonzeros

function SparseMatrixCSC{Tv,Ti}(m::Integer, n::Integer, colptr::Vector{Ti},
rowval::Vector{Ti}, nzval::Vector{Tv}) where {Tv,Ti<:Integer}
function SparseMatrixCSC{Tv,Ti}(m::Integer, n::Integer, colptr::AbstractVector{Ti},
rowval::AbstractVector{Ti}, nzval::AbstractVector{Tv}) where {Tv,Ti<:Integer}
@noinline throwsz(str, lbl, k) =
throw(ArgumentError("number of $str ($lbl) must be ≥ 0, got $k"))
m < 0 && throwsz("rows", 'm', m)
n < 0 && throwsz("columns", 'n', n)
new(Int(m), Int(n), colptr, rowval, nzval)
new{Tv, Ti, typeof(colptr), typeof(rowval), typeof(nzval)}(Int(m), Int(n), colptr, rowval, nzval)
end
end
function SparseMatrixCSC(m::Integer, n::Integer, colptr::Vector, rowval::Vector, nzval::Vector)
function SparseMatrixCSC(m::Integer, n::Integer, colptr::AbstractVector, rowval::AbstractVector, nzval::AbstractVector)
Tv = eltype(nzval)
Ti = promote_type(eltype(colptr), eltype(rowval))
sparse_check_Ti(m, n, Ti)
Expand All @@ -52,7 +55,7 @@ function sparse_check_Ti(m::Integer, n::Integer, Ti::Type)
0 ≤ n && (!isbitstype(Ti) || n ≤ typemax(Ti)) || throwTi("number of columns", "n", n)
end

function sparse_check(n::Integer, colptr::Vector{Ti}, rowval, nzval) where Ti
function sparse_check(n::Integer, colptr::AbstractVector{Ti}, rowval, nzval) where Ti
sparse_check_length("colptr", colptr, n+1, String) # don't check upper bound
ckp = Ti(1)
ckp == colptr[1] || throw(ArgumentError("$ckp == colptr[1] != 1"))
Expand Down Expand Up @@ -248,9 +251,9 @@ end

## Reshape

function sparse_compute_reshaped_colptr_and_rowval(colptrS::Vector{Ti}, rowvalS::Vector{Ti},
mS::Int, nS::Int, colptrA::Vector{Ti},
rowvalA::Vector{Ti}, mA::Int, nA::Int) where Ti
function sparse_compute_reshaped_colptr_and_rowval(colptrS::AbstractVector{Ti}, rowvalS::AbstractVector{Ti},
mS::Int, nS::Int, colptrA::AbstractVector{Ti},
rowvalA::AbstractVector{Ti}, mA::Int, nA::Int) where Ti
lrowvalA = length(rowvalA)
maxrowvalA = (lrowvalA > 0) ? maximum(rowvalA) : zero(Ti)
((length(colptrA) == (nA+1)) && (maximum(colptrA) <= (lrowvalA+1)) && (maxrowvalA <= mA)) || throw(BoundsError())
Expand Down Expand Up @@ -3468,11 +3471,15 @@ function _spdiagm(size, kv::Pair{<:Integer,<:AbstractVector}...)
end

## expand a colptr or rowptr into a dense index vector
function expandptr(V::Vector{<:Integer})
if V[1] != 1 throw(ArgumentError("first index must be one")) end
function expandptr(V::AbstractVector{<:Integer})
if V[1] != 1
throw(ArgumentError("first index must be one"))
end
res = similar(V, (Int64(V[end]-1),))
for i in 1:(length(V)-1), j in V[i]:(V[i+1] - 1); res[j] = i end
res
for i in 1:(length(V)-1), j in V[i]:(V[i+1] - 1)
res[j] = i
end
return res
end


Expand Down
12 changes: 6 additions & 6 deletions stdlib/SparseArrays/src/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,20 @@ import LinearAlgebra: promote_to_array_type, promote_to_arrays_

Vector type for storing sparse vectors.
"""
struct SparseVector{Tv,Ti<:Integer} <: AbstractSparseVector{Tv,Ti}
struct SparseVector{Tv,Ti<:Integer,NZInd<:AbstractVector{Ti},NZVal<:AbstractVector{Tv}} <: AbstractSparseVector{Tv,Ti}
n::Int # Length of the sparse vector
nzind::Vector{Ti} # Indices of stored values
nzval::Vector{Tv} # Stored values, typically nonzeros
nzind::NZInd # Indices of stored values
nzval::NZVal # Stored values, typically nonzeros

function SparseVector{Tv,Ti}(n::Integer, nzind::Vector{Ti}, nzval::Vector{Tv}) where {Tv,Ti<:Integer}
function SparseVector{Tv,Ti}(n::Integer, nzind::AbstractVector{Ti}, nzval::AbstractVector{Tv}) where {Tv,Ti<:Integer}
n >= 0 || throw(ArgumentError("The number of elements must be non-negative."))
length(nzind) == length(nzval) ||
throw(ArgumentError("index and value vectors must be the same length"))
new(convert(Int, n), nzind, nzval)
new{Tv, Ti, typeof(nzind), typeof(nzval)}(convert(Int, n), nzind, nzval)
end
end

SparseVector(n::Integer, nzind::Vector{Ti}, nzval::Vector{Tv}) where {Tv,Ti} =
SparseVector(n::Integer, nzind::AbstractVector{Ti}, nzval::AbstractVector{Tv}) where {Tv,Ti} =
SparseVector{Tv,Ti}(n, nzind, nzval)

# Define an alias for a view of a whole column of a SparseMatrixCSC. Many methods can be written for the
Expand Down