Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add scatter for CUDA support #1

Merged
merged 13 commits into from
May 2, 2021
3 changes: 3 additions & 0 deletions src/NNlibCUDA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@ using NNlib
using CUDA
using Random, Statistics

const IntOrIntTuple = Union{Integer, NTuple{N,<:Integer} where N}

include("upsample.jl")
include("activations.jl")
include("batchedmul.jl")
include("scatter.jl")
include("cudnn/cudnn.jl")
include("cudnn/conv.jl")
include("cudnn/pooling.jl")
Expand Down
40 changes: 40 additions & 0 deletions src/scatter.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
for op in [+, -, *, /, max, min, &, |]
@eval function scatter_kernel!(op::typeof($(op)), dst, src, idx)
yuehhua marked this conversation as resolved.
Show resolved Hide resolved
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x

@inbounds if index <= length(idx)
@atomic dst[idx[index]...] = $(op)(dst[idx[index]...], src[index])
end
return nothing
end

@eval function scatter_kernel!(op::typeof($(op)), dst, src, idx, dims::Val{N}, max_idx, max_dims_idx, dims_size) where {N}
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x

@inbounds if index <= max_idx
j, k = divrem(index-1, max_dims_idx)
dims_i = CartesianIndices(dims_size)[k+1]
@atomic dst[Tuple(dims_i)..., idx[j+1]...] = $(op)(dst[Tuple(dims_i)..., idx[j+1]...], src[index])
end
return nothing
end

@eval function NNlib.scatter!(op::typeof($(op)), dst::CuArray{Tdst}, src::CuArray{Tsrc}, idx::CuArray{<:IntOrIntTuple}, dims::Val{N}) where {Tdst,Tsrc,N}
yuehhua marked this conversation as resolved.
Show resolved Hide resolved
args = if N == 0
max_idx = length(idx)
op, dst, src, idx
else
dims_size = size(dst)[1:N]
max_dims_idx = prod(dims_size)
max_idx = max_dims_idx * length(idx)
op, dst, src, idx, dims, max_idx, max_dims_idx, dims_size
end

kernel = @cuda launch=false scatter_kernel!(args...)
config = launch_configuration(kernel.fun; max_threads=256)
threads = Base.min(max_idx, config.threads)
yuehhua marked this conversation as resolved.
Show resolved Hide resolved
blocks = ceil(Int, max_idx / threads)
kernel(args...; threads=threads, blocks=blocks)
return dst
end
end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using NNlib
using Zygote
using NNlibCUDA
using ForwardDiff: Dual
using Statistics: mean
using CUDA
CUDA.allowscalar(false)

Expand All @@ -16,4 +17,5 @@ if CUDA.has_cuda()
include("pooling.jl")
include("softmax.jl")
include("batchnorm.jl")
include("scatter.jl")
end
149 changes: 149 additions & 0 deletions test/scatter.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
dsts = Dict(
0 => cu([3, 4, 5, 6, 7]),
1 => cu([3 3 4 4 5;
5 5 6 6 7]),
)
srcs = Dict(
(0, true) => cu(ones(Int, 3, 4)),
(0, false) => cu(ones(Int, 3) * collect(1:4)'),
(1, true) => cu(ones(Int, 2, 3, 4)),
(1, false) => cu([1, 2] .* reshape(ones(Int, 3) * collect(1:4)', 1,3,4)),
)
idxs = [
cu([1 2 3 4;
4 2 1 3;
3 5 5 3]), # integer index
cu([(1,) (2,) (3,) (4,);
(4,) (2,) (1,) (3,);
(3,) (5,) (5,) (3,)]), # tuple index
]
res = Dict(
(+, 0, true) => cu([5, 6, 9, 8, 9]),
(+, 1, true) => cu([5 5 8 6 7;
7 7 10 8 9]),
(+, 0, false) => cu([4, 4, 12, 5, 5]),
(+, 1, false) => cu([4 4 12 5 5;
8 8 24 10 10]),
(-, 0, true) => cu([1, 2, 1, 4, 5]),
(-, 1, true) => cu([1 1 0 2 3;
3 3 2 4 5]),
(-, 0, false) => cu([-4, -4, -12, -5, -5]),
(-, 1, false) => cu([-4 -4 -12 -5 -5;
-8 -8 -24 -10 -10]),
(max, 0, true) => cu([3, 4, 5, 6, 7]),
(max, 1, true) => cu([3 3 4 4 5;
5 5 6 6 7]),
(max, 0, false) => cu([3, 2, 4, 4, 3]),
(max, 1, false) => cu([3 2 4 4 3;
6 4 8 8 6]),
(min, 0, true) => cu([1, 1, 1, 1, 1]),
(min, 1, true) => cu([1 1 1 1 1;
1 1 1 1 1]),
(min, 0, false) => cu([1, 2, 1, 1, 2]),
(min, 1, false) => cu([1 2 1 1 2;
2 4 2 2 4]),
(*, 0, true) => cu([3, 4, 5, 6, 7]),
(*, 1, true) => cu([3 3 4 4 5;
5 5 6 6 7]),
(*, 0, false) => cu([3, 4, 48, 4, 6]),
(*, 1, false) => cu([3 4 48 4 6;
12 16 768 16 24]),
(/, 0, true) => cu([0.75, 1., 0.3125, 1.5, 1.75]),
(/, 1, true) => cu([0.75 0.75 0.25 1. 1.25;
1.25 1.25 0.375 1.5 1.75]),
(/, 0, false) => cu([1//3, 1//4, 1//48, 1//4, 1//6]),
(/, 1, false) => cu([1//3 1//4 1//48 1//4 1//6;
1//12 1//16 1//768 1//16 1//24]),
(mean, 0, true) => cu([4., 5., 6., 7., 8.]),
(mean, 1, true) => cu([4. 4. 5. 5. 6.;
6. 6. 7. 7. 8.]),
(mean, 0, false) => cu([2, 2, 3, 2.5, 2.5]),
(mean, 1, false) => cu([2. 2. 3. 2.5 2.5;
4. 4. 6. 5. 5.]),
)

types = [CuArray{UInt32}, CuArray{UInt64},
CuArray{Int32}, CuArray{Int64},
CuArray{Float32}, CuArray{Float64}]


@testset "scatter" begin
for T = types
@testset "$(T)" begin
@testset "+" begin
for idx = idxs, dims = [0, 1]
mutated = true
@test NNlib.scatter!(+, T(dsts[dims]), T(srcs[(dims, mutated)]), idx) == T(res[(+, dims, mutated)])

mutated = false
# @test scatter(+, srcs[(dims, mutated)], idx) == T(res[(+, dims, mutated)])
end
end

@testset "-" begin
for idx = idxs, dims = [0, 1]
mutated = true
@test NNlib.scatter!(-, T(dsts[dims]), T(srcs[(dims, mutated)]), idx) == T(res[(-, dims, mutated)])

mutated = false
# @test scatter(-, srcs[(dims, mutated)], idx) == T(res[(-, dims, mutated)])
end
end

@testset "max" begin
for idx = idxs, dims = [0, 1]
mutated = true
@test NNlib.scatter!(max, T(dsts[dims]), T(srcs[(dims, mutated)]), idx) == T(res[(max, dims, mutated)])

mutated = false
# @test scatter(max, srcs[(dims, mutated)], idx) == T(res[(max, dims, mutated)])
end
end

@testset "min" begin
for idx = idxs, dims = [0, 1]
mutated = true
@test NNlib.scatter!(min, T(dsts[dims]), T(srcs[(dims, mutated)]), idx) == T(res[(min, dims, mutated)])

mutated = false
# @test scatter(min, srcs[(dims, mutated)], idx) == T(res[(min, dims, mutated)])
end
end
end
end


for T = [CuArray{Float32}, CuArray{Float64}]
@testset "$(T)" begin
@testset "*" begin
for idx = idxs, dims = [0, 1]
mutated = true
@test NNlib.scatter!(*, T(dsts[dims]), T(srcs[(dims, mutated)]), idx) == T(res[(*, dims, mutated)])

mutated = false
# @test scatter(*, srcs[(dims, mutated)], idx) == T(res[(*, dims, mutated)])
end
end

@testset "/" begin
for idx = idxs, dims = [0, 1]
mutated = true
@test NNlib.scatter!(/, T(dsts[dims]), T(srcs[(dims, mutated)].*2), idx) == T(res[(/, dims, mutated)])

mutated = false
# @test scatter(/, srcs[(dims, mutated)], idx) == T(res[(/, dims, mutated)])
end
end

@testset "mean" begin
for idx = idxs, dims = [0, 1]
mutated = true
@test NNlib.scatter!(mean, T(dsts[dims]), T(srcs[(dims, mutated)]), idx) == T(res[(mean, dims, mutated)])

mutated = false
# @test scatter(mean, srcs[(dims, mutated)], idx) == T(res[(mean, dims, mutated)])
end
end
end
end
end