Skip to content

Commit

Permalink
Merge pull request #8 from yuehhua/gather
Browse files Browse the repository at this point in the history
Gather for CUDA support
  • Loading branch information
CarloLucibello authored May 24, 2021
2 parents 94ea8b3 + c5da628 commit 2a4b220
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 24 deletions.
66 changes: 42 additions & 24 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,22 @@ uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82"
version = "0.4.1"

[[CUDA]]
deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "DataStructures", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "MacroTools", "Memoize", "Printf", "Random", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "TimerOutputs"]
git-tree-sha1 = "d4fa6486e94c4087f1d081d7be2d501a170bd51d"
deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "DataStructures", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "MacroTools", "Memoize", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "TimerOutputs"]
git-tree-sha1 = "364179416eabc34c9ca32126a6bdb431680c3bad"
uuid = "052768ef-5323-5732-b1bb-66c8b64840ba"
version = "3.1.0"
version = "3.2.1"

[[ChainRulesCore]]
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
git-tree-sha1 = "a66109c73612c63b10923ac446fddb0f0d21a593"
git-tree-sha1 = "b391f22252b8754f4440de1f37ece49d8a7314bb"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.9.40"
version = "0.9.44"

[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "ac4132ad78082518ec2037ae5770b6e796f7f956"
git-tree-sha1 = "e4e2b39db08f967cc1360951f01e8a75ec441cab"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "3.27.0"
version = "3.30.0"

[[CompilerSupportLibraries_jll]]
deps = ["Artifacts", "Libdl"]
Expand All @@ -72,6 +72,12 @@ uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"

[[DocStringExtensions]]
deps = ["LibGit2", "Markdown", "Pkg", "Test"]
git-tree-sha1 = "9d4f64f79012636741cf01133158a54b24924c32"
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
version = "0.8.4"

[[Downloads]]
deps = ["ArgTools", "LibCURL", "NetworkOptions"]
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
Expand All @@ -82,16 +88,16 @@ uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
version = "0.1.3"

[[GPUArrays]]
deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization"]
git-tree-sha1 = "9c95b2fd5c16bc7f97371e9f92f0fef77e0f5957"
deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization", "Statistics"]
git-tree-sha1 = "df5b8569904c5c10e84c640984cfff054b18c086"
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
version = "6.2.2"
version = "6.4.1"

[[GPUCompiler]]
deps = ["DataStructures", "ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "Serialization", "TimerOutputs", "UUIDs"]
git-tree-sha1 = "6eadd2321dc3ac0fc9d530ab01c2caa7fe5d74c6"
git-tree-sha1 = "42d635f6d87af125b86288df3819f805fb4d851a"
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
version = "0.11.4"
version = "0.11.5"

[[InteractiveUtils]]
deps = ["Markdown"]
Expand All @@ -105,9 +111,9 @@ version = "1.3.0"

[[LLVM]]
deps = ["CEnum", "Libdl", "Printf", "Unicode"]
git-tree-sha1 = "b616937c31337576360cb9fb872ec7633af7b194"
git-tree-sha1 = "b499c68a45249b0385585c62f4a9b62b5db8e691"
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
version = "3.6.0"
version = "3.7.1"

[[LazyArtifacts]]
deps = ["Artifacts", "Pkg"]
Expand Down Expand Up @@ -136,6 +142,12 @@ uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
deps = ["Libdl"]
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[[LogExpFunctions]]
deps = ["DocStringExtensions", "LinearAlgebra"]
git-tree-sha1 = "1ba664552f1ef15325e68dc4c05c3ef8c2d5d885"
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
version = "0.2.4"

[[Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"

Expand Down Expand Up @@ -181,19 +193,19 @@ uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
version = "0.5.4+0"

[[OrderedCollections]]
git-tree-sha1 = "4fa2ba51070ec13fcc7517db714445b4ab986bdf"
git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.4.0"
version = "1.4.1"

[[Pkg]]
deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"

[[Preferences]]
deps = ["TOML"]
git-tree-sha1 = "ea79e4c9077208cd3bc5d29631a26bc0cff78902"
git-tree-sha1 = "00cfd92944ca9c760982747e9a1d0d5d86ab1e5a"
uuid = "21216c6a-2e73-6563-6e65-726566657250"
version = "1.2.1"
version = "1.2.2"

[[Printf]]
deps = ["Unicode"]
Expand All @@ -207,6 +219,12 @@ uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
deps = ["Serialization"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[[Random123]]
deps = ["Libdl", "Random", "RandomNumbers"]
git-tree-sha1 = "7c6710c8198fd4444b5eb6a3840b7d47bd3593c5"
uuid = "74087812-796a-5b5d-8853-05524746bad3"
version = "1.3.1"

[[RandomNumbers]]
deps = ["Random", "Requires"]
git-tree-sha1 = "441e6fc35597524ada7f85e13df1f4e10137d16f"
Expand Down Expand Up @@ -248,10 +266,10 @@ deps = ["LinearAlgebra", "Random"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[[SpecialFunctions]]
deps = ["ChainRulesCore", "OpenSpecFun_jll"]
git-tree-sha1 = "5919936c0e92cff40e57d0ddf0ceb667d42e5902"
deps = ["ChainRulesCore", "LogExpFunctions", "OpenSpecFun_jll"]
git-tree-sha1 = "c467f25b6ec4167ea3a9a4351c66c2e1cba5da33"
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "1.3.0"
version = "1.4.1"

[[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
Expand All @@ -270,10 +288,10 @@ deps = ["InteractiveUtils", "Logging", "Random", "Serialization"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[[TimerOutputs]]
deps = ["Printf"]
git-tree-sha1 = "32cdbe6cd2d214c25a0b88f985c9e0092877c236"
deps = ["ExprTools", "Printf"]
git-tree-sha1 = "bf8aacc899a1bd16522d0350e1e2310510d77236"
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
version = "0.5.8"
version = "0.5.9"

[[UUIDs]]
deps = ["Random", "SHA"]
Expand Down
1 change: 1 addition & 0 deletions src/NNlibCUDA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ include("upsample.jl")
include("activations.jl")
include("batchedmul.jl")
include("scatter.jl")
include("gather.jl")
include("cudnn/cudnn.jl")
include("cudnn/conv.jl")
include("cudnn/pooling.jl")
Expand Down
53 changes: 53 additions & 0 deletions src/gather.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
function gather_check_dims(X::AbstractArray{Tx,Nx},
Y::AbstractArray{Ty,Ny},
idx::AbstractArray{Tidx,Nidx}) where
{Tx,Ty,Tidx<:IntOrIntTuple,Nx,Ny,Nidx}
M = NNlib.typelength(Tidx)
dims = gather_check_dims(Nx, Ny, M, Nidx)
size(X)[1:dims] == size(Y)[1:dims] || throw(ArgumentError("Incompatible input shapes."))
size(Y)[dims+1:end] == size(idx) || throw(ArgumentError("Incompatible input shapes."))
return dims
end

function gather_check_dims(X::AbstractArray{Tx,Nx},
Y::AbstractArray{Ty,Ny},
idx::AbstractArray{CartesianIndex{M},Nidx}) where
{Tx,Ty,Nx,Ny,M,Nidx}
dims = gather_check_dims(Nx, Ny, M, Nidx)
size(X)[1:dims] == size(Y)[1:dims] || throw(ArgumentError("Incompatible input shapes."))
size(Y)[dims+1:end] == size(idx) || throw(ArgumentError("Incompatible input shapes."))
return dims
end

function gather_check_dims(Nx, Ny, M, Nidx)
@assert Nx - M == Ny - Nidx "Incompatible input shapes of (dst, src, idx) = ($Nx, $Ny, $Nidx)."
dims = Nx - M
dims < 0 && throw(ArgumentError("dims must be non-negative but got dims=$dims."))
return dims
end

function gather_kernel!(dst, src, idx, max_idx, max_dims_idx, dims_size)
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x

@inbounds if index <= max_idx
j, k = divrem(index-1, max_dims_idx)
dims_i = CartesianIndices(dims_size)[k+1]
dst[index] = src[dims_i, idx[j+1]...]
end
return nothing
end

function NNlib.gather!(dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArray)
dims = gather_check_dims(src, dst, idx)
dims_size = size(src)[1:dims]
max_dims_idx = prod(dims_size)
max_idx = max_dims_idx * length(idx)
args = dst, src, idx, max_idx, max_dims_idx, dims_size

kernel = @cuda launch=false gather_kernel!(args...)
config = launch_configuration(kernel.fun; max_threads=256)
threads = min(max_idx, config.threads)
blocks = cld(max_idx, threads)
kernel(args...; threads=threads, blocks=blocks)
return dst
end
60 changes: 60 additions & 0 deletions test/gather.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
@testset "gather" begin
T = Float32
CT = CuArray{Float32}

## 1d src, 2d index of ints -> 2d output
src = CT([3, 4, 5, 6, 7])
index = cu([1 2 3 4;
4 2 1 3;
3 5 5 3])
output = CT([3 4 5 6;
6 4 3 5;
5 7 7 5])

y = NNlib.gather(src, index)
@test y isa CuArray{Float32,2}
@test size(y) == size(index)
gputest(src -> NNlib.gather(src, index), src, checkgrad=false)
@test NNlib.gather!(CUDA.zeros(T, size(index)...), src, index) == output
@test_throws ArgumentError NNlib.gather!(zeros(T, 3, 5), src, index)

## 1d src, 3d index of ints -> 3d output
src = CT([3, 4, 5, 6, 7])
index = cu([1 2 3 4;
4 2 1 3;
3 5 5 3][:,:,1:1])
output = CT([3 4 5 6;
6 4 3 5;
5 7 7 5][:,:,1:1])

y = NNlib.gather(src, index)
@test y isa CuArray{Float32,3}
@test size(y) == size(index)
gputest(src -> NNlib.gather(src, index), src, checkgrad=false)


## 2d src, 2d index of ints -> 3d output
src = CT([3 5 7
4 6 8])
index = cu([1 2 3;
2 2 1;
3 1 3])

output = zeros(T, 2, 3, 3)

output[:,:,1] = [3 5 7
4 6 8]

output[:,:,2] = [5 5 3
6 6 4]

output[:,:,3] = [7 3 7
8 4 8]

y = NNlib.gather(src, index)
M = NNlib.typelength(eltype(index))
Nsrc = ndims(src)
@test y isa CuArray{Float32,3}
@test size(y) == (size(src)[1:Nsrc-M]..., size(index)...)
gputest(src -> NNlib.gather(src, index), src, checkgrad=false)
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ if CUDA.has_cuda()
include("softmax.jl")
include("batchnorm.jl")
include("scatter.jl")
include("gather.jl")
end

0 comments on commit 2a4b220

Please sign in to comment.