Skip to content

Commit

Permalink
refactor to current scatter API
Browse files Browse the repository at this point in the history
  • Loading branch information
yuehhua committed Mar 15, 2021
1 parent 38bbc23 commit 43d82eb
Showing 1 changed file with 19 additions and 104 deletions.
123 changes: 19 additions & 104 deletions lib/NNlibCUDA/src/scatter.jl
Original file line number Diff line number Diff line change
@@ -1,109 +1,24 @@
# Integer
for op = [:add, :sub, :max, :min, :and, :or, :xor]
fn = Symbol("scatter_$(op)!")
atm_op = Symbol("atomic_$(op)!")
@eval function $fn(ys::CuMatrix{T}, us::CuArray{T}, xs::CuArray{Int}) where {T<:Integer}
function kernel!(ys, us, xs)
li = threadIdx().y + (blockIdx().y - 1) * blockDim().y
i = threadIdx().x + (blockIdx().x - 1) * blockDim().x

@inbounds if li <= length(xs) && i <= size(ys, 1)
ind = CartesianIndices(xs)[li]
j = Base._to_linear_index(ys, i, xs[li])
CUDA.$atm_op(pointer(ys, j), us[i, ind])
end

return
end

thread_x = min(MAX_THREADS, size(ys, 1))
thread_y = min(MAX_THREADS ÷ thread_x, length(xs))
threads = (thread_x, thread_y)
blocks = ceil.(Int, (size(ys, 1), length(xs)) ./ threads)
@cuda blocks=blocks threads=threads kernel!(ys, us, xs)
return ys
end

@eval function $fn(ys::CuArray{T}, us::CuArray{T}, xs::CuArray{<:Tuple}) where {T<:Integer}
function kernel!(ys, us, xs)
li = threadIdx().y + (blockIdx().y - 1) * blockDim().y
i = threadIdx().x + (blockIdx().x - 1) * blockDim().x

@inbounds if li <= length(xs) && i <= size(ys, 1)
ind = CartesianIndices(xs)[li]
j = Base._to_linear_index(ys, i, xs[li]...)
CUDA.$atm_op(pointer(ys, j), us[i, ind])
end

return
end

thread_x = min(MAX_THREADS, size(ys, 1))
thread_y = min(MAX_THREADS ÷ thread_x, length(xs))
threads = (thread_x, thread_y)
blocks = ceil.(Int, (size(ys, 1), length(xs)) ./ threads)
@cuda blocks=blocks threads=threads kernel!(ys, us, xs)
return ys
end
end


# Floating point
for op = [:add, :sub, :mul, :div, :max, :min]
fn = Symbol("scatter_$(op)!")
atm_op = Symbol("atomic_$(op)!")
@eval function $fn(ys::CuMatrix{T}, us::CuArray{T}, xs::CuArray{Int}) where {T<:AbstractFloat}
function kernel!(ys::CuDeviceArray{T}, us::CuDeviceArray{T}, xs)
i = threadIdx().x + (blockIdx().x - 1) * blockDim().x
j = threadIdx().y + (blockIdx().y - 1) * blockDim().y

@inbounds if i <= size(ys, 1) && j <= length(xs)
ind = CartesianIndices(xs)[j]
k = Base._to_linear_index(ys, i, xs[j])
CUDA.$atm_op(pointer(ys, k), us[i, ind])
end

return
end

thread_i = min(MAX_THREADS, size(ys, 1))
thread_j = min(MAX_THREADS ÷ thread_i, length(xs))
threads = (thread_i, thread_j)
blocks = ceil.(Int, (size(ys, 1), length(xs)) ./ threads)
@cuda blocks=blocks threads=threads kernel!(ys, us, xs)
return ys
end

@eval function $fn(ys::CuArray{T}, us::CuArray{T}, xs::CuArray{<:Tuple}) where {T<:AbstractFloat}
function kernel!(ys::CuDeviceArray{T}, us::CuDeviceArray{T}, xs)
i = threadIdx().x + (blockIdx().x - 1) * blockDim().x
j = threadIdx().y + (blockIdx().y - 1) * blockDim().y

@inbounds if i <= size(ys, 1) && j <= length(xs)
ind = CartesianIndices(xs)[j]
k = Base._to_linear_index(ys, i, xs[j]...)
CUDA.$atm_op(pointer(ys, k), us[i, ind])
end

return
ATM_OPS = Dict((+) => CUDA.atomic_add!, (-) => CUDA.atomic_sub!, (max) => CUDA.atomic_max!, (min) => CUDA.atomic_min!,
(*) => CUDA.atomic_mul!, (/) => CUDA.atomic_div!, (&) => CUDA.atomic_and!, (|) => CUDA.atomic_or!)

function scatter!(op, dst::CuArray, src::CuArray, idx::CuArray{IntOrIntTuple})
function kernel!(atm_op, dst, src, idx)
li = threadIdx().y + (blockIdx().y - 1) * blockDim().y
i = threadIdx().x + (blockIdx().x - 1) * blockDim().x

@inbounds if li <= length(idx) && i <= size(dst, 1)
ind = CartesianIndices(idx)[li]
j = Base._to_linear_index(dst, i, idx[li]...)
atm_op(pointer(dst, j), src[i, ind])
end

thread_i = min(MAX_THREADS, size(ys, 1))
thread_j = min(MAX_THREADS ÷ thread_i, length(xs))
threads = (thread_i, thread_j)
blocks = ceil.(Int, (size(ys, 1), length(xs)) ./ threads)
@cuda blocks=blocks threads=threads kernel!(ys, us, xs)
return ys
return
end
end


function scatter_mean!(ys::CuMatrix{T}, us::CuArray{T}, xs::CuArray{<:IntOrTuple}) where {T<:AbstractFloat}
yt = CUDA.zero(ys)
ot = CUDA.zero(ys)
os = CUDA.one.(us)
scatter_add!(ot, os, xs)
scatter_add!(yt, us, xs)
ys .+= save_div.(yt, ot)
return ys
thread_x = min(MAX_THREADS, size(dst, 1))
thread_y = min(MAX_THREADS ÷ thread_x, length(idx))
threads = (thread_x, thread_y)
blocks = ceil.(Int, (size(dst, 1), length(idx)) ./ threads)
@cuda blocks=blocks threads=threads kernel!(ATM_OPS[op], dst, src, idx)
return dst
end

0 comments on commit 43d82eb

Please sign in to comment.