diff --git a/lib/NNlibCUDA/src/scatter.jl b/lib/NNlibCUDA/src/scatter.jl index 40c92f8f9..b847802f8 100644 --- a/lib/NNlibCUDA/src/scatter.jl +++ b/lib/NNlibCUDA/src/scatter.jl @@ -1,109 +1,24 @@ -# Integer -for op = [:add, :sub, :max, :min, :and, :or, :xor] - fn = Symbol("scatter_$(op)!") - atm_op = Symbol("atomic_$(op)!") - @eval function $fn(ys::CuMatrix{T}, us::CuArray{T}, xs::CuArray{Int}) where {T<:Integer} - function kernel!(ys, us, xs) - li = threadIdx().y + (blockIdx().y - 1) * blockDim().y - i = threadIdx().x + (blockIdx().x - 1) * blockDim().x - - @inbounds if li <= length(xs) && i <= size(ys, 1) - ind = CartesianIndices(xs)[li] - j = Base._to_linear_index(ys, i, xs[li]) - CUDA.$atm_op(pointer(ys, j), us[i, ind]) - end - - return - end - - thread_x = min(MAX_THREADS, size(ys, 1)) - thread_y = min(MAX_THREADS ÷ thread_x, length(xs)) - threads = (thread_x, thread_y) - blocks = ceil.(Int, (size(ys, 1), length(xs)) ./ threads) - @cuda blocks=blocks threads=threads kernel!(ys, us, xs) - return ys - end - - @eval function $fn(ys::CuArray{T}, us::CuArray{T}, xs::CuArray{<:Tuple}) where {T<:Integer} - function kernel!(ys, us, xs) - li = threadIdx().y + (blockIdx().y - 1) * blockDim().y - i = threadIdx().x + (blockIdx().x - 1) * blockDim().x - - @inbounds if li <= length(xs) && i <= size(ys, 1) - ind = CartesianIndices(xs)[li] - j = Base._to_linear_index(ys, i, xs[li]...) - CUDA.$atm_op(pointer(ys, j), us[i, ind]) - end - - return - end - - thread_x = min(MAX_THREADS, size(ys, 1)) - thread_y = min(MAX_THREADS ÷ thread_x, length(xs)) - threads = (thread_x, thread_y) - blocks = ceil.(Int, (size(ys, 1), length(xs)) ./ threads) - @cuda blocks=blocks threads=threads kernel!(ys, us, xs) - return ys - end -end - - -# Floating point -for op = [:add, :sub, :mul, :div, :max, :min] - fn = Symbol("scatter_$(op)!") - atm_op = Symbol("atomic_$(op)!") - @eval function $fn(ys::CuMatrix{T}, us::CuArray{T}, xs::CuArray{Int}) where {T<:AbstractFloat} - function kernel!(ys::CuDeviceArray{T}, us::CuDeviceArray{T}, xs) - i = threadIdx().x + (blockIdx().x - 1) * blockDim().x - j = threadIdx().y + (blockIdx().y - 1) * blockDim().y - - @inbounds if i <= size(ys, 1) && j <= length(xs) - ind = CartesianIndices(xs)[j] - k = Base._to_linear_index(ys, i, xs[j]) - CUDA.$atm_op(pointer(ys, k), us[i, ind]) - end - - return - end - - thread_i = min(MAX_THREADS, size(ys, 1)) - thread_j = min(MAX_THREADS ÷ thread_i, length(xs)) - threads = (thread_i, thread_j) - blocks = ceil.(Int, (size(ys, 1), length(xs)) ./ threads) - @cuda blocks=blocks threads=threads kernel!(ys, us, xs) - return ys - end - - @eval function $fn(ys::CuArray{T}, us::CuArray{T}, xs::CuArray{<:Tuple}) where {T<:AbstractFloat} - function kernel!(ys::CuDeviceArray{T}, us::CuDeviceArray{T}, xs) - i = threadIdx().x + (blockIdx().x - 1) * blockDim().x - j = threadIdx().y + (blockIdx().y - 1) * blockDim().y - - @inbounds if i <= size(ys, 1) && j <= length(xs) - ind = CartesianIndices(xs)[j] - k = Base._to_linear_index(ys, i, xs[j]...) - CUDA.$atm_op(pointer(ys, k), us[i, ind]) - end - - return +ATM_OPS = Dict((+) => CUDA.atomic_add!, (-) => CUDA.atomic_sub!, (max) => CUDA.atomic_max!, (min) => CUDA.atomic_min!, + (*) => CUDA.atomic_mul!, (/) => CUDA.atomic_div!, (&) => CUDA.atomic_and!, (|) => CUDA.atomic_or!) + +function scatter!(op, dst::CuArray, src::CuArray, idx::CuArray{IntOrIntTuple}) + function kernel!(atm_op, dst, src, idx) + li = threadIdx().y + (blockIdx().y - 1) * blockDim().y + i = threadIdx().x + (blockIdx().x - 1) * blockDim().x + + @inbounds if li <= length(idx) && i <= size(dst, 1) + ind = CartesianIndices(idx)[li] + j = Base._to_linear_index(dst, i, idx[li]...) + atm_op(pointer(dst, j), src[i, ind]) end - thread_i = min(MAX_THREADS, size(ys, 1)) - thread_j = min(MAX_THREADS ÷ thread_i, length(xs)) - threads = (thread_i, thread_j) - blocks = ceil.(Int, (size(ys, 1), length(xs)) ./ threads) - @cuda blocks=blocks threads=threads kernel!(ys, us, xs) - return ys + return end -end - -function scatter_mean!(ys::CuMatrix{T}, us::CuArray{T}, xs::CuArray{<:IntOrTuple}) where {T<:AbstractFloat} - yt = CUDA.zero(ys) - ot = CUDA.zero(ys) - os = CUDA.one.(us) - scatter_add!(ot, os, xs) - scatter_add!(yt, us, xs) - ys .+= save_div.(yt, ot) - return ys + thread_x = min(MAX_THREADS, size(dst, 1)) + thread_y = min(MAX_THREADS ÷ thread_x, length(idx)) + threads = (thread_x, thread_y) + blocks = ceil.(Int, (size(dst, 1), length(idx)) ./ threads) + @cuda blocks=blocks threads=threads kernel!(ATM_OPS[op], dst, src, idx) + return dst end