Skip to content

Commit

Permalink
try to unify the scatter API
Browse files Browse the repository at this point in the history
  • Loading branch information
yuehhua committed Mar 25, 2021
1 parent ce63fb3 commit d36c18e
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 1 deletion.
2 changes: 1 addition & 1 deletion lib/NNlibCUDA/src/scatter.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
ATM_OPS = Dict((+) => CUDA.atomic_add!, (-) => CUDA.atomic_sub!, (max) => CUDA.atomic_max!, (min) => CUDA.atomic_min!,
(*) => CUDA.atomic_mul!, (/) => CUDA.atomic_div!, (&) => CUDA.atomic_and!, (|) => CUDA.atomic_or!)

function scatter!(op, dst::CuArray, src::CuArray, idx::CuArray{IntOrIntTuple})
function scatter!(op, dst::CuArray{Tdst}, src::CuArray{Tsrc}, idx::CuArray{<:IntOrIntTuple}, dims::Val{N}) where {Tdst,Tsrc,N}
function kernel!(atm_op, dst, src, idx)
li = threadIdx().y + (blockIdx().y - 1) * blockDim().y
i = threadIdx().x + (blockIdx().x - 1) * blockDim().x
Expand Down
1 change: 1 addition & 0 deletions lib/NNlibCUDA/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using NNlib
using Zygote
using NNlibCUDA
using ForwardDiff: Dual
using Statistics: mean
using CUDA
CUDA.allowscalar(false)

Expand Down

0 comments on commit d36c18e

Please sign in to comment.