From 373e61b79ed53ae50c0b65e47818664d10884830 Mon Sep 17 00:00:00 2001 From: Tim Holy Date: Thu, 20 Nov 2014 15:38:57 -0600 Subject: [PATCH] ngenerate/nsplat: multidimensional algorithms on AbstractArrays --- base/multidimensional.jl | 230 +++++++++++++++++++++------------------ 1 file changed, 126 insertions(+), 104 deletions(-) diff --git a/base/multidimensional.jl b/base/multidimensional.jl index 77a99f8ad33e1..e8d4dba1b3276 100644 --- a/base/multidimensional.jl +++ b/base/multidimensional.jl @@ -181,9 +181,12 @@ using .IteratorsMD ### From array.jl -@ngenerate N Void function checksize(A::AbstractArray, I::NTuple{N, Any}...) - @nexprs N d->(size(A, d) == length(I_d) || throw(DimensionMismatch("index $d has length $(length(I_d)), but size(A, $d) = $(size(A,d))"))) - nothing +stagedfunction checksize(A::AbstractArray, I...) + N = length(I) + quote + @nexprs $N d->(size(A, d) == length(I[d]) || throw(DimensionMismatch("index $d has length $(length(I[d])), but size(A, $d) = $(size(A,d))"))) + nothing + end end @inline unsafe_getindex(v::BitArray, ind::Int) = Base.unsafe_bitgetindex(v.chunks, ind) @@ -259,17 +262,19 @@ end end -@ngenerate N NTuple{N,Vector{Int}} function findn{T,N}(A::AbstractArray{T,N}) - nnzA = countnz(A) - @nexprs N d->(I_d = Array(Int, nnzA)) - k = 1 - @nloops N i A begin - @inbounds if (@nref N A i) != zero(T) - @nexprs N d->(I_d[k] = i_d) - k += 1 +stagedfunction findn{T,N}(A::AbstractArray{T,N}) + quote + nnzA = countnz(A) + @nexprs $N d->(I_d = Array(Int, nnzA)) + k = 1 + @nloops $N i A begin + @inbounds if (@nref $N A i) != zero(T) + @nexprs $N d->(I_d[k] = i_d) + k += 1 + end end + @ntuple $N I end - @ntuple N I end @@ -386,57 +391,70 @@ end cumsum(A::AbstractArray, axis::Integer=1) = cumsum!(similar(A, Base._cumsum_type(A)), A, axis) +cumsum!(B, A::AbstractArray) = cumsum!(B, A, 1) cumprod(A::AbstractArray, axis::Integer=1) = cumprod!(similar(A), A, axis) +cumprod!(B, A) = cumprod!(B, A, 1) for (f, op) in ((:cumsum!, :+), (:cumprod!, :*)) @eval begin - @ngenerate N typeof(B) function ($f){T,N}(B, A::AbstractArray{T,N}, axis::Integer=1) - if size(B, axis) < 1 - return B - end - size(B) == size(A) || throw(DimensionMismatch("size of B must match A")) - if axis == 1 - # We can accumulate to a temporary variable, which allows register usage and will be slightly faster - @inbounds @nloops N i d->(d > 1 ? (1:size(A,d)) : (1:1)) begin - tmp = convert(eltype(B), @nref(N, A, i)) - @nref(N, B, i) = tmp - for i_1 = 2:size(A,1) - tmp = ($op)(tmp, @nref(N, A, i)) - @nref(N, B, i) = tmp - end + stagedfunction ($f){T,N}(B, A::AbstractArray{T,N}, axis::Integer) + quote + if size(B, axis) < 1 + return B end - else - @nexprs N d->(isaxis_d = axis == d) - # Copy the initial element in each 1d vector along dimension `axis` - @inbounds @nloops N i d->(d == axis ? (1:1) : (1:size(A,d))) @nref(N, B, i) = @nref(N, A, i) - # Accumulate - @inbounds @nloops N i d->((1+isaxis_d):size(A, d)) d->(j_d = i_d - isaxis_d) begin - @nref(N, B, i) = ($op)(@nref(N, B, j), @nref(N, A, i)) + size(B) == size(A) || throw(DimensionMismatch("Size of B must match A")) + if axis == 1 + # We can accumulate to a temporary variable, which allows register usage and will be slightly faster + @inbounds @nloops $N i d->(d > 1 ? (1:size(A,d)) : (1:1)) begin + tmp = convert(eltype(B), @nref($N, A, i)) + @nref($N, B, i) = tmp + for i_1 = 2:size(A,1) + tmp = ($($op))(tmp, @nref($N, A, i)) + @nref($N, B, i) = tmp + end + end + else + @nexprs $N d->(isaxis_d = axis == d) + # Copy the initial element in each 1d vector along dimension `axis` + @inbounds @nloops $N i d->(d == axis ? (1:1) : (1:size(A,d))) @nref($N, B, i) = @nref($N, A, i) + # Accumulate + @inbounds @nloops $N i d->((1+isaxis_d):size(A, d)) d->(j_d = i_d - isaxis_d) begin + @nref($N, B, i) = ($($op))(@nref($N, B, j), @nref($N, A, i)) + end end + B end - B end end end ### from abstractarray.jl -@ngenerate N typeof(A) function fill!{T,N}(A::AbstractArray{T,N}, x) - xT = convert(T, x) - @nloops N i A begin - @inbounds (@nref N A i) = xT +function fill!(A::AbstractArray, x) + for I in eachindex(A) + @inbounds A[I] = x end A end -@ngenerate N typeof(dest) function copy!{T,N}(dest::AbstractArray{T,N}, src::AbstractArray{T,N}) - if @nall N d->(size(dest,d) == size(src,d)) - @nloops N i dest begin - @inbounds (@nref N dest i) = (@nref N src i) +function copy!{T,N}(dest::AbstractArray{T,N}, src::AbstractArray{T,N}) + samesize = true + for d = 1:N + if size(dest,d) != size(src,d) + samesize = false + break + end + end + if samesize + for I in eachindex(dest) + @inbounds dest[I] = src[I] end else - invoke(copy!, (typeof(dest), Any), dest, src) + length(dest) == length(src) || throw(DimensionMismatch("Inconsistent lengths")) + for (Idest, Isrc) in zip(eachindex(dest),eachindex(src)) + @inbounds dest[Idest] = src[Isrc] + end end dest end @@ -697,19 +715,21 @@ end ## findn -@ngenerate N NTuple{N,Vector{Int}} function findn{N}(B::BitArray{N}) - nnzB = countnz(B) - I = ntuple(N, x->Array(Int, nnzB)) - if nnzB > 0 - count = 1 - @nloops N i B begin - if (@nref N B i) # TODO: should avoid bounds checking - @nexprs N d->(I[d][count] = i_d) - count += 1 +stagedfunction findn{N}(B::BitArray{N}) + quote + nnzB = countnz(B) + I = ntuple($N, x->Array(Int, nnzB)) + if nnzB > 0 + count = 1 + @nloops $N i B begin + if (@nref $N B i) # TODO: should avoid bounds checking + @nexprs $N d->(I[d][count] = i_d) + count += 1 + end end end + return I end - return I end ## isassigned @@ -774,70 +794,72 @@ immutable Prehashed end hash(x::Prehashed) = x.hash -@ngenerate N typeof(A) function unique{T,N}(A::AbstractArray{T,N}, dim::Int) - 1 <= dim <= N || return copy(A) - hashes = zeros(UInt, size(A, dim)) +stagedfunction unique{T,N}(A::AbstractArray{T,N}, dim::Int) + quote + 1 <= dim <= $N || return copy(A) + hashes = zeros(UInt, size(A, dim)) - # Compute hash for each row - k = 0 - @nloops N i A d->(if d == dim; k = i_d; end) begin - @inbounds hashes[k] = hash(hashes[k], hash((@nref N A i))) - end + # Compute hash for each row + k = 0 + @nloops $N i A d->(if d == dim; k = i_d; end) begin + @inbounds hashes[k] = hash(hashes[k], hash((@nref $N A i))) + end - # Collect index of first row for each hash - uniquerow = Array(Int, size(A, dim)) - firstrow = Dict{Prehashed,Int}() - for k = 1:size(A, dim) - uniquerow[k] = get!(firstrow, Prehashed(hashes[k]), k) - end - uniquerows = collect(values(firstrow)) + # Collect index of first row for each hash + uniquerow = Array(Int, size(A, dim)) + firstrow = Dict{Prehashed,Int}() + for k = 1:size(A, dim) + uniquerow[k] = get!(firstrow, Prehashed(hashes[k]), k) + end + uniquerows = collect(values(firstrow)) - # Check for collisions - collided = falses(size(A, dim)) - @inbounds begin - @nloops N i A d->(if d == dim + # Check for collisions + collided = falses(size(A, dim)) + @inbounds begin + @nloops $N i A d->(if d == dim k = i_d j_d = uniquerow[k] else j_d = i_d end) begin - if (@nref N A j) != (@nref N A i) - collided[k] = true - end + if (@nref $N A j) != (@nref $N A i) + collided[k] = true + end + end end - end - if any(collided) - nowcollided = BitArray(size(A, dim)) - while any(collided) - # Collect index of first row for each collided hash - empty!(firstrow) - for j = 1:size(A, dim) - collided[j] || continue - uniquerow[j] = get!(firstrow, Prehashed(hashes[j]), j) - end - for v in values(firstrow) - push!(uniquerows, v) - end + if any(collided) + nowcollided = BitArray(size(A, dim)) + while any(collided) + # Collect index of first row for each collided hash + empty!(firstrow) + for j = 1:size(A, dim) + collided[j] || continue + uniquerow[j] = get!(firstrow, Prehashed(hashes[j]), j) + end + for v in values(firstrow) + push!(uniquerows, v) + end - # Check for collisions - fill!(nowcollided, false) - @nloops N i A d->begin - if d == dim - k = i_d - j_d = uniquerow[k] - (!collided[k] || j_d == k) && continue - else - j_d = i_d - end - end begin - if (@nref N A j) != (@nref N A i) - nowcollided[k] = true + # Check for collisions + fill!(nowcollided, false) + @nloops $N i A d->begin + if d == dim + k = i_d + j_d = uniquerow[k] + (!collided[k] || j_d == k) && continue + else + j_d = i_d + end + end begin + if (@nref $N A j) != (@nref $N A i) + nowcollided[k] = true + end end + (collided, nowcollided) = (nowcollided, collided) end - (collided, nowcollided) = (nowcollided, collided) end - end - @nref N A d->d == dim ? sort!(uniquerows) : (1:size(A, d)) + @nref $N A d->d == dim ? sort!(uniquerows) : (1:size(A, d)) + end end