diff --git a/lib/NNlibCUDA/test/scatter.jl b/lib/NNlibCUDA/test/scatter.jl index f8f56aa9b..c3bd28c98 100644 --- a/lib/NNlibCUDA/test/scatter.jl +++ b/lib/NNlibCUDA/test/scatter.jl @@ -1,55 +1,111 @@ -ys = cu([3 3 4 4 5; - 5 5 6 6 7]) -us = cu(ones(Int, 2, 3, 4)) -xs = CuArray{Int64}([1 2 3 4; - 4 2 1 3; - 3 5 5 3]) -xs_tup = CuArray([(1,) (2,) (3,) (4,); - (4,) (2,) (1,) (3,); - (3,) (5,) (5,) (3,)]) - - -@testset "cuda/scatter" begin - for T = [UInt32, UInt64, Int32, Int64] +dsts = Dict( + 0 => cu([3, 4, 5, 6, 7]), + 1 => cu([3 3 4 4 5; + 5 5 6 6 7]), +) +srcs = Dict( + (0, true) => cu(ones(Int, 3, 4)), + (0, false) => cu(ones(Int, 3) * collect(1:4)'), + (1, true) => cu(ones(Int, 2, 3, 4)), + (1, false) => cu([1, 2] .* reshape(ones(Int, 3) * collect(1:4)', 1,3,4)), +) +idxs = Dict( + :int => cu([1 2 3 4; + 4 2 1 3; + 3 5 5 3]), + :tup => cu([(1,) (2,) (3,) (4,); + (4,) (2,) (1,) (3,); + (3,) (5,) (5,) (3,)]), +) +res = Dict( + (+, 0, true) => cu([5, 6, 9, 8, 9]), + (+, 1, true) => cu([5 5 8 6 7; + 7 7 10 8 9]), + (+, 0, false) => cu([4, 4, 12, 5, 5]), + (+, 1, false) => cu([4 4 12 5 5; + 8 8 24 10 10]), + (-, 0, true) => cu([1, 2, 1, 4, 5]), + (-, 1, true) => cu([1 1 0 2 3; + 3 3 2 4 5]), + (-, 0, false) => cu([-4, -4, -12, -5, -5]), + (-, 1, false) => cu([-4 -4 -12 -5 -5; + -8 -8 -24 -10 -10]), + (max, 0, true) => cu([3, 4, 5, 6, 7]), + (max, 1, true) => cu([3 3 4 4 5; + 5 5 6 6 7]), + (max, 0, false) => cu([3, 2, 4, 4, 3]), + (max, 1, false) => cu([3 2 4 4 3; + 6 4 8 8 6]), + (min, 0, true) => cu([1, 1, 1, 1, 1]), + (min, 1, true) => cu([1 1 1 1 1; + 1 1 1 1 1]), + (min, 0, false) => cu([1, 2, 1, 1, 2]), + (min, 1, false) => cu([1 2 1 1 2; + 2 4 2 2 4]), + (*, 0, true) => cu([3, 4, 5, 6, 7]), + (*, 1, true) => cu([3 3 4 4 5; + 5 5 6 6 7]), + (*, 0, false) => cu([3, 4, 48, 4, 6]), + (*, 1, false) => cu([3 4 48 4 6; + 12 16 768 16 24]), + (/, 0, true) => cu([0.75, 1., 0.3125, 1.5, 1.75]), + (/, 1, true) => cu([0.75 0.75 0.25 1. 1.25; + 1.25 1.25 0.375 1.5 1.75]), + (/, 0, false) => cu([1//3, 1//4, 1//48, 1//4, 1//6]), + (/, 1, false) => cu([1//3 1//4 1//48 1//4 1//6; + 1//12 1//16 1//768 1//16 1//24]), + (mean, 0, true) => cu([4., 5., 6., 7., 8.]), + (mean, 1, true) => cu([4. 4. 5. 5. 6.; + 6. 6. 7. 7. 8.]), + (mean, 0, false) => cu([2, 2, 3, 2.5, 2.5]), + (mean, 1, false) => cu([2. 2. 3. 2.5 2.5; + 4. 4. 6. 5. 5.]), +) + +types = [UInt32, UInt64, Int32, Int64, Float32, Float64] + + +@testset "scatter" begin + for T = types @testset "$(T)" begin - @testset "add" begin - ys_ = cu([5 5 8 6 7; - 7 7 10 8 9]) - @test scatter_add!(T.(copy(ys)), T.(us), xs) == T.(ys_) - @test scatter!(:add, T.(copy(ys)), T.(us), xs) == T.(ys_) - - @test scatter_add!(T.(copy(ys)), T.(us), xs_tup) == T.(ys_) - @test scatter!(:add, T.(copy(ys)), T.(us), xs_tup) == T.(ys_) + @testset "+" begin + for idx = values(idxs), dims = [0, 1] + mutated = true + @test scatter!(+, T.(dsts[dims]), srcs[(dims, mutated)], idx) == T.(res[(+, dims, mutated)]) + + mutated = false + # @test scatter(+, srcs[(dims, mutated)], idx) == T.(res[(+, dims, mutated)]) + end end - @testset "sub" begin - ys_ = cu([1 1 0 2 3; - 3 3 2 4 5]) - @test scatter_sub!(T.(copy(ys)), T.(us), xs) == T.(ys_) - @test scatter!(:sub, T.(copy(ys)), T.(us), xs) == T.(ys_) + @testset "-" begin + for idx = values(idxs), dims = [0, 1] + mutated = true + @test scatter!(-, T.(dsts[dims]), srcs[(dims, mutated)], idx) == T.(res[(-, dims, mutated)]) - @test scatter_sub!(T.(copy(ys)), T.(us), xs_tup) == T.(ys_) - @test scatter!(:sub, T.(copy(ys)), T.(us), xs_tup) == T.(ys_) + mutated = false + # @test scatter(-, srcs[(dims, mutated)], idx) == T.(res[(-, dims, mutated)]) + end end @testset "max" begin - ys_ = cu([3 3 4 4 5; - 5 5 6 6 7]) - @test scatter_max!(T.(copy(ys)), T.(us), xs) == T.(ys_) - @test scatter!(:max, T.(copy(ys)), T.(us), xs) == T.(ys_) + for idx = values(idxs), dims = [0, 1] + mutated = true + @test scatter!(max, T.(dsts[dims]), srcs[(dims, mutated)], idx) == T.(res[(max, dims, mutated)]) - @test scatter_max!(T.(copy(ys)), T.(us), xs_tup) == T.(ys_) - @test scatter!(:max, T.(copy(ys)), T.(us), xs_tup) == T.(ys_) + mutated = false + # @test scatter(max, srcs[(dims, mutated)], idx) == T.(res[(max, dims, mutated)]) + end end @testset "min" begin - ys_ = cu([1 1 1 1 1; - 1 1 1 1 1]) - @test scatter_min!(T.(copy(ys)), T.(us), xs) == T.(ys_) - @test scatter!(:min, T.(copy(ys)), T.(us), xs) == T.(ys_) + for idx = values(idxs), dims = [0, 1] + mutated = true + @test scatter!(min, T.(dsts[dims]), srcs[(dims, mutated)], idx) == T.(res[(min, dims, mutated)]) - @test scatter_min!(T.(copy(ys)), T.(us), xs_tup) == T.(ys_) - @test scatter!(:min, T.(copy(ys)), T.(us), xs_tup) == T.(ys_) + mutated = false + # @test scatter(min, srcs[(dims, mutated)], idx) == T.(res[(min, dims, mutated)]) + end end end end @@ -57,75 +113,34 @@ xs_tup = CuArray([(1,) (2,) (3,) (4,); for T = [Float32, Float64] @testset "$(T)" begin - @testset "add" begin - ys_ = cu([5 5 8 6 7; - 7 7 10 8 9]) - @test scatter_add!(T.(copy(ys)), T.(us), xs) == T.(ys_) - @test scatter!(:add, T.(copy(ys)), T.(us), xs) == T.(ys_) - - @test scatter_add!(T.(copy(ys)), T.(us), xs_tup) == T.(ys_) - @test scatter!(:add, T.(copy(ys)), T.(us), xs_tup) == T.(ys_) + @testset "*" begin + for idx = values(idxs), dims = [0, 1] + mutated = true + @test scatter!(*, T.(dsts[dims]), srcs[(dims, mutated)], idx) == T.(res[(*, dims, mutated)]) + + mutated = false + # @test scatter(*, srcs[(dims, mutated)], idx) == T.(res[(*, dims, mutated)]) + end end - @testset "sub" begin - ys_ = cu([1 1 0 2 3; - 3 3 2 4 5]) - @test scatter_sub!(T.(copy(ys)), T.(us), xs) == T.(ys_) - @test scatter!(:sub, T.(copy(ys)), T.(us), xs) == T.(ys_) + @testset "/" begin + for idx = values(idxs), dims = [0, 1] + mutated = true + @test scatter!(/, T.(dsts[dims]), srcs[(dims, mutated)].*2, idx) == T.(res[(/, dims, mutated)]) - @test scatter_sub!(T.(copy(ys)), T.(us), xs_tup) == T.(ys_) - @test scatter!(:sub, T.(copy(ys)), T.(us), xs_tup) == T.(ys_) - end - - @testset "max" begin - ys_ = cu([3 3 4 4 5; - 5 5 6 6 7]) - @test scatter_max!(T.(copy(ys)), T.(us), xs) == T.(ys_) - @test scatter!(:max, T.(copy(ys)), T.(us), xs) == T.(ys_) - - @test scatter_max!(T.(copy(ys)), T.(us), xs_tup) == T.(ys_) - @test scatter!(:max, T.(copy(ys)), T.(us), xs_tup) == T.(ys_) - end - - @testset "min" begin - ys_ = cu([1 1 1 1 1; - 1 1 1 1 1]) - @test scatter_min!(T.(copy(ys)), T.(us), xs) == T.(ys_) - @test scatter!(:min, T.(copy(ys)), T.(us), xs) == T.(ys_) - - @test scatter_min!(T.(copy(ys)), T.(us), xs_tup) == T.(ys_) - @test scatter!(:min, T.(copy(ys)), T.(us), xs_tup) == T.(ys_) - end - - @testset "mul" begin - ys_ = cu([3 3 4 4 5; - 5 5 6 6 7]) - @test scatter_mul!(T.(copy(ys)), T.(us), xs) == T.(ys_) - @test scatter!(:mul, T.(copy(ys)), T.(us), xs) == T.(ys_) - - @test scatter_mul!(T.(copy(ys)), T.(us), xs_tup) == T.(ys_) - @test scatter!(:mul, T.(copy(ys)), T.(us), xs_tup) == T.(ys_) - end - - @testset "div" begin - us_div = us .* 2 - ys_ = cu([0.75 0.75 0.25 1. 1.25; - 1.25 1.25 0.375 1.5 1.75]) - @test scatter_div!(T.(copy(ys)), T.(us_div), xs) == T.(ys_) - @test scatter!(:div, T.(copy(ys)), T.(us_div), xs) == T.(ys_) - - @test scatter_div!(T.(copy(ys)), T.(us_div), xs_tup) == T.(ys_) - @test scatter!(:div, T.(copy(ys)), T.(us_div), xs_tup) == T.(ys_) + mutated = false + # @test scatter(/, srcs[(dims, mutated)], idx) == T.(res[(/, dims, mutated)]) + end end @testset "mean" begin - ys_ = cu([4 4 5 5 6; - 6 6 7 7 8]) - @test scatter_mean!(T.(copy(ys)), T.(us), xs) == T.(ys_) - @test scatter!(:mean, T.(copy(ys)), T.(us), xs) == T.(ys_) + for idx = values(idxs), dims = [0, 1] + mutated = true + @test scatter!(mean, T.(dsts[dims]), srcs[(dims, mutated)], idx) == T.(res[(mean, dims, mutated)]) - @test scatter_mean!(T.(copy(ys)), T.(us), xs_tup) == T.(ys_) - @test scatter!(:mean, T.(copy(ys)), T.(us), xs_tup) == T.(ys_) + mutated = false + # @test scatter(mean, srcs[(dims, mutated)], idx) == T.(res[(mean, dims, mutated)]) + end end end end