Skip to content

faster circshift! for SparseMatrixCSC #30317

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Dec 25, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion stdlib/SparseArrays/src/SparseArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import Base: @get!, acos, acosd, acot, acotd, acsch, asech, asin, asind, asinh,
vcat, hcat, hvcat, cat, imag, argmax, kron, length, log, log1p, max, min,
maximum, minimum, one, promote_eltype, real, reshape, rot180,
rotl90, rotr90, round, setindex!, similar, size, transpose,
vec, permute!, map, map!, Array, diff
vec, permute!, map, map!, Array, diff, circshift!, circshift

using Random: GLOBAL_RNG, AbstractRNG, randsubseq, randsubseq!

Expand Down
43 changes: 43 additions & 0 deletions stdlib/SparseArrays/src/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3511,3 +3511,46 @@ end
(+)(A::SparseMatrixCSC, J::UniformScaling) = A + sparse(J, size(A)...)
(-)(A::SparseMatrixCSC, J::UniformScaling) = A - sparse(J, size(A)...)
(-)(J::UniformScaling, A::SparseMatrixCSC) = sparse(J, size(A)...) - A

## circular shift

function circshift!(O::SparseMatrixCSC, X::SparseMatrixCSC, (r,c)::Base.DimsInteger{2})
nnz = length(X.nzval)

iszero(nnz) && return copy!(O, X)

##### column shift
c = mod(c, X.n)
if iszero(c)
copy!(O, X)
else
##### readjust output
resize!(O.colptr, X.n + 1)
resize!(O.rowval, nnz)
resize!(O.nzval, nnz)
O.colptr[X.n + 1] = nnz + 1

# exchange left and right blocks
nleft = X.colptr[X.n - c + 1] - 1
nright = nnz - nleft
@inbounds for i=c+1:X.n
O.colptr[i] = X.colptr[i-c] + nright
end
@inbounds for i=1:c
O.colptr[i] = X.colptr[X.n - c + i] - nleft
end
# rotate rowval and nzval by the right number of elements
circshift!(O.rowval, X.rowval, (nright,))
circshift!(O.nzval, X.nzval, (nright,))
end
##### row shift
r = mod(r, X.m)
iszero(r) && return O
@inbounds for i=1:O.n
subvector_shifter!(O.rowval, O.nzval, O.colptr[i], O.colptr[i+1]-1, O.m, r)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Skip this loop if iszero(r). Similarly the code above can be replace with a copy if iszero(c).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @stevengj. I implemented the suggestions, let me know if I interpreted correctly. I also moved subvector_shifter! to sparsevector.jl, it seemed more appropriate (should I prepend the name with _ given that it's a helper, or better not, as it is used also by sparsematrix.jl?).

return O
end

circshift!(O::SparseMatrixCSC, X::SparseMatrixCSC, (r,)::Base.DimsInteger{1}) = circshift!(O, X, (r,0))
circshift!(O::SparseMatrixCSC, X::SparseMatrixCSC, r::Real) = circshift!(O, X, (Integer(r),0))
39 changes: 39 additions & 0 deletions stdlib/SparseArrays/src/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1975,3 +1975,42 @@ function fill!(A::Union{SparseVector, SparseMatrixCSC}, x)
end
return A
end



# in-place swaps (dense) blocks start:split and split+1:fin in col
function _swap!(col::AbstractVector, start::Integer, fin::Integer, split::Integer)
split == fin && return
reverse!(col, start, split)
reverse!(col, split + 1, fin)
reverse!(col, start, fin)
return
end


# in-place shifts a sparse subvector by r. Used also by sparsematrix.jl
function subvector_shifter!(R::AbstractVector, V::AbstractVector, start::Integer, fin::Integer, m::Integer, r::Integer)
split = fin
@inbounds for j = start:fin
# shift positions ...
R[j] += r
if R[j] <= m
split = j
else
R[j] -= m
end
end
# ...but rowval should be sorted within columns
_swap!(R, start, fin, split)
_swap!(V, start, fin, split)
end


function circshift!(O::SparseVector, X::SparseVector, (r,)::Base.DimsInteger{1})
O .= X
Copy link
Member

@stevengj stevengj Dec 21, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a bug in copy! for this case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I opened a separate issue #30443

subvector_shifter!(O.nzind, O.nzval, 1, length(O.nzind), O.n, mod(r, X.n))
return O
end


circshift!(O::SparseVector, X::SparseVector, r::Real,) = circshift!(O, X, (Integer(r),))
27 changes: 27 additions & 0 deletions stdlib/SparseArrays/test/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2410,4 +2410,31 @@ end
@test one(A) isa SparseMatrixCSC{Int}
end

@testset "circshift" begin
m,n = 17,15
A = sprand(m, n, 0.5)
for rshift in (-1, 0, 1, 10), cshift in (-1, 0, 1, 10)
shifts = (rshift, cshift)
# using dense circshift to compare
B = circshift(Matrix(A), shifts)
# sparse circshift
C = circshift(A, shifts)
@test C == B
# sparse circshift should not add structural zeros
@test nnz(C) == nnz(A)
# test circshift!
D = similar(A)
circshift!(D, A, shifts)
@test D == B
@test nnz(D) == nnz(A)
# test different in/out types
A2 = floor.(100A)
E1 = spzeros(Int64, m, n)
E2 = spzeros(Int64, m, n)
circshift!(E1, A2, shifts)
circshift!(E2, Matrix(A2), shifts)
@test E1 == E2
end
end

end # module
22 changes: 22 additions & 0 deletions stdlib/SparseArrays/test/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1265,4 +1265,26 @@ end
end
end

@testset "SparseVector circshift" begin
n = 100
v = sprand(n, 0.5)
for shift in (0,-1,1,5,-7,n+10)
x = circshift(Vector(v), shift)
w = circshift(v, shift)
@test nnz(v) == nnz(w)
@test w == x
# test circshift!
v1 = similar(v)
circshift!(v1, v, shift)
@test v1 == x
# test different in/out types
y1 = spzeros(Int64, n)
y2 = spzeros(Int64, n)
v2 = floor.(100v)
circshift!(y1, v2, shift)
circshift!(y2, Vector(v2), shift)
@test y1 == y2
end
end

end # module