Skip to content

Commit

Permalink
Simplify SparseMatrix setindex! (JuliaLang#27026)
Browse files Browse the repository at this point in the history
* Simplify SparseMatrix setindex!

Fix JuliaLang#27013, and add a whole slew of tests. `methods(setindex!, Tuple{SparseMatrixCSC, Vararg{Any}})` goes from a hodgepodge of 26 methods that were tough to reason about to just 6 methods: scalar, logical linear indexing, linear indexing with a vector, linear indexing with a matrix, generic nonscalar indexing, and a disambiguation method. This is much easier to reason about, and fixes a handful of bugs.

* : -> j
  • Loading branch information
mbauman authored May 9, 2018
1 parent a0b1e98 commit e4fa75d
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 36 deletions.
8 changes: 1 addition & 7 deletions stdlib/SparseArrays/src/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,13 +223,7 @@ import Base: asyncmap
@deprecate asyncmap(f, s::AbstractSparseArray...; kwargs...) sparse(asyncmap(f, map(Array, s)...; kwargs...))

# PR 26347: implicit scalar broadcasting within setindex!
@deprecate setindex!(A::SparseMatrixCSC, x::Number, i::Integer, J::AbstractVector{<:Integer}) (A[i, J] .= x; A)
@deprecate setindex!(A::SparseMatrixCSC, x::Number, I::AbstractVector{<:Integer}, j::Integer) (A[I, j] .= x; A)
@deprecate setindex!(A::SparseMatrixCSC, x, ::Colon) fill!(A, x)
@deprecate setindex!(A::SparseMatrixCSC, x, ::Colon, ::Colon) fill!(A, x)
@deprecate setindex!(A::SparseMatrixCSC, x, ::Colon, j::Union{Integer, AbstractVector}) (A[:, j] .= x; A)
@deprecate setindex!(A::SparseMatrixCSC, x, i::Union{Integer, AbstractVector}, ::Colon) (A[i, :] .= x; A)
@deprecate setindex!(A::SparseMatrixCSC, x::Number, I::AbstractVector{<:Integer}, J::AbstractVector{<:Integer}) (A[I, J] .= x; A)
@deprecate setindex!(A::SparseMatrixCSC{<:Any,<:Any}, x, i::Union{Integer, AbstractVector{<:Integer}, Colon}, j::Union{Integer, AbstractVector{<:Integer}, Colon}) (A[i, j] .= x; A)

#25395 keywords unlocked
@deprecate dropzeros(x, trim) dropzeros(x, trim = trim)
Expand Down
42 changes: 13 additions & 29 deletions stdlib/SparseArrays/src/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2326,7 +2326,7 @@ getindex(A::SparseMatrixCSC, I::AbstractVector{<:Integer}, J::AbstractVector{Boo
getindex(A::SparseMatrixCSC, I::AbstractVector{Bool}, J::AbstractVector{<:Integer}) = A[findall(I),J]

## setindex!
function setindex!(A::SparseMatrixCSC{Tv,Ti}, _v, _i::Integer, _j::Integer) where Tv where Ti
function setindex!(A::SparseMatrixCSC{Tv,Ti}, _v, _i::Integer, _j::Integer) where {Tv,Ti<:Integer}
v = convert(Tv, _v)
i = convert(Ti, _i)
j = convert(Ti, _j)
Expand All @@ -2353,11 +2353,6 @@ function setindex!(A::SparseMatrixCSC{Tv,Ti}, _v, _i::Integer, _j::Integer) wher
return A
end

setindex!(A::SparseMatrixCSC, x::AbstractArray, ::Colon) = setindex!(A, x, 1:length(A))
setindex!(A::SparseMatrixCSC, x::AbstractArray, ::Colon, ::Colon) = setindex!(A, x, 1:size(A, 1), 1:size(A,2))
setindex!(A::SparseMatrixCSC, x::AbstractArray, ::Colon, j::Union{Integer, AbstractVector}) = setindex!(A, x, 1:size(A, 1), j)
setindex!(A::SparseMatrixCSC, x::AbstractArray, i::Union{Integer, AbstractVector}, ::Colon) = setindex!(A, x, i, 1:size(A, 2))

function Base.fill!(V::SubArray{Tv, <:Any, <:SparseMatrixCSC, Tuple{Vararg{Union{Integer, AbstractVector{<:Integer}},2}}}, x) where Tv
A = V.parent
I, J = V.indices
Expand Down Expand Up @@ -2520,21 +2515,15 @@ function _spsetnz_setindex!(A::SparseMatrixCSC{Tv}, x::Tv,
return A
end

setindex!(A::SparseMatrixCSC{Tv,Ti}, S::Matrix, I::Integer, J::Integer) where {Tv,Ti} = setindex!(A, convert(Tv, S), I, J)
setindex!(A::SparseMatrixCSC{Tv,Ti}, S::Matrix, I::Union{Integer, AbstractVector{<:Integer}}, J::Union{Integer, AbstractVector{<:Integer}}) where {Tv,Ti} =
setindex!(A, convert(SparseMatrixCSC{Tv,Ti}, S), I, J)

setindex!(A::SparseMatrixCSC, v::AbstractVector, I::Integer, J::Integer) = setindex!(A, convert(Tv, v), I, J)
setindex!(A::SparseMatrixCSC, v::AbstractVector, I::Union{Integer, AbstractVector{<:Integer}}, J::Union{Integer, AbstractVector{<:Integer}}) =
setindex!(A, reshape(v, length(I), length(J)), I, J)
# Nonscalar A[I,J] = B: Convert B to a SparseMatrixCSC of the appropriate shape first
_to_same_csc(::SparseMatrixCSC{Tv, Ti}, V::AbstractMatrix, I...) where {Tv,Ti} = convert(SparseMatrixCSC{Tv,Ti}, V)
_to_same_csc(::SparseMatrixCSC{Tv, Ti}, V::AbstractVector, I...) where {Tv,Ti} = convert(SparseMatrixCSC{Tv,Ti}, reshape(V, map(length, I)))

# Nonscalar A[I,J] = B
setindex!(A::SparseMatrixCSC{Tv,Ti}, B::SparseMatrixCSC{Tv,Ti}, I::Integer, J::Integer) where {Tv,Ti} =
setindex!(A, convert(Tv, I, J), I, J)
function setindex!(A::SparseMatrixCSC{Tv,Ti}, B::SparseMatrixCSC{Tv,Ti}, I::Union{Integer, AbstractVector{<:Integer}}, J::Union{Integer, AbstractVector{<:Integer}}) where {Tv,Ti}
if size(B,1) != length(I) || size(B,2) != length(J)
throw(DimensionMismatch(""))
end
setindex!(A::SparseMatrixCSC{Tv}, B::AbstractVecOrMat, I::Integer, J::Integer) where {Tv} = setindex!(A, convert(Tv, B), I, J)
function setindex!(A::SparseMatrixCSC{Tv,Ti}, V::AbstractVecOrMat, Ix::Union{Integer, AbstractVector{<:Integer}, Colon}, Jx::Union{Integer, AbstractVector{<:Integer}, Colon}) where {Tv,Ti<:Integer}
(I, J) = Base.ensure_indexable(to_indices(A, (Ix, Jx)))
checkbounds(A, I, J)
B = _to_same_csc(A, V, I, J)

issortedI = issorted(I)
issortedJ = issorted(J)
Expand Down Expand Up @@ -2581,7 +2570,7 @@ function setindex!(A::SparseMatrixCSC{Tv,Ti}, B::SparseMatrixCSC{Tv,Ti}, I::Unio
asgn_col = J[colB]

I_asgn = falses(m)
I_asgn[I] .= true
fill!(view(I_asgn, I), true)

ptrS = 1

Expand Down Expand Up @@ -2658,19 +2647,12 @@ end

# Logical setindex!

setindex!(A::SparseMatrixCSC, x::Matrix, I::Integer, J::AbstractVector{Bool}) = setindex!(A, sparse(x), I, findall(J))
setindex!(A::SparseMatrixCSC, x::Matrix, I::AbstractVector{Bool}, J::Integer) = setindex!(A, sparse(x), findall(I), J)
setindex!(A::SparseMatrixCSC, x::Matrix, I::AbstractVector{Bool}, J::AbstractVector{Bool}) = setindex!(A, sparse(x), findall(I), findall(J))
setindex!(A::SparseMatrixCSC, x::Matrix, I::AbstractVector{<:Integer}, J::AbstractVector{Bool}) = setindex!(A, sparse(x), I, findall(J))
setindex!(A::SparseMatrixCSC, x::Matrix, I::AbstractVector{Bool}, J::AbstractVector{<:Integer}) = setindex!(A, sparse(x), findall(I),J)

setindex!(A::Matrix, x::SparseMatrixCSC, I::Integer, J::AbstractVector{Bool}) = setindex!(A, Array(x), I, findall(J))
setindex!(A::Matrix, x::SparseMatrixCSC, I::AbstractVector{Bool}, J::Integer) = setindex!(A, Array(x), findall(I), J)
setindex!(A::Matrix, x::SparseMatrixCSC, I::AbstractVector{Bool}, J::AbstractVector{Bool}) = setindex!(A, Array(x), findall(I), findall(J))
setindex!(A::Matrix, x::SparseMatrixCSC, I::AbstractVector{<:Integer}, J::AbstractVector{Bool}) = setindex!(A, Array(x), I, findall(J))
setindex!(A::Matrix, x::SparseMatrixCSC, I::AbstractVector{Bool}, J::AbstractVector{<:Integer}) = setindex!(A, Array(x), findall(I), J)

setindex!(A::SparseMatrixCSC, x::AbstractArray, I::AbstractVector{Bool}) = setindex!(A, x, findall(I))
function setindex!(A::SparseMatrixCSC, x::AbstractArray, I::AbstractMatrix{Bool})
checkbounds(A, I)
n = sum(I)
Expand Down Expand Up @@ -2771,7 +2753,9 @@ function setindex!(A::SparseMatrixCSC, x::AbstractArray, I::AbstractMatrix{Bool}
A
end

function setindex!(A::SparseMatrixCSC, x::AbstractArray, I::AbstractVector{<:Real})
function setindex!(A::SparseMatrixCSC, x::AbstractArray, Ix::AbstractVector{<:Integer})
(I,) = Base.ensure_indexable(to_indices(A, (Ix,)))
# We check bounds after sorting I
n = length(I)
(n == 0) && (return A)

Expand Down
29 changes: 29 additions & 0 deletions stdlib/SparseArrays/test/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2244,4 +2244,33 @@ end
@test maximum(B) == 6
end

_length_or_count_or_five(::Colon) = 5
_length_or_count_or_five(x::AbstractVector{Bool}) = count(x)
_length_or_count_or_five(x) = length(x)
@testset "nonscalar setindex!" begin
for I in (1:4, :, 5:-1:2, [], trues(5), setindex!(falses(5), true, 2), 3),
J in (2:4, :, 4:-1:1, [], setindex!(trues(5), false, 3), falses(5), 4)
V = sparse(1 .+ zeros(_length_or_count_or_five(I)*_length_or_count_or_five(J)))
M = sparse(1 .+ zeros(_length_or_count_or_five(I), _length_or_count_or_five(J)))
if I isa Integer && J isa Integer
@test_throws MethodError spzeros(5,5)[I, J] = V
@test_throws MethodError spzeros(5,5)[I, J] = M
continue
end
@test setindex!(spzeros(5, 5), V, I, J) == setindex!(zeros(5,5), V, I, J)
@test setindex!(spzeros(5, 5), M, I, J) == setindex!(zeros(5,5), M, I, J)
@test setindex!(spzeros(5, 5), Array(M), I, J) == setindex!(zeros(5,5), M, I, J)
@test setindex!(spzeros(5, 5), Array(V), I, J) == setindex!(zeros(5,5), V, I, J)
end
@test setindex!(spzeros(5, 5), 1:25, :) == setindex!(zeros(5,5), 1:25, :) == reshape(1:25, 5, 5)
@test setindex!(spzeros(5, 5), (25:-1:1).+spzeros(25), :) == setindex!(zeros(5,5), (25:-1:1).+spzeros(25), :) == reshape(25:-1:1, 5, 5)
for X in (1:20, sparse(1:20), reshape(sparse(1:20), 20, 1), (1:20) .+ spzeros(20, 1), collect(1:20), collect(reshape(1:20, 20, 1)))
@test setindex!(spzeros(5, 5), X, 6:25) == setindex!(zeros(5,5), 1:20, 6:25)
@test setindex!(spzeros(5, 5), X, 21:-1:2) == setindex!(zeros(5,5), 1:20, 21:-1:2)
b = trues(25)
b[[6, 8, 13, 15, 23]] .= false
@test setindex!(spzeros(5, 5), X, b) == setindex!(zeros(5, 5), X, b)
end
end

end # module

0 comments on commit e4fa75d

Please sign in to comment.