Skip to content

Commit

Permalink
Change find() to return the same index type as pairs()
Browse files Browse the repository at this point in the history
This does not change anything for AbstractVectors and general iterables,
which continue to use linear indices. For other AbstractArrays, return
CartesianIndexes (rather than linear indices). For Dicts, return keys
(previously not supported at all).

Relying on collect() to choose the return element type allows supporting
any definition of pairs(), including that for Dict, which creates a standard
Generator for which eltype() returns Any.
  • Loading branch information
nalimilan committed Nov 25, 2017
1 parent f8243b7 commit 0a8fbcb
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 45 deletions.
88 changes: 44 additions & 44 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1723,48 +1723,60 @@ findlast(testf::Function, A) = findprev(testf, A, endof(A))
"""
find(f::Function, A)
Return a vector `I` of the linear indexes of `A` where `f(A[I])` returns `true`.
Return a vector `I` of the indices or keys of `A` where `f(A[I])` returns `true`.
If there are no such elements of `A`, return an empty array.
Indices or keys are of the same type as those returned by [`keys(A)`](@ref)
and [`pairs(A)`](@ref) for `AbstractArray` and `Associative` objects,
and are linear indices of type `Int` for other iterables.
# Examples
```jldoctest
julia> A = [1 2 0; 3 4 0]
2×3 Array{Int64,2}:
1 2 0
3 4 0
julia> x = [1, 3, 4]
3-element Array{Int64,1}:
1
3
4
julia> find(isodd, A)
julia> find(isodd, x)
2-element Array{Int64,1}:
1
2
julia> A = [1 2 0; 3 4 0]
2×3 Array{Int64,2}:
1 2 0
3 4 0
julia> find(isodd, A)
2-element Array{CartesianIndex{2},1}:
CartesianIndex(1, 1)
CartesianIndex(2, 1)
julia> find(!iszero, A)
4-element Array{Int64,1}:
1
2
3
4
4-element Array{CartesianIndex{2},1}:
CartesianIndex(1, 1)
CartesianIndex(2, 1)
CartesianIndex(1, 2)
CartesianIndex(2, 2)
julia> d = Dict(:A => 10, :B => -1, :C => 0)
Dict{Symbol,Int64} with 3 entries:
:A => 10
:B => -1
:C => 0
julia> find(x -> x >= 0, d)
2-element Array{Symbol,1}:
:A
:C
julia> find(isodd, [2, 4])
0-element Array{Int64,1}
```
"""
function find(testf::Function, A)
# use a dynamic-length array to store the indexes, then copy to a non-padded
# array for the return
tmpI = Vector{Int}()
inds = _index_remapper(A)
for (i,a) = enumerate(A)
if testf(a)
push!(tmpI, inds[i])
end
end
I = Vector{Int}(uninitialized, length(tmpI))
copy!(I, tmpI)
return I
end
_index_remapper(A::AbstractArray) = linearindices(A)
_index_remapper(iter) = OneTo(typemax(Int)) # safe for objects that don't implement length
find(testf::Function, A) = collect(first(p) for p in _pairs(A) if testf(last(p)))

_pairs(A::Union{AbstractArray, Associative}) = pairs(A)
_pairs(iter) = zip(OneTo(typemax(Int)), iter) # safe for objects that don't implement length

"""
find(A)
Expand All @@ -1789,22 +1801,10 @@ julia> find(falses(3))
```
"""
function find(A)
nnzA = count(t -> t != 0, A)
I = Vector{Int}(uninitialized, nnzA)
cnt = 1
inds = _index_remapper(A)
warned = false
for (i,a) in enumerate(A)
if !warned && !(a isa Bool)
depwarn("In the future `find(A)` will only work on boolean collections. Use `find(x->x!=0, A)` instead.", :find)
warned = true
end
if a != 0
I[cnt] = inds[i]
cnt += 1
end
if !(eltype(A) === Bool) && !all(x -> x isa Bool, A)
depwarn("In the future `find(A)` will only work on boolean collections. Use `find(x->x!=0, A)` instead.", :find)
end
return I
collect(first(p) for p in _pairs(A) if last(p) != 0)
end

find(x::Bool) = x ? [1] : Vector{Int}()
Expand Down
2 changes: 1 addition & 1 deletion base/sparse/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1264,7 +1264,7 @@ function find(p::Function, S::SparseMatrixCSC)
end
sz = size(S)
I, J = _findn(p, S)
return sub2ind(sz, I, J)
return CartesianIndex.(I, J)
end

findn(S::SparseMatrixCSC{Tv,Ti}) where {Tv,Ti} = _findn(x->x!=0, S)
Expand Down
10 changes: 10 additions & 0 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,16 @@ end
@test findprev(isodd, [2,4,5,3,9,2,0], 7) == 5
@test findprev(isodd, [2,4,5,3,9,2,0], 2) == 0
end
@testset "find with Matrix" begin
A = [1 2 0; 3 4 0]
@test find(isodd, A) == [CartesianIndex(1, 1), CartesianIndex(2, 1)]
@test find(!iszero, A) == [CartesianIndex(1, 1), CartesianIndex(2, 1),
CartesianIndex(1, 2), CartesianIndex(2, 2)]
end
@testset "find with Dict" begin
d = Dict(:A => 10, :B => -1, :C => 0)
@test sort(find(x -> x >= 0, d)) == [:A, :C]
end
@testset "find with general iterables" begin
s = "julia"
@test find(c -> c == 'l', s) == [3]
Expand Down

0 comments on commit 0a8fbcb

Please sign in to comment.