-
-
Notifications
You must be signed in to change notification settings - Fork 122
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #280 from FluxML/cl/gather
add gather
- Loading branch information
Showing
5 changed files
with
194 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
export gather, gather! | ||
|
||
""" | ||
gather!(dst, src, idx) | ||
Reverse operation of [`scatter!`](@ref). Gathers data from source `src` | ||
and writes it in destination `dst` according to the index array `idx`. | ||
For each `k` in `CartesianIndices(idx)`, assign values to `dst` according to | ||
dst[:, ... , k] .= src[:, ... , idx[k]...] | ||
Notice that if `idx` is a vector containing integers, | ||
and both `dst` and `src` are matrices, previous expression simplifies to | ||
dst[:, k] .= src[:, idx[k]] | ||
and `k` will run over `1:length(idx)`. | ||
The elements of `idx` can be integers or integer tuples and may be repeated. | ||
A single `src` column can end up being copied into zero, one, | ||
or multiple `dst` columns. | ||
See [`gather`](@ref) for an allocating version. | ||
""" | ||
function gather!(dst::AbstractArray{Tdst,Ndst}, | ||
src::AbstractArray{Tsrc,Nsrc}, | ||
idx::AbstractArray{Tidx, Nidx}) where | ||
{Tdst, Tsrc, Ndst, Nsrc, Nidx, Tidx <: IntOrIntTuple} | ||
|
||
M = typelength(Tidx) | ||
d = Ndst - Nidx | ||
d == Nsrc - M || throw(ArgumentError("Incompatible input shapes.")) | ||
size(dst)[1:d] == size(src)[1:d] || throw(ArgumentError("Incompatible input shapes.")) | ||
size(dst)[d+1:end] == size(idx) || throw(ArgumentError("Incompatible input shapes.")) | ||
|
||
colons = ntuple(i -> Colon(), d) | ||
for k in CartesianIndices(idx) | ||
view(dst, colons..., k) .= view(src, colons..., idx[k]...) | ||
end | ||
return dst | ||
end | ||
|
||
""" | ||
gather(src, idx) -> dst | ||
Reverse operation of [`scatter`](@ref). Gathers data from source `src` | ||
and writes it in a destination `dst` according to the index | ||
array `idx`. | ||
For each `k` in `CartesianIndices(idx)`, assign values to `dst` | ||
according to | ||
dst[:, ... , k] .= src[:, ... , idx[k]...] | ||
Notice that if `idx` is a vector containing integers | ||
and `src` is a matrix, previous expression simplifies to | ||
dst[:, k] .= src[:, idx[k]] | ||
and `k` will run over `1:length(idx)`. | ||
The elements of `idx` can be integers or integer tuples and may be repeated. | ||
A single `src` column can end up being copied into zero, one, | ||
or multiple `dst` columns. | ||
See [`gather!`](@ref) for an in-place version. | ||
""" | ||
function gather(src::AbstractArray{Tsrc, Nsrc}, | ||
idx::AbstractArray{Tidx, Nidx}) where | ||
{Tsrc, Nsrc, Nidx, Tidx<:IntOrIntTuple} | ||
|
||
M = typelength(Tidx) | ||
dstsize = (size(src)[1:Nsrc-M]..., size(idx)...) | ||
dst = similar(src, Tsrc, dstsize) | ||
return gather!(dst, src, idx) | ||
end | ||
|
||
# Simple implementation with getindex for integer array. | ||
# Perf equivalent to the one above (which can also handle the integer case) | ||
# leave it here to show the simple connection with getindex. | ||
function gather(src::AbstractArray{Tsrc, Nsrc}, | ||
idx::AbstractArray{<:Integer}) where {Tsrc, Nsrc} | ||
colons = ntuple(i -> Colon(), Nsrc-1) | ||
return src[colons..., idx] | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
@testset "gather scalar index" begin | ||
T = Float32 | ||
|
||
## 1d src, 2d index of ints -> 2d output | ||
src = T[3, 4, 5, 6, 7] | ||
index = [1 2 3 4; | ||
4 2 1 3; | ||
3 5 5 3] | ||
output = T[3 4 5 6; | ||
6 4 3 5; | ||
5 7 7 5] | ||
|
||
y = gather(src, index) | ||
@test y isa Array{T,2} | ||
@test size(y) == size(index) | ||
@test y == output | ||
@test gather!(T.(zero(index)), src, index) == output | ||
@test_throws ArgumentError gather!(zeros(T, 3, 5), src, index) | ||
|
||
index2 = [1 2 3 4; | ||
4 2 1 3; | ||
3 6 5 3] | ||
@test_throws BoundsError gather!(T.(zero(index)), src, index2) | ||
|
||
## 1d src, 3d index of ints -> 3d output | ||
src = T[3, 4, 5, 6, 7] | ||
index = [1 2 3 4; | ||
4 2 1 3; | ||
3 5 5 3][:,:,1:1] | ||
output = T[3 4 5 6; | ||
6 4 3 5; | ||
5 7 7 5][:,:,1:1] | ||
|
||
y = gather(src, index) | ||
@test y isa Array{T,3} | ||
@test size(y) == size(index) | ||
@test y == output | ||
|
||
|
||
## 2d src, 2d index of ints -> 3d output | ||
src = T[3 5 7 | ||
4 6 8] | ||
index = [1 2 3; | ||
2 2 1; | ||
3 1 3] | ||
|
||
output = zeros(T, 2, 3, 3) | ||
|
||
output[:,:,1] = [3 5 7 | ||
4 6 8] | ||
|
||
output[:,:,2] = [5 5 3 | ||
6 6 4] | ||
|
||
output[:,:,3] = [7 3 7 | ||
8 4 8] | ||
|
||
y = gather(src, index) | ||
M = NNlib.typelength(eltype(index)) | ||
Nsrc = ndims(src) | ||
@test y isa Array{T,3} | ||
@test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) | ||
@test y == output | ||
end | ||
|
||
@testset "gather tuple index" begin | ||
T = Float32 | ||
|
||
## 2d src, 1d index of 2-tuples -> 1d output | ||
src = T[3 5 7 | ||
4 6 8] | ||
|
||
index = [(1,1), (1,2), (1,3), (2,1), (2,2), (2,3)] | ||
|
||
output = T[3, 5, 7, 4, 6, 8] | ||
|
||
y = gather(src, index) | ||
M = NNlib.typelength(eltype(index)) | ||
Nsrc = ndims(src) | ||
@test y isa Array{T,1} | ||
@test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) | ||
@test y == output | ||
|
||
## 3d src, 2d index of 2-tuples -> 3d output | ||
n1, nsrc, nidx = 2, 3, 6 | ||
src = rand(Float32, n1, nsrc, nsrc) | ||
index = [(rand(1:nsrc), rand(1:nsrc)) for i=1:nidx, j=1:nidx] | ||
|
||
y = gather(src, index) | ||
M = NNlib.typelength(eltype(index)) | ||
Nsrc = ndims(src) | ||
@test y isa Array{T,3} | ||
@test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters