Skip to content

Commit

Permalink
migrate from ScatterNNlib
Browse files Browse the repository at this point in the history
  • Loading branch information
yuehhua committed Mar 15, 2021
1 parent ca82fb2 commit 38bbc23
Show file tree
Hide file tree
Showing 4 changed files with 243 additions and 0 deletions.
1 change: 1 addition & 0 deletions lib/NNlibCUDA/src/NNlibCUDA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using Random, Statistics
include("upsample.jl")
include("activations.jl")
include("batchedmul.jl")
include("scatter.jl")
include("cudnn/cudnn.jl")
include("cudnn/conv.jl")
include("cudnn/pooling.jl")
Expand Down
109 changes: 109 additions & 0 deletions lib/NNlibCUDA/src/scatter.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Integer
for op = [:add, :sub, :max, :min, :and, :or, :xor]
fn = Symbol("scatter_$(op)!")
atm_op = Symbol("atomic_$(op)!")
@eval function $fn(ys::CuMatrix{T}, us::CuArray{T}, xs::CuArray{Int}) where {T<:Integer}
function kernel!(ys, us, xs)
li = threadIdx().y + (blockIdx().y - 1) * blockDim().y
i = threadIdx().x + (blockIdx().x - 1) * blockDim().x

@inbounds if li <= length(xs) && i <= size(ys, 1)
ind = CartesianIndices(xs)[li]
j = Base._to_linear_index(ys, i, xs[li])
CUDA.$atm_op(pointer(ys, j), us[i, ind])
end

return
end

thread_x = min(MAX_THREADS, size(ys, 1))
thread_y = min(MAX_THREADS ÷ thread_x, length(xs))
threads = (thread_x, thread_y)
blocks = ceil.(Int, (size(ys, 1), length(xs)) ./ threads)
@cuda blocks=blocks threads=threads kernel!(ys, us, xs)
return ys
end

@eval function $fn(ys::CuArray{T}, us::CuArray{T}, xs::CuArray{<:Tuple}) where {T<:Integer}
function kernel!(ys, us, xs)
li = threadIdx().y + (blockIdx().y - 1) * blockDim().y
i = threadIdx().x + (blockIdx().x - 1) * blockDim().x

@inbounds if li <= length(xs) && i <= size(ys, 1)
ind = CartesianIndices(xs)[li]
j = Base._to_linear_index(ys, i, xs[li]...)
CUDA.$atm_op(pointer(ys, j), us[i, ind])
end

return
end

thread_x = min(MAX_THREADS, size(ys, 1))
thread_y = min(MAX_THREADS ÷ thread_x, length(xs))
threads = (thread_x, thread_y)
blocks = ceil.(Int, (size(ys, 1), length(xs)) ./ threads)
@cuda blocks=blocks threads=threads kernel!(ys, us, xs)
return ys
end
end


# Floating point
for op = [:add, :sub, :mul, :div, :max, :min]
fn = Symbol("scatter_$(op)!")
atm_op = Symbol("atomic_$(op)!")
@eval function $fn(ys::CuMatrix{T}, us::CuArray{T}, xs::CuArray{Int}) where {T<:AbstractFloat}
function kernel!(ys::CuDeviceArray{T}, us::CuDeviceArray{T}, xs)
i = threadIdx().x + (blockIdx().x - 1) * blockDim().x
j = threadIdx().y + (blockIdx().y - 1) * blockDim().y

@inbounds if i <= size(ys, 1) && j <= length(xs)
ind = CartesianIndices(xs)[j]
k = Base._to_linear_index(ys, i, xs[j])
CUDA.$atm_op(pointer(ys, k), us[i, ind])
end

return
end

thread_i = min(MAX_THREADS, size(ys, 1))
thread_j = min(MAX_THREADS ÷ thread_i, length(xs))
threads = (thread_i, thread_j)
blocks = ceil.(Int, (size(ys, 1), length(xs)) ./ threads)
@cuda blocks=blocks threads=threads kernel!(ys, us, xs)
return ys
end

@eval function $fn(ys::CuArray{T}, us::CuArray{T}, xs::CuArray{<:Tuple}) where {T<:AbstractFloat}
function kernel!(ys::CuDeviceArray{T}, us::CuDeviceArray{T}, xs)
i = threadIdx().x + (blockIdx().x - 1) * blockDim().x
j = threadIdx().y + (blockIdx().y - 1) * blockDim().y

@inbounds if i <= size(ys, 1) && j <= length(xs)
ind = CartesianIndices(xs)[j]
k = Base._to_linear_index(ys, i, xs[j]...)
CUDA.$atm_op(pointer(ys, k), us[i, ind])
end

return
end

thread_i = min(MAX_THREADS, size(ys, 1))
thread_j = min(MAX_THREADS ÷ thread_i, length(xs))
threads = (thread_i, thread_j)
blocks = ceil.(Int, (size(ys, 1), length(xs)) ./ threads)
@cuda blocks=blocks threads=threads kernel!(ys, us, xs)
return ys
end
end


function scatter_mean!(ys::CuMatrix{T}, us::CuArray{T}, xs::CuArray{<:IntOrTuple}) where {T<:AbstractFloat}
yt = CUDA.zero(ys)
ot = CUDA.zero(ys)
os = CUDA.one.(us)
scatter_add!(ot, os, xs)
scatter_add!(yt, us, xs)
ys .+= save_div.(yt, ot)
return ys
end
1 change: 1 addition & 0 deletions lib/NNlibCUDA/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ if CUDA.has_cuda()
include("pooling.jl")
include("softmax.jl")
include("batchnorm.jl")
include("scatter.jl")
end
132 changes: 132 additions & 0 deletions lib/NNlibCUDA/test/scatter.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
ys = cu([3 3 4 4 5;
5 5 6 6 7])
us = cu(ones(Int, 2, 3, 4))
xs = CuArray{Int64}([1 2 3 4;
4 2 1 3;
3 5 5 3])
xs_tup = CuArray([(1,) (2,) (3,) (4,);
(4,) (2,) (1,) (3,);
(3,) (5,) (5,) (3,)])


@testset "cuda/scatter" begin
for T = [UInt32, UInt64, Int32, Int64]
@testset "$(T)" begin
@testset "add" begin
ys_ = cu([5 5 8 6 7;
7 7 10 8 9])
@test scatter_add!(T.(copy(ys)), T.(us), xs) == T.(ys_)
@test scatter!(:add, T.(copy(ys)), T.(us), xs) == T.(ys_)

@test scatter_add!(T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
@test scatter!(:add, T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
end

@testset "sub" begin
ys_ = cu([1 1 0 2 3;
3 3 2 4 5])
@test scatter_sub!(T.(copy(ys)), T.(us), xs) == T.(ys_)
@test scatter!(:sub, T.(copy(ys)), T.(us), xs) == T.(ys_)

@test scatter_sub!(T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
@test scatter!(:sub, T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
end

@testset "max" begin
ys_ = cu([3 3 4 4 5;
5 5 6 6 7])
@test scatter_max!(T.(copy(ys)), T.(us), xs) == T.(ys_)
@test scatter!(:max, T.(copy(ys)), T.(us), xs) == T.(ys_)

@test scatter_max!(T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
@test scatter!(:max, T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
end

@testset "min" begin
ys_ = cu([1 1 1 1 1;
1 1 1 1 1])
@test scatter_min!(T.(copy(ys)), T.(us), xs) == T.(ys_)
@test scatter!(:min, T.(copy(ys)), T.(us), xs) == T.(ys_)

@test scatter_min!(T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
@test scatter!(:min, T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
end
end
end


for T = [Float32, Float64]
@testset "$(T)" begin
@testset "add" begin
ys_ = cu([5 5 8 6 7;
7 7 10 8 9])
@test scatter_add!(T.(copy(ys)), T.(us), xs) == T.(ys_)
@test scatter!(:add, T.(copy(ys)), T.(us), xs) == T.(ys_)

@test scatter_add!(T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
@test scatter!(:add, T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
end

@testset "sub" begin
ys_ = cu([1 1 0 2 3;
3 3 2 4 5])
@test scatter_sub!(T.(copy(ys)), T.(us), xs) == T.(ys_)
@test scatter!(:sub, T.(copy(ys)), T.(us), xs) == T.(ys_)

@test scatter_sub!(T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
@test scatter!(:sub, T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
end

@testset "max" begin
ys_ = cu([3 3 4 4 5;
5 5 6 6 7])
@test scatter_max!(T.(copy(ys)), T.(us), xs) == T.(ys_)
@test scatter!(:max, T.(copy(ys)), T.(us), xs) == T.(ys_)

@test scatter_max!(T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
@test scatter!(:max, T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
end

@testset "min" begin
ys_ = cu([1 1 1 1 1;
1 1 1 1 1])
@test scatter_min!(T.(copy(ys)), T.(us), xs) == T.(ys_)
@test scatter!(:min, T.(copy(ys)), T.(us), xs) == T.(ys_)

@test scatter_min!(T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
@test scatter!(:min, T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
end

@testset "mul" begin
ys_ = cu([3 3 4 4 5;
5 5 6 6 7])
@test scatter_mul!(T.(copy(ys)), T.(us), xs) == T.(ys_)
@test scatter!(:mul, T.(copy(ys)), T.(us), xs) == T.(ys_)

@test scatter_mul!(T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
@test scatter!(:mul, T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
end

@testset "div" begin
us_div = us .* 2
ys_ = cu([0.75 0.75 0.25 1. 1.25;
1.25 1.25 0.375 1.5 1.75])
@test scatter_div!(T.(copy(ys)), T.(us_div), xs) == T.(ys_)
@test scatter!(:div, T.(copy(ys)), T.(us_div), xs) == T.(ys_)

@test scatter_div!(T.(copy(ys)), T.(us_div), xs_tup) == T.(ys_)
@test scatter!(:div, T.(copy(ys)), T.(us_div), xs_tup) == T.(ys_)
end

@testset "mean" begin
ys_ = cu([4 4 5 5 6;
6 6 7 7 8])
@test scatter_mean!(T.(copy(ys)), T.(us), xs) == T.(ys_)
@test scatter!(:mean, T.(copy(ys)), T.(us), xs) == T.(ys_)

@test scatter_mean!(T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
@test scatter!(:mean, T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
end
end
end
end

0 comments on commit 38bbc23

Please sign in to comment.