Skip to content

Commit

Permalink
Merge pull request #9057 from JuliaLang/teh/cartesian_iteration2
Browse files Browse the repository at this point in the history
IteratorsMD: properly handle 0-dimensional arrays
  • Loading branch information
timholy committed Nov 20, 2014
2 parents 6a3763e + 24ea5e1 commit cf1f26e
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 51 deletions.
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ New language features
and macros in packages and user code ([#8791]). Type `?@doc` at the repl
to see the current syntax and more information.

* New multidimensional iterators and index types for efficient
iteration over general AbstractArrays

Language changes
----------------

Expand Down
118 changes: 67 additions & 51 deletions base/multidimensional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,17 @@ abstract CartesianIndex{N} # the state for all multidimensional iterators
abstract IndexIterator{N} # Iterator that visits the index associated with each element

stagedfunction Base.call{N}(::Type{CartesianIndex},index::NTuple{N,Int})
indextype,itertype=gen_cartesian(N)
indextype, itertype = gen_cartesian(N)
return :($indextype(index))
end
stagedfunction Base.call{N}(::Type{IndexIterator},index::NTuple{N,Int})
indextype,itertype=gen_cartesian(N)
indextype, itertype = gen_cartesian(N)
return :($itertype(index))
end

let implemented = IntSet()
global gen_cartesian
function gen_cartesian(N::Int, with_shared=Base.is_unix(OS_NAME))
function gen_cartesian(N::Int)
# Create the types
indextype = symbol("CartesianIndex_$N")
itertype = symbol("IndexIterator_$N")
Expand All @@ -33,50 +33,18 @@ function gen_cartesian(N::Int, with_shared=Base.is_unix(OS_NAME))
fields = [Expr(:(::), fieldnames[i], :Int) for i = 1:N]
extype = Expr(:type, false, Expr(:(<:), indextype, Expr(:curly, :CartesianIndex, N)), Expr(:block, fields...))
exindices = Expr[:(index[$i]) for i = 1:N]

onesN = ones(Int, N)
infsN = fill(typemax(Int), N)
anyzero = Expr(:(||), [:(iter.dims.$(fieldnames[i]) == 0) for i = 1:N]...)

# Some necessary ambiguity resolution
exrange = N != 1 ? nothing : quote
next(R::StepRange, I::CartesianIndex_1) = R[I.I_1], CartesianIndex_1(I.I_1+1)
next{T}(R::UnitRange{T}, I::CartesianIndex_1) = R[I.I_1], CartesianIndex_1(I.I_1+1)
end
exshared = !with_shared ? nothing : quote
getindex{T}(S::SharedArray{T,$N}, I::$indextype) = S.s[I]
setindex!{T}(S::SharedArray{T,$N}, v, I::$indextype) = S.s[I] = v
end
totalex = quote
# type definition
# type definition of state
$extype
# extra constructor from tuple
# constructor from tuple
$indextype(index::NTuple{$N,Int}) = $indextype($(exindices...))

# type definition of iterator
immutable $itertype <: IndexIterator{$N}
dims::$indextype
end
# constructor from tuple
$itertype(dims::NTuple{$N,Int})=$itertype($indextype(dims))

# getindex and setindex!
$exshared
getindex{T}(A::AbstractArray{T,$N}, index::$indextype) = @nref $N A d->getfield(index,d)
setindex!{T}(A::AbstractArray{T,$N}, v, index::$indextype) = (@nref $N A d->getfield(index,d)) = v

# next iteration
$exrange
@inline function next{T}(A::AbstractArray{T,$N}, state::$indextype)
@inbounds v = A[state]
newstate = @nif $N d->(getfield(state,d) < size(A, d)) d->(@ncall($N, $indextype, k->(k>d ? getfield(state,k) : k==d ? getfield(state,k)+1 : 1)))
v, newstate
end
@inline function next(iter::$itertype, state::$indextype)
newstate = @nif $N d->(getfield(state,d) < getfield(iter.dims,d)) d->(@ncall($N, $indextype, k->(k>d ? getfield(state,k) : k==d ? getfield(state,k)+1 : 1)))
state, newstate
end

# start
start(iter::$itertype) = $anyzero ? $indextype($(infsN...)) : $indextype($(onesN...))
end
eval(totalex)
push!(implemented,N)
Expand All @@ -85,26 +53,74 @@ function gen_cartesian(N::Int, with_shared=Base.is_unix(OS_NAME))
end
end

# Iteration
# indexing
stagedfunction getindex{N}(A::AbstractArray, index::CartesianIndex{N})
:(@nref $N A d->getfield(index,d))
end
stagedfunction setindex!{N}(A::AbstractArray, v, index::CartesianIndex{N})
:((@nref $N A d->getfield(index,d)) = v)
end

# Prevent an ambiguity warning
gen_cartesian(1) # to make sure the next two lines are valid
next(R::StepRange, state::(Bool, CartesianIndex{1})) = R[state[2].I_1], (state[2].I_1==length(R), CartesianIndex_1(state[2].I_1+1))
next{T}(R::UnitRange{T}, state::(Bool, CartesianIndex{1})) = R[state[2].I_1], (state[2].I_1==length(R), CartesianIndex_1(state[2].I_1+1))

# iteration
eachindex(A::AbstractArray) = IndexIterator(size(A))

# start iteration
stagedfunction _start{T,N}(A::AbstractArray{T,N},::LinearSlow)
args = fill(:s, N)
stagedfunction start{N}(iter::IndexIterator{N})
indextype, _ = gen_cartesian(N)
args = fill(1, N)
fieldnames = [symbol("I_$i") for i = 1:N]
anyzero = Expr(:(||), [:(iter.dims.$(fieldnames[i]) == 0) for i = 1:N]...)
quote
z = $anyzero
return z, $indextype($(args...))
end
end

stagedfunction _start{T,N}(A::AbstractArray{T,N}, ::LinearSlow)
indextype, _ = gen_cartesian(N)
args = fill(1, N)
quote
z = isempty(A)
return z, $indextype($(args...))
end
end

stagedfunction next{T,N}(A::AbstractArray{T,N}, state::(Bool, CartesianIndex{N}))
indextype, _ = gen_cartesian(N)
finishedex = (N==0 ? true : :(getfield(newindex, $N) > size(A, $N)))
meta = Expr(:meta, :inline)
quote
$meta
index=state[2]
@inbounds v = A[index]
newindex=@nif $N d->(getfield(index,d) < size(A, d)) d->@ncall($N, $indextype, k->(k>d ? getfield(index,k) : k==d ? getfield(index,k)+1 : 1))
finished=$finishedex
v, (finished,newindex)
end
end
stagedfunction next{N}(iter::IndexIterator{N}, state::(Bool, CartesianIndex{N}))
indextype, _ = gen_cartesian(N)
finishedex = (N==0 ? true : :(getfield(newindex, $N) > getfield(iter.dims, $N)))
meta = Expr(:meta, :inline)
quote
s = ifelse(isempty(A), typemax(Int), 1)
$indextype($(args...))
$meta
index=state[2]
newindex=@nif $N d->(getfield(index,d) < getfield(iter.dims, d)) d->@ncall($N, $indextype, k->(k>d ? getfield(index,k) : k==d ? getfield(index,k)+1 : 1))
finished=$finishedex
index, (finished,newindex)
end
end

# Ambiguity resolution
done(R::StepRange, I::CartesianIndex{1}) = getfield(I, 1) > length(R)
done(R::UnitRange, I::CartesianIndex{1}) = getfield(I, 1) > length(R)
done(R::FloatRange, I::CartesianIndex{1}) = getfield(I, 1) > length(R)
done(R::StepRange, state::(Bool, CartesianIndex{1})) = state[1]
done(R::UnitRange, state::(Bool, CartesianIndex{1})) = state[1]
done(R::FloatRange, state::(Bool, CartesianIndex{1})) = state[1]

done{T,N}(A::AbstractArray{T,N}, I::CartesianIndex{N}) = getfield(I, N) > size(A, N)
done{N}(iter::IndexIterator{N}, I::CartesianIndex{N}) = getfield(I, N) > getfield(iter.dims, N)
done{T,N}(A::AbstractArray{T,N}, state::(Bool, CartesianIndex{N})) = state[1]
done{N}(iter::IndexIterator{N}, state::(Bool, CartesianIndex{N})) = state[1]

end # IteratorsMD

Expand Down
26 changes: 26 additions & 0 deletions doc/stdlib/base.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4123,6 +4123,32 @@ Basic functions

Returns the number of elements in A

.. function:: eachindex(A)

Creates an iterable object for visiting each multi-dimensional index of the AbstractArray ``A``. Example for a 2-d array::

julia> A = rand(2,3)
2x3 Array{Float64,2}:
0.960084 0.629326 0.625155
0.432588 0.955903 0.991614

julia> for iter in eachindex(A)
@show iter.I_1, iter.I_2
@show A[iter]
end
(iter.I_1,iter.I_2) = (1,1)
A[iter] = 0.9600836263003063
(iter.I_1,iter.I_2) = (2,1)
A[iter] = 0.4325878255452178
(iter.I_1,iter.I_2) = (1,2)
A[iter] = 0.6293256402775211
(iter.I_1,iter.I_2) = (2,2)
A[iter] = 0.9559027084099654
(iter.I_1,iter.I_2) = (1,3)
A[iter] = 0.6251548453735303
(iter.I_1,iter.I_2) = (2,3)
A[iter] = 0.9916142534546522

.. function:: countnz(A)

Counts the number of nonzero values in array A (dense or sparse). Note that this is not a constant-time operation. For sparse matrices, one should usually use ``nnz``, which returns the number of stored values.
Expand Down
18 changes: 18 additions & 0 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -927,6 +927,19 @@ b718cbc = 5
@test_throws InexactError b718cbc[1.1]

# Multidimensional iterators
for a in ([1:5], reshape([2]))
counter = 0
for I in eachindex(a)
counter += 1
end
@test counter == length(a)
counter = 0
for aa in a
counter += 1
end
@test counter == length(a)
end

function mdsum(A)
s = 0.0
for a in A
Expand Down Expand Up @@ -970,9 +983,14 @@ for i = 2:10
insert!(shp, 2, 1)
end

a = reshape([2])
@test mdsum(a) == 2
@test mdsum2(a) == 2

a = ones(0,5)
b = sub(a, :, :)
@test mdsum(b) == 0
a = ones(5,0)
b = sub(a, :, :)
@test mdsum(b) == 0

0 comments on commit cf1f26e

Please sign in to comment.