Skip to content

Commit

Permalink
Merge pull request #280 from FluxML/cl/gather
Browse files Browse the repository at this point in the history
add gather
  • Loading branch information
CarloLucibello authored Mar 12, 2021
2 parents d84d334 + 7307904 commit cf8cba0
Show file tree
Hide file tree
Showing 5 changed files with 194 additions and 10 deletions.
5 changes: 3 additions & 2 deletions src/NNlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import ChainRulesCore: rrule
using Base.Broadcast: broadcasted
using Statistics: mean

const IntOrTuple = Union{Integer,Tuple}
const IntOrIntTuple = Union{Integer, NTuple{N,<:Integer} where N}
const Numeric = Union{AbstractArray{<:T}, T} where {T<:Number}

# Include APIs
Expand Down Expand Up @@ -35,8 +35,9 @@ include("conv_bias_act.jl")
include("pooling.jl")
include("padding.jl")
include("upsample.jl")
include("utils.jl")
include("gather.jl")
include("scatter.jl")
include("utils.jl")

## Include implementations
include("impl/padding_edges.jl")
Expand Down
84 changes: 84 additions & 0 deletions src/gather.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
export gather, gather!

"""
gather!(dst, src, idx)
Reverse operation of [`scatter!`](@ref). Gathers data from source `src`
and writes it in destination `dst` according to the index array `idx`.
For each `k` in `CartesianIndices(idx)`, assign values to `dst` according to
dst[:, ... , k] .= src[:, ... , idx[k]...]
Notice that if `idx` is a vector containing integers,
and both `dst` and `src` are matrices, previous expression simplifies to
dst[:, k] .= src[:, idx[k]]
and `k` will run over `1:length(idx)`.
The elements of `idx` can be integers or integer tuples and may be repeated.
A single `src` column can end up being copied into zero, one,
or multiple `dst` columns.
See [`gather`](@ref) for an allocating version.
"""
function gather!(dst::AbstractArray{Tdst,Ndst},
src::AbstractArray{Tsrc,Nsrc},
idx::AbstractArray{Tidx, Nidx}) where
{Tdst, Tsrc, Ndst, Nsrc, Nidx, Tidx <: IntOrIntTuple}

M = typelength(Tidx)
d = Ndst - Nidx
d == Nsrc - M || throw(ArgumentError("Incompatible input shapes."))
size(dst)[1:d] == size(src)[1:d] || throw(ArgumentError("Incompatible input shapes."))
size(dst)[d+1:end] == size(idx) || throw(ArgumentError("Incompatible input shapes."))

colons = ntuple(i -> Colon(), d)
for k in CartesianIndices(idx)
view(dst, colons..., k) .= view(src, colons..., idx[k]...)
end
return dst
end

"""
gather(src, idx) -> dst
Reverse operation of [`scatter`](@ref). Gathers data from source `src`
and writes it in a destination `dst` according to the index
array `idx`.
For each `k` in `CartesianIndices(idx)`, assign values to `dst`
according to
dst[:, ... , k] .= src[:, ... , idx[k]...]
Notice that if `idx` is a vector containing integers
and `src` is a matrix, previous expression simplifies to
dst[:, k] .= src[:, idx[k]]
and `k` will run over `1:length(idx)`.
The elements of `idx` can be integers or integer tuples and may be repeated.
A single `src` column can end up being copied into zero, one,
or multiple `dst` columns.
See [`gather!`](@ref) for an in-place version.
"""
function gather(src::AbstractArray{Tsrc, Nsrc},
idx::AbstractArray{Tidx, Nidx}) where
{Tsrc, Nsrc, Nidx, Tidx<:IntOrIntTuple}

M = typelength(Tidx)
dstsize = (size(src)[1:Nsrc-M]..., size(idx)...)
dst = similar(src, Tsrc, dstsize)
return gather!(dst, src, idx)
end

# Simple implementation with getindex for integer array.
# Perf equivalent to the one above (which can also handle the integer case)
# leave it here to show the simple connection with getindex.
function gather(src::AbstractArray{Tsrc, Nsrc},
idx::AbstractArray{<:Integer}) where {Tsrc, Nsrc}
colons = ntuple(i -> Colon(), Nsrc-1)
return src[colons..., idx]
end
16 changes: 8 additions & 8 deletions src/scatter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ Once the dimensions match, arrays are aligned automatically. The value of `idx`
function scatter!(op,
dst::AbstractArray{Tdst,Ndst},
src::AbstractArray{Tsrc,Nsrc},
idx::AbstractArray{Tidx,Nidx}) where {Tdst,Tsrc,Tidx<:IntOrTuple,Ndst,Nsrc,Nidx}
idx::AbstractArray{Tidx,Nidx}) where {Tdst,Tsrc,Tidx<:IntOrIntTuple,Ndst,Nsrc,Nidx}
M = typelength(Tidx)
dims = _check_dims(Ndst, Nsrc, M, Nidx)
scatter!(op, dst, src, idx, Val(dims))
end

function scatter!(op, dst::AbstractArray{Tdst}, src::AbstractArray{Tsrc}, idx::AbstractArray{<:IntOrTuple},
function scatter!(op, dst::AbstractArray{Tdst}, src::AbstractArray{Tsrc}, idx::AbstractArray{<:IntOrIntTuple},
dims::Val{N}) where {Tdst,Tsrc,N}
colons = Base.ntuple(_->Colon(), dims)
for k in CartesianIndices(idx)
Expand All @@ -67,7 +67,7 @@ end
function scatter!(op::typeof(mean),
dst::AbstractArray{Tdst,Ndst},
src::AbstractArray{Tsrc,Nsrc},
idx::AbstractArray{<:IntOrTuple,Nidx}) where {Tdst,Tsrc,Ndst,Nsrc,Nidx}
idx::AbstractArray{<:IntOrIntTuple,Nidx}) where {Tdst,Tsrc,Ndst,Nsrc,Nidx}
Ns = scatter!(+, zero(dst), one.(src), idx)
dst_ = scatter!(+, zero(dst), src, idx)
dst .+= safe_div.(dst_, Ns)
Expand Down Expand Up @@ -96,7 +96,7 @@ function scatter end
for op in [+, -]
@eval function scatter(op::typeof($op),
src::AbstractArray{T,Nsrc},
idx::AbstractArray{<:IntOrTuple,Nidx}) where {T,Nsrc,Nidx}
idx::AbstractArray{<:IntOrIntTuple,Nidx}) where {T,Nsrc,Nidx}
dims = Nsrc - Nidx
dstsize = (size(src)[1:dims]..., maximum_dims(idx)...)
dst = similar(src, T, dstsize)
Expand All @@ -108,7 +108,7 @@ end
for op in [*, /]
@eval function scatter(op::typeof($op),
src::AbstractArray{T,Nsrc},
idx::AbstractArray{<:IntOrTuple,Nidx}) where {T,Nsrc,Nidx}
idx::AbstractArray{<:IntOrIntTuple,Nidx}) where {T,Nsrc,Nidx}
dims = Nsrc - Nidx
dstsize = (size(src)[1:dims]..., maximum_dims(idx)...)
dst = similar(src, T, dstsize)
Expand All @@ -119,7 +119,7 @@ end

function scatter(op::typeof(max),
src::AbstractArray{T,Nsrc},
idx::AbstractArray{<:IntOrTuple,Nidx}) where {T,Nsrc,Nidx}
idx::AbstractArray{<:IntOrIntTuple,Nidx}) where {T,Nsrc,Nidx}
dims = Nsrc - Nidx
dstsize = (size(src)[1:dims]..., maximum_dims(idx)...)
dst = similar(src, T, dstsize)
Expand All @@ -129,7 +129,7 @@ end

function scatter(op::typeof(min),
src::AbstractArray{T,Nsrc},
idx::AbstractArray{<:IntOrTuple,Nidx}) where {T,Nsrc,Nidx}
idx::AbstractArray{<:IntOrIntTuple,Nidx}) where {T,Nsrc,Nidx}
dims = Nsrc - Nidx
dstsize = (size(src)[1:dims]..., maximum_dims(idx)...)
dst = similar(src, T, dstsize)
Expand All @@ -139,7 +139,7 @@ end

function scatter(op::typeof(mean),
src::AbstractArray{T,Nsrc},
idx::AbstractArray{<:IntOrTuple,Nidx}) where {T,Nsrc,Nidx}
idx::AbstractArray{<:IntOrIntTuple,Nidx}) where {T,Nsrc,Nidx}
FT = float(T)
dims = Nsrc - Nidx
dstsize = (size(src)[1:dims]..., maximum_dims(idx)...)
Expand Down
94 changes: 94 additions & 0 deletions test/gather.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
@testset "gather scalar index" begin
T = Float32

## 1d src, 2d index of ints -> 2d output
src = T[3, 4, 5, 6, 7]
index = [1 2 3 4;
4 2 1 3;
3 5 5 3]
output = T[3 4 5 6;
6 4 3 5;
5 7 7 5]

y = gather(src, index)
@test y isa Array{T,2}
@test size(y) == size(index)
@test y == output
@test gather!(T.(zero(index)), src, index) == output
@test_throws ArgumentError gather!(zeros(T, 3, 5), src, index)

index2 = [1 2 3 4;
4 2 1 3;
3 6 5 3]
@test_throws BoundsError gather!(T.(zero(index)), src, index2)

## 1d src, 3d index of ints -> 3d output
src = T[3, 4, 5, 6, 7]
index = [1 2 3 4;
4 2 1 3;
3 5 5 3][:,:,1:1]
output = T[3 4 5 6;
6 4 3 5;
5 7 7 5][:,:,1:1]

y = gather(src, index)
@test y isa Array{T,3}
@test size(y) == size(index)
@test y == output


## 2d src, 2d index of ints -> 3d output
src = T[3 5 7
4 6 8]
index = [1 2 3;
2 2 1;
3 1 3]

output = zeros(T, 2, 3, 3)

output[:,:,1] = [3 5 7
4 6 8]

output[:,:,2] = [5 5 3
6 6 4]

output[:,:,3] = [7 3 7
8 4 8]

y = gather(src, index)
M = NNlib.typelength(eltype(index))
Nsrc = ndims(src)
@test y isa Array{T,3}
@test size(y) == (size(src)[1:Nsrc-M]..., size(index)...)
@test y == output
end

@testset "gather tuple index" begin
T = Float32

## 2d src, 1d index of 2-tuples -> 1d output
src = T[3 5 7
4 6 8]

index = [(1,1), (1,2), (1,3), (2,1), (2,2), (2,3)]

output = T[3, 5, 7, 4, 6, 8]

y = gather(src, index)
M = NNlib.typelength(eltype(index))
Nsrc = ndims(src)
@test y isa Array{T,1}
@test size(y) == (size(src)[1:Nsrc-M]..., size(index)...)
@test y == output

## 3d src, 2d index of 2-tuples -> 3d output
n1, nsrc, nidx = 2, 3, 6
src = rand(Float32, n1, nsrc, nsrc)
index = [(rand(1:nsrc), rand(1:nsrc)) for i=1:nidx, j=1:nidx]

y = gather(src, index)
M = NNlib.typelength(eltype(index))
Nsrc = ndims(src)
@test y isa Array{T,3}
@test size(y) == (size(src)[1:Nsrc-M]..., size(index)...)
end
5 changes: 5 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,15 @@ end
include("upsample.jl")
end

@testset "Gather" begin
include("gather.jl")
end

@testset "Scatter" begin
include("scatter.jl")
end

@testset "Utilities" begin
include("utils.jl")
end

0 comments on commit cf8cba0

Please sign in to comment.