From 73079044dd9180d48686819196b58f046e35b91c Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Fri, 5 Mar 2021 09:17:23 +0100 Subject: [PATCH] cleanup; getindex --- Project.toml | 1 - src/NNlib.jl | 2 +- src/gather.jl | 9 +++++++++ 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 2985a7e02..5b0923ad7 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,6 @@ uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" version = "0.7.14" [deps] -BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/NNlib.jl b/src/NNlib.jl index 24658016d..647eb400a 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -7,7 +7,7 @@ import ChainRulesCore: rrule using Base.Broadcast: broadcasted using Statistics: mean -const IntOrIntTuple = Union{Integer, NTuple{N,Integer} where N} +const IntOrIntTuple = Union{Integer, NTuple{N,<:Integer} where N} const Numeric = Union{AbstractArray{<:T}, T} where {T<:Number} # Include APIs diff --git a/src/gather.jl b/src/gather.jl index 906ce15d6..46984f180 100644 --- a/src/gather.jl +++ b/src/gather.jl @@ -73,3 +73,12 @@ function gather(src::AbstractArray{Tsrc, Nsrc}, dst = similar(src, Tsrc, dstsize) return gather!(dst, src, idx) end + +# Simple implementation with getindex for integer array. +# Perf equivalent to the one above (which can also handle the integer case) +# leave it here to show the simple connection with getindex. +function gather(src::AbstractArray{Tsrc, Nsrc}, + idx::AbstractArray{<:Integer}) where {Tsrc, Nsrc} + colons = ntuple(i -> Colon(), Nsrc-1) + return src[colons..., idx] +end