Skip to content

make SparseMatrixCSC and SparseVector work on non-numerical values #30580

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions stdlib/SparseArrays/src/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ function SparseMatrixCSC{Tv,Ti}(M::AbstractMatrix) where {Tv,Ti}
end

function SparseMatrixCSC{Tv,Ti}(M::StridedMatrix) where {Tv,Ti}
nz = count(t -> t != 0, M)
nz = count(!iszero, M)
colptr = zeros(Ti, size(M, 2) + 1)
nzval = Vector{Tv}(undef, nz)
rowval = Vector{Ti}(undef, nz)
Expand All @@ -394,7 +394,7 @@ function SparseMatrixCSC{Tv,Ti}(M::StridedMatrix) where {Tv,Ti}
@inbounds for j in 1:size(M, 2)
for i in 1:size(M, 1)
v = M[i, j]
if v != 0
if !iszero(v)
rowval[cnt] = i
nzval[cnt] = v
cnt += 1
Expand Down Expand Up @@ -1241,7 +1241,7 @@ Removes stored numerical zeros from `A`, optionally trimming resulting excess sp
For an out-of-place version, see [`dropzeros`](@ref). For
algorithmic information, see `fkeep!`.
"""
dropzeros!(A::SparseMatrixCSC; trim::Bool = true) = fkeep!(A, (i, j, x) -> x != 0, trim)
dropzeros!(A::SparseMatrixCSC; trim::Bool = true) = fkeep!(A, (i, j, x) -> !iszero(x), trim)
"""
dropzeros(A::SparseMatrixCSC; trim::Bool = true)

Expand Down Expand Up @@ -2330,7 +2330,7 @@ function _setindex_scalar!(A::SparseMatrixCSC{Tv,Ti}, _v, _i::Integer, _j::Integ
end
# Column j does not contain entry A[i,j]. If v is nonzero, insert entry A[i,j] = v
# and return. If to the contrary v is zero, then simply return.
if v != 0
if !iszero(v)
insert!(A.rowval, searchk, i)
insert!(A.nzval, searchk, v)
@simd for m in (j + 1):(A.n + 1)
Expand Down Expand Up @@ -3184,7 +3184,7 @@ function is_hermsym(A::SparseMatrixCSC, check::Function)
# We therefore "catch up" here while making sure that
# the elements are actually zero.
while row2 < col
if nzval[offset] != 0
if !iszero(nzval[offset])
return false
end
offset += 1
Expand Down Expand Up @@ -3222,7 +3222,7 @@ function istriu(A::SparseMatrixCSC)
if rowval[l1-i] <= col
break
end
if nzval[l1-i] != 0
if !iszero(nzval[l1-i])
return false
end
end
Expand All @@ -3241,7 +3241,7 @@ function istril(A::SparseMatrixCSC)
if rowval[i] >= col
break
end
if nzval[i] != 0
if !iszero(nzval[i])
return false
end
end
Expand Down
6 changes: 3 additions & 3 deletions stdlib/SparseArrays/src/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ function setindex!(x::SparseVector{Tv,Ti}, v::Tv, i::Ti) where {Tv,Ti<:Integer}
if 1 <= k <= m && nzind[k] == i # i found
nzval[k] = v
else # i not found
if v != 0
if !iszero(v)
insert!(nzind, k, i)
insert!(nzval, k, v)
end
Expand Down Expand Up @@ -392,7 +392,7 @@ function _dense2indval!(nzind::Vector{Ti}, nzval::Vector{Tv}, s::AbstractArray{T
c = 0
@inbounds for i = 1:n
v = s[i]
if v != 0
if !iszero(v)
if c >= cap
cap *= 2
resize!(nzind, cap)
Expand Down Expand Up @@ -1929,7 +1929,7 @@ Removes stored numerical zeros from `x`, optionally trimming resulting excess sp
For an out-of-place version, see [`dropzeros`](@ref). For
algorithmic information, see `fkeep!`.
"""
dropzeros!(x::SparseVector; trim::Bool = true) = fkeep!(x, (i, x) -> x != 0, trim)
dropzeros!(x::SparseVector; trim::Bool = true) = fkeep!(x, (i, x) -> !iszero(x), trim)

"""
dropzeros(x::SparseVector; trim::Bool = true)
Expand Down
7 changes: 2 additions & 5 deletions stdlib/SparseArrays/test/higherorderfns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -287,17 +287,14 @@ end
A, fA = sparse(1.0I, N, N), Matrix(1.0I, N, N)
B, fB = spzeros(1, N), zeros(1, N)
intorfloat_zeropres(xs...) = all(iszero, xs) ? zero(Float64) : Int(1)
stringorfloat_zeropres(xs...) = all(iszero, xs) ? zero(Float64) : "hello"
intorfloat_notzeropres(xs...) = all(iszero, xs) ? Int(1) : zero(Float64)
stringorfloat_notzeropres(xs...) = all(iszero, xs) ? "hello" : zero(Float64)
for fn in (intorfloat_zeropres, intorfloat_notzeropres,
stringorfloat_zeropres, stringorfloat_notzeropres)
for fn in (intorfloat_zeropres, intorfloat_notzeropres)
@test map(fn, A) == sparse(map(fn, fA))
@test broadcast(fn, A) == sparse(broadcast(fn, fA))
@test broadcast(fn, A, B) == sparse(broadcast(fn, fA, fB))
@test broadcast(fn, B, A) == sparse(broadcast(fn, fB, fA))
end
for fn in (intorfloat_zeropres, stringorfloat_zeropres)
for fn in (intorfloat_zeropres,)
@test broadcast(fn, A, B, A) == sparse(broadcast(fn, fA, fB, fA))
end
end
Expand Down
16 changes: 3 additions & 13 deletions stdlib/SparseArrays/test/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1263,15 +1263,16 @@ end
@test isequal(findmax(A, dims=tup), (rval, rind))
end

A = sparse(["a", "b"])
@test_throws MethodError findmin(A, dims=1)
# sparse arrays of types without zero(T) are forbidden
@test_throws MethodError sparse(["a", "b"])
end

# Support the case when user defined `zero` and `isless` for non-numerical type
struct CustomType
x::String
end
Base.zero(::Type{CustomType}) = CustomType("")
Base.zero(x::CustomType) = zero(CustomType)
Base.isless(x::CustomType, y::CustomType) = isless(x.x, y.x)
@testset "findmin/findmax for non-numerical type" begin
A = sparse([CustomType("a"), CustomType("b")])
Expand Down Expand Up @@ -2286,17 +2287,6 @@ end
@test findnext(!iszero, z,i) == findnext(!iszero, z_sp,i)
@test findprev(!iszero, z,i) == findprev(!iszero, z_sp,i)
end

w = [ "a" ""; "" "b"]
w_sp = sparse(w)

for i in keys(w)
@test findnext(!isequal(""), w,i) == findnext(!isequal(""), w_sp,i)
@test findprev(!isequal(""), w,i) == findprev(!isequal(""), w_sp,i)
@test findnext(isequal(""), w,i) == findnext(isequal(""), w_sp,i)
@test findprev(isequal(""), w,i) == findprev(isequal(""), w_sp,i)
end

end

# #20711
Expand Down
1 change: 1 addition & 0 deletions test/hashing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ Base.hash(x::CustomHashReal, h::UInt) = hash(x.x, h)
Base.:(==)(x::CustomHashReal, y::Number) = x.x == y
Base.:(==)(x::Number, y::CustomHashReal) = x == y.x
Base.zero(::Type{CustomHashReal}) = CustomHashReal(0.0)
Base.zero(x::CustomHashReal) = zero(CustomHashReal)

let a = sparse([CustomHashReal(0), CustomHashReal(3), CustomHashReal(3)])
@test hash(a) == hash(Array(a))
Expand Down
3 changes: 3 additions & 0 deletions test/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,9 @@ end

# issue #12960
mutable struct T12960 end
import Base.zero
Base.zero(::Type{T12960}) = T12960()
Base.zero(x::T12960) = T12960()
let
A = sparse(1.0I, 3, 3)
B = similar(A, T12960)
Expand Down