diff --git a/Manifest.toml b/Manifest.toml index dc33e60..4db5124 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -33,22 +33,22 @@ uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" version = "0.4.1" [[CUDA]] -deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "DataStructures", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "MacroTools", "Memoize", "Printf", "Random", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "TimerOutputs"] -git-tree-sha1 = "d4fa6486e94c4087f1d081d7be2d501a170bd51d" +deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "DataStructures", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "MacroTools", "Memoize", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "TimerOutputs"] +git-tree-sha1 = "364179416eabc34c9ca32126a6bdb431680c3bad" uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" -version = "3.1.0" +version = "3.2.1" [[ChainRulesCore]] deps = ["Compat", "LinearAlgebra", "SparseArrays"] -git-tree-sha1 = "a66109c73612c63b10923ac446fddb0f0d21a593" +git-tree-sha1 = "b391f22252b8754f4440de1f37ece49d8a7314bb" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.9.40" +version = "0.9.44" [[Compat]] deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] -git-tree-sha1 = "ac4132ad78082518ec2037ae5770b6e796f7f956" +git-tree-sha1 = "e4e2b39db08f967cc1360951f01e8a75ec441cab" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "3.27.0" +version = "3.30.0" [[CompilerSupportLibraries_jll]] deps = ["Artifacts", "Libdl"] @@ -72,6 +72,12 @@ uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" +[[DocStringExtensions]] +deps = ["LibGit2", "Markdown", "Pkg", "Test"] +git-tree-sha1 = "9d4f64f79012636741cf01133158a54b24924c32" +uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +version = "0.8.4" + [[Downloads]] deps = ["ArgTools", "LibCURL", "NetworkOptions"] uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" @@ -82,16 +88,16 @@ uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" version = "0.1.3" [[GPUArrays]] -deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization"] -git-tree-sha1 = "9c95b2fd5c16bc7f97371e9f92f0fef77e0f5957" +deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization", "Statistics"] +git-tree-sha1 = "df5b8569904c5c10e84c640984cfff054b18c086" uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -version = "6.2.2" +version = "6.4.1" [[GPUCompiler]] deps = ["DataStructures", "ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "Serialization", "TimerOutputs", "UUIDs"] -git-tree-sha1 = "6eadd2321dc3ac0fc9d530ab01c2caa7fe5d74c6" +git-tree-sha1 = "42d635f6d87af125b86288df3819f805fb4d851a" uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" -version = "0.11.4" +version = "0.11.5" [[InteractiveUtils]] deps = ["Markdown"] @@ -105,9 +111,9 @@ version = "1.3.0" [[LLVM]] deps = ["CEnum", "Libdl", "Printf", "Unicode"] -git-tree-sha1 = "b616937c31337576360cb9fb872ec7633af7b194" +git-tree-sha1 = "b499c68a45249b0385585c62f4a9b62b5db8e691" uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "3.6.0" +version = "3.7.1" [[LazyArtifacts]] deps = ["Artifacts", "Pkg"] @@ -136,6 +142,12 @@ uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" deps = ["Libdl"] uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +[[LogExpFunctions]] +deps = ["DocStringExtensions", "LinearAlgebra"] +git-tree-sha1 = "1ba664552f1ef15325e68dc4c05c3ef8c2d5d885" +uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +version = "0.2.4" + [[Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" @@ -181,9 +193,9 @@ uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" version = "0.5.4+0" [[OrderedCollections]] -git-tree-sha1 = "4fa2ba51070ec13fcc7517db714445b4ab986bdf" +git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c" uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.4.0" +version = "1.4.1" [[Pkg]] deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] @@ -191,9 +203,9 @@ uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" [[Preferences]] deps = ["TOML"] -git-tree-sha1 = "ea79e4c9077208cd3bc5d29631a26bc0cff78902" +git-tree-sha1 = "00cfd92944ca9c760982747e9a1d0d5d86ab1e5a" uuid = "21216c6a-2e73-6563-6e65-726566657250" -version = "1.2.1" +version = "1.2.2" [[Printf]] deps = ["Unicode"] @@ -207,6 +219,12 @@ uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" deps = ["Serialization"] uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +[[Random123]] +deps = ["Libdl", "Random", "RandomNumbers"] +git-tree-sha1 = "7c6710c8198fd4444b5eb6a3840b7d47bd3593c5" +uuid = "74087812-796a-5b5d-8853-05524746bad3" +version = "1.3.1" + [[RandomNumbers]] deps = ["Random", "Requires"] git-tree-sha1 = "441e6fc35597524ada7f85e13df1f4e10137d16f" @@ -248,10 +266,10 @@ deps = ["LinearAlgebra", "Random"] uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [[SpecialFunctions]] -deps = ["ChainRulesCore", "OpenSpecFun_jll"] -git-tree-sha1 = "5919936c0e92cff40e57d0ddf0ceb667d42e5902" +deps = ["ChainRulesCore", "LogExpFunctions", "OpenSpecFun_jll"] +git-tree-sha1 = "c467f25b6ec4167ea3a9a4351c66c2e1cba5da33" uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "1.3.0" +version = "1.4.1" [[Statistics]] deps = ["LinearAlgebra", "SparseArrays"] @@ -270,10 +288,10 @@ deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [[TimerOutputs]] -deps = ["Printf"] -git-tree-sha1 = "32cdbe6cd2d214c25a0b88f985c9e0092877c236" +deps = ["ExprTools", "Printf"] +git-tree-sha1 = "bf8aacc899a1bd16522d0350e1e2310510d77236" uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" -version = "0.5.8" +version = "0.5.9" [[UUIDs]] deps = ["Random", "SHA"] diff --git a/src/NNlibCUDA.jl b/src/NNlibCUDA.jl index b4cc05c..f98c08f 100644 --- a/src/NNlibCUDA.jl +++ b/src/NNlibCUDA.jl @@ -10,6 +10,7 @@ include("upsample.jl") include("activations.jl") include("batchedmul.jl") include("scatter.jl") +include("gather.jl") include("cudnn/cudnn.jl") include("cudnn/conv.jl") include("cudnn/pooling.jl") diff --git a/src/gather.jl b/src/gather.jl new file mode 100644 index 0000000..dcfd29b --- /dev/null +++ b/src/gather.jl @@ -0,0 +1,53 @@ +function gather_check_dims(X::AbstractArray{Tx,Nx}, + Y::AbstractArray{Ty,Ny}, + idx::AbstractArray{Tidx,Nidx}) where + {Tx,Ty,Tidx<:IntOrIntTuple,Nx,Ny,Nidx} + M = NNlib.typelength(Tidx) + dims = gather_check_dims(Nx, Ny, M, Nidx) + size(X)[1:dims] == size(Y)[1:dims] || throw(ArgumentError("Incompatible input shapes.")) + size(Y)[dims+1:end] == size(idx) || throw(ArgumentError("Incompatible input shapes.")) + return dims +end + +function gather_check_dims(X::AbstractArray{Tx,Nx}, + Y::AbstractArray{Ty,Ny}, + idx::AbstractArray{CartesianIndex{M},Nidx}) where + {Tx,Ty,Nx,Ny,M,Nidx} + dims = gather_check_dims(Nx, Ny, M, Nidx) + size(X)[1:dims] == size(Y)[1:dims] || throw(ArgumentError("Incompatible input shapes.")) + size(Y)[dims+1:end] == size(idx) || throw(ArgumentError("Incompatible input shapes.")) + return dims +end + +function gather_check_dims(Nx, Ny, M, Nidx) + @assert Nx - M == Ny - Nidx "Incompatible input shapes of (dst, src, idx) = ($Nx, $Ny, $Nidx)." + dims = Nx - M + dims < 0 && throw(ArgumentError("dims must be non-negative but got dims=$dims.")) + return dims +end + +function gather_kernel!(dst, src, idx, max_idx, max_dims_idx, dims_size) + index = threadIdx().x + (blockIdx().x - 1) * blockDim().x + + @inbounds if index <= max_idx + j, k = divrem(index-1, max_dims_idx) + dims_i = CartesianIndices(dims_size)[k+1] + dst[index] = src[dims_i, idx[j+1]...] + end + return nothing +end + +function NNlib.gather!(dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArray) + dims = gather_check_dims(src, dst, idx) + dims_size = size(src)[1:dims] + max_dims_idx = prod(dims_size) + max_idx = max_dims_idx * length(idx) + args = dst, src, idx, max_idx, max_dims_idx, dims_size + + kernel = @cuda launch=false gather_kernel!(args...) + config = launch_configuration(kernel.fun; max_threads=256) + threads = min(max_idx, config.threads) + blocks = cld(max_idx, threads) + kernel(args...; threads=threads, blocks=blocks) + return dst +end diff --git a/test/gather.jl b/test/gather.jl new file mode 100644 index 0000000..8dd2f20 --- /dev/null +++ b/test/gather.jl @@ -0,0 +1,60 @@ +@testset "gather" begin + T = Float32 + CT = CuArray{Float32} + + ## 1d src, 2d index of ints -> 2d output + src = CT([3, 4, 5, 6, 7]) + index = cu([1 2 3 4; + 4 2 1 3; + 3 5 5 3]) + output = CT([3 4 5 6; + 6 4 3 5; + 5 7 7 5]) + + y = NNlib.gather(src, index) + @test y isa CuArray{Float32,2} + @test size(y) == size(index) + gputest(src -> NNlib.gather(src, index), src, checkgrad=false) + @test NNlib.gather!(CUDA.zeros(T, size(index)...), src, index) == output + @test_throws ArgumentError NNlib.gather!(zeros(T, 3, 5), src, index) + + ## 1d src, 3d index of ints -> 3d output + src = CT([3, 4, 5, 6, 7]) + index = cu([1 2 3 4; + 4 2 1 3; + 3 5 5 3][:,:,1:1]) + output = CT([3 4 5 6; + 6 4 3 5; + 5 7 7 5][:,:,1:1]) + + y = NNlib.gather(src, index) + @test y isa CuArray{Float32,3} + @test size(y) == size(index) + gputest(src -> NNlib.gather(src, index), src, checkgrad=false) + + + ## 2d src, 2d index of ints -> 3d output + src = CT([3 5 7 + 4 6 8]) + index = cu([1 2 3; + 2 2 1; + 3 1 3]) + + output = zeros(T, 2, 3, 3) + + output[:,:,1] = [3 5 7 + 4 6 8] + + output[:,:,2] = [5 5 3 + 6 6 4] + + output[:,:,3] = [7 3 7 + 8 4 8] + + y = NNlib.gather(src, index) + M = NNlib.typelength(eltype(index)) + Nsrc = ndims(src) + @test y isa CuArray{Float32,3} + @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) + gputest(src -> NNlib.gather(src, index), src, checkgrad=false) +end diff --git a/test/runtests.jl b/test/runtests.jl index 778b875..0909bcd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -18,4 +18,5 @@ if CUDA.has_cuda() include("softmax.jl") include("batchnorm.jl") include("scatter.jl") + include("gather.jl") end