Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: make SparseMatrixCSC immutable #16371

Merged
merged 1 commit into from
May 18, 2016
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 64 additions & 59 deletions base/sparse/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# Assumes that row values in rowval for each column are sorted
# issorted(rowval[colptr[i]:(colptr[i+1]-1)]) == true

type SparseMatrixCSC{Tv,Ti<:Integer} <: AbstractSparseMatrix{Tv,Ti}
immutable SparseMatrixCSC{Tv,Ti<:Integer} <: AbstractSparseMatrix{Tv,Ti}
m::Int # Number of rows
n::Int # Number of columns
colptr::Vector{Ti} # Column i is in colptr[i]:(colptr[i+1]-1)
Expand Down Expand Up @@ -2268,22 +2268,25 @@ function spset!{Tv,Ti<:Integer}(A::SparseMatrixCSC{Tv}, x::Tv, I::AbstractVector
((I[end] > m) || (J[end] > n)) && throw(DimensionMismatch(""))
nnzA = nnz(A) + lenI * length(J)

colptr = A.colptr
rowvalA = rowval = A.rowval
nzvalA = nzval = A.nzval

rowidx = 1
nadd = 0
@inbounds for col in 1:n
rrange = colptr[col]:(colptr[col+1]-1)
(nadd > 0) && (colptr[col] = colptr[col] + nadd)
rrange = nzrange(A, col)
if nadd > 0
A.colptr[col] = A.colptr[col] + nadd
end

if col in J
if isempty(rrange) # set new vals only
nincl = lenI
if nadd == 0
rowvalA = Array(Ti, nnzA); copy!(rowvalA, 1, rowval, 1, length(rowval))
nzvalA = Array(Tv, nnzA); copy!(nzvalA, 1, nzval, 1, length(nzval))
rowval = copy(rowvalA)
nzval = copy(nzvalA)
resize!(rowvalA, nnzA)
resize!(nzvalA, nnzA)
end
r = rowidx:(rowidx+nincl-1)
rowvalA[r] = I
Expand All @@ -2309,8 +2312,10 @@ function spset!{Tv,Ti<:Integer}(A::SparseMatrixCSC{Tv}, x::Tv, I::AbstractVector
old_ptr += 1
else
if nadd == 0
rowvalA = Array(Ti, nnzA); copy!(rowvalA, 1, rowval, 1, length(rowval))
nzvalA = Array(Tv, nnzA); copy!(nzvalA, 1, nzval, 1, length(nzval))
rowval = copy(rowvalA)
nzval = copy(nzvalA)
resize!(rowvalA, nnzA)
resize!(nzvalA, nnzA)
end
nadd += 1
end
Expand All @@ -2323,8 +2328,10 @@ function spset!{Tv,Ti<:Integer}(A::SparseMatrixCSC{Tv}, x::Tv, I::AbstractVector
if old_ptr > old_stop
if new_ptr <= new_stop
if nadd == 0
rowvalA = Array(Ti, nnzA); copy!(rowvalA, 1, rowval, 1, length(rowval))
nzvalA = Array(Tv, nnzA); copy!(nzvalA, 1, nzval, 1, length(nzval))
rowval = copy(rowvalA)
nzval = copy(nzvalA)
resize!(rowvalA, nnzA)
resize!(nzvalA, nnzA)
end
r = rowidx:(rowidx+(new_stop-new_ptr))
rowvalA[r] = I[new_ptr:new_stop]
Expand Down Expand Up @@ -2353,12 +2360,9 @@ function spset!{Tv,Ti<:Integer}(A::SparseMatrixCSC{Tv}, x::Tv, I::AbstractVector
end

if nadd > 0
colptr[n+1] = rowidx
A.colptr[n+1] = rowidx
deleteat!(rowvalA, rowidx:nnzA)
deleteat!(nzvalA, rowidx:nnzA)

A.rowval = rowvalA
A.nzval = nzvalA
end
return A
end
Expand All @@ -2373,14 +2377,16 @@ function spdelete!{Tv,Ti<:Integer}(A::SparseMatrixCSC{Tv}, I::AbstractVector{Ti}

((I[end] > m) || (J[end] > n)) && throw(DimensionMismatch(""))

colptr = A.colptr
rowval = rowvalA = A.rowval
nzval = nzvalA = A.nzval
rowidx = 1
ndel = 0
@inbounds for col in 1:n
rrange = colptr[col]:(colptr[col+1]-1)
(ndel > 0) && (colptr[col] = colptr[col] - ndel)
rrange = nzrange(A, col)
if ndel > 0
A.colptr[col] = A.colptr[col] - ndel
end

if isempty(rrange) || !(col in J)
nincl = length(rrange)
if(ndel > 0) && !isempty(rrange)
Expand All @@ -2392,8 +2398,8 @@ function spdelete!{Tv,Ti<:Integer}(A::SparseMatrixCSC{Tv}, I::AbstractVector{Ti}
for ridx in rrange
if rowval[ridx] in I
if ndel == 0
rowvalA = copy(rowval)
nzvalA = copy(nzval)
rowval = copy(rowvalA)
nzval = copy(nzvalA)
end
ndel += 1
else
Expand All @@ -2408,12 +2414,9 @@ function spdelete!{Tv,Ti<:Integer}(A::SparseMatrixCSC{Tv}, I::AbstractVector{Ti}
end

if ndel > 0
colptr[n+1] = rowidx
A.colptr[n+1] = rowidx
deleteat!(rowvalA, rowidx:nnzA)
deleteat!(nzvalA, rowidx:nnzA)

A.rowval = rowvalA
A.nzval = nzvalA
end
return A
end
Expand Down Expand Up @@ -2458,11 +2461,14 @@ function setindex!{Tv,Ti,T<:Integer}(A::SparseMatrixCSC{Tv,Ti}, B::SparseMatrixC
colptrB = B.colptr; rowvalB = B.rowval; nzvalB = B.nzval
Copy link
Member

@Sacha0 Sacha0 May 14, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems you've removed all references to colptrB, rowvalB, and nzvalB in this method. Perhaps nix this line? Or did I miss remaining references?

Edit: The same seems to hold for colptrA, rowvalA, and nzvalA?

Copy link
Member Author

@KristofferC KristofferC May 14, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I mixed up two commits a bit. That was supposed to go in the (potential) later PR that removes manual hoistings. I will change back to using the colptrB variables.


nnzS = nnz(A) + nnz(B)
colptrS = Array(Ti, n+1)
rowvalS = Array(Ti, nnzS)
nzvalS = Array(Tv, nnzS)

colptrS[1] = 1
colptrS = copy(A.colptr)
rowvalS = copy(A.rowval)
nzvalS = copy(A.nzval)

resize!(rowvalA, nnzS)
resize!(nzvalA, nnzS)

colB = 1
asgn_col = J[colB]

Expand All @@ -2475,73 +2481,70 @@ function setindex!{Tv,Ti,T<:Integer}(A::SparseMatrixCSC{Tv,Ti}, B::SparseMatrixC

# Copy column of A if it is not being assigned into
if colB > nJ || col != J[colB]
colptrS[col+1] = colptrS[col] + (colptrA[col+1]-colptrA[col])
colptrA[col+1] = colptrA[col] + (colptrS[col+1]-colptrS[col])

for k = colptrA[col]:colptrA[col+1]-1
rowvalS[ptrS] = rowvalA[k]
nzvalS[ptrS] = nzvalA[k]
for k = colptrS[col]:colptrS[col+1]-1
rowvalA[ptrS] = rowvalS[k]
nzvalA[ptrS] = nzvalS[k]
ptrS += 1
end
continue
end

ptrA::Int = colptrA[col]
stopA::Int = colptrA[col+1]
ptrA::Int = colptrS[col]
stopA::Int = colptrS[col+1]
ptrB::Int = colptrB[colB]
stopB::Int = colptrB[colB+1]

while ptrA < stopA && ptrB < stopB
rowA = rowvalA[ptrA]
rowA = rowvalS[ptrA]
rowB = I[rowvalB[ptrB]]
if rowA < rowB
if ~I_asgn[rowA]
rowvalS[ptrS] = rowA
nzvalS[ptrS] = nzvalA[ptrA]
rowvalA[ptrS] = rowA
nzvalA[ptrS] = nzvalS[ptrA]
ptrS += 1
end
ptrA += 1
elseif rowB < rowA
rowvalS[ptrS] = rowB
nzvalS[ptrS] = nzvalB[ptrB]
rowvalA[ptrS] = rowB
nzvalA[ptrS] = nzvalB[ptrB]
ptrS += 1
ptrB += 1
else
rowvalS[ptrS] = rowB
nzvalS[ptrS] = nzvalB[ptrB]
rowvalA[ptrS] = rowB
nzvalA[ptrS] = nzvalB[ptrB]
ptrS += 1
ptrB += 1
ptrA += 1
end
end

while ptrA < stopA
rowA = rowvalA[ptrA]
rowA = rowvalS[ptrA]
if ~I_asgn[rowA]
rowvalS[ptrS] = rowA
nzvalS[ptrS] = nzvalA[ptrA]
rowvalA[ptrS] = rowA
nzvalA[ptrS] = nzvalS[ptrA]
ptrS += 1
end
ptrA += 1
end

while ptrB < stopB
rowB = I[rowvalB[ptrB]]
rowvalS[ptrS] = rowB
nzvalS[ptrS] = nzvalB[ptrB]
rowvalA[ptrS] = rowB
nzvalA[ptrS] = nzvalB[ptrB]
ptrS += 1
ptrB += 1
end

colptrS[col+1] = ptrS
colptrA[col+1] = ptrS
colB += 1
end

deleteat!(rowvalS, colptrS[end]:length(rowvalS))
deleteat!(nzvalS, colptrS[end]:length(nzvalS))
deleteat!(rowvalA, colptrA[end]:length(rowvalA))
deleteat!(nzvalA, colptrA[end]:length(nzvalA))

A.colptr = colptrS
A.rowval = rowvalS
A.nzval = nzvalS
return A
end

Expand Down Expand Up @@ -2597,10 +2600,12 @@ function setindex!{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractMatrix{Bool})

if (mode > 1) && (nadd == 0) && (ndel == 0)
# copy storage to take changes
colptrB = copy(colptrA)
colptrA = copy(colptrB)
memreq = (x == 0) ? 0 : n
rowvalB = Array(Ti, length(rowvalA)+memreq); copy!(rowvalB, 1, rowvalA, 1, r1-1)
nzvalB = Array(Tv, length(nzvalA)+memreq); copy!(nzvalB, 1, nzvalA, 1, r1-1)
rowvalA = copy(rowvalB)
nzvalA = copy(nzvalB)
resize!(rowvalB, length(rowvalA)+memreq)
resize!(nzvalB, length(rowvalA)+memreq)
Copy link
Member

@Sacha0 Sacha0 May 14, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I understand correctly that in this method colptrB, rowvalB, and nzvalB are now persistent references to A.colptr, A.rowval, and A.nzval, while colptrA, rowvalA, and nzvalA can become disjoint storage for the original contents of A.colptr, A.rowval, and A.nzval? If so this seems a bit labyrinthine, and I might advocate for renaming things commensurate with their use (as in the preceding methods).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You basically want to swap xxxA and xxxB?

I agree, these functions are indeed difficult to reason about. They try to reuse the storage in the input matrix but if they find out that they need temporary data they allocate copies of the matrix fields and rebind the variables. It makes it difficult to keep track on what data each variable points to. I didn't want to dig down too much into these functions so I mostly mechanically made them set the new values inplace the matrix fields instead of the newly allocated vectors.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You basically want to swap xxxA and xxxB?

Eventually something along those lines might be nice, yes. But for now please ignore this comment (see below.)

I agree, these functions are indeed difficult to reason about. They try to reuse the storage in the input matrix but if they find out that they need temporary data they allocate copies of the matrix fields and rebind the variables. It makes it difficult to keep track on what data each variable points to. I didn't want to dig down too much into these functions so I mostly mechanically made them set the new values inplace the matrix fields instead of the newly allocated vectors.

Cheers, I would have done the same and for the same reasons. I like how you've minimized modifications in this PR. Better to defer my suggestion above to a separate PR to retain that approach and see this merged efficiently.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's open issue about it later so we don't forget it. Hopefully making SparseMatrixCSC an immutable opens up for some clean up possibilities when dereferencing fields can be hoisted out of, for example, utility functions.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good! :)

end
if mode == 1
rowvalB[bidx] = row
Expand Down Expand Up @@ -2653,7 +2658,6 @@ function setindex!{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractMatrix{Bool})
deleteat!(nzvalB, bidx:n)
deleteat!(rowvalB, bidx:n)
end
A.nzval = nzvalB; A.rowval = rowvalB; A.colptr = colptrB
end
A
end
Expand Down Expand Up @@ -2719,10 +2723,12 @@ function setindex!{Tv,Ti,T<:Real}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractVecto

if (mode > 1) && (nadd == 0) && (ndel == 0)
# copy storage to take changes
colptrB = copy(colptrA)
colptrA = copy(colptrB)
memreq = (x == 0) ? 0 : n
rowvalB = Array(Ti, length(rowvalA)+memreq); copy!(rowvalB, 1, rowvalA, 1, r1-1)
nzvalB = Array(Tv, length(nzvalA)+memreq); copy!(nzvalB, 1, nzvalA, 1, r1-1)
rowvalA = copy(rowvalB)
nzvalA = copy(nzvalB)
resize!(rowvalB, length(rowvalA)+memreq)
resize!(nzvalB, length(rowvalA)+memreq)
Copy link
Member

@Sacha0 Sacha0 May 14, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above re. xxxxxxBand xxxxxxA naming. (Edit: Please also ignore this comment as above.)

end
if mode == 1
rowvalB[bidx] = row
Expand Down Expand Up @@ -2759,7 +2765,6 @@ function setindex!{Tv,Ti,T<:Real}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractVecto
deleteat!(nzvalB, bidx:n)
deleteat!(rowvalB, bidx:n)
end
A.nzval = nzvalB; A.rowval = rowvalB; A.colptr = colptrB
end
A
end
Expand Down