Skip to content

Commit

Permalink
refactor test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
yuehhua committed Mar 16, 2021
1 parent 43d82eb commit 29f113d
Showing 1 changed file with 117 additions and 102 deletions.
219 changes: 117 additions & 102 deletions lib/NNlibCUDA/test/scatter.jl
Original file line number Diff line number Diff line change
@@ -1,131 +1,146 @@
ys = cu([3 3 4 4 5;
5 5 6 6 7])
us = cu(ones(Int, 2, 3, 4))
xs = CuArray{Int64}([1 2 3 4;
4 2 1 3;
3 5 5 3])
xs_tup = CuArray([(1,) (2,) (3,) (4,);
(4,) (2,) (1,) (3,);
(3,) (5,) (5,) (3,)])


@testset "cuda/scatter" begin
for T = [UInt32, UInt64, Int32, Int64]
dsts = Dict(
0 => cu([3, 4, 5, 6, 7]),
1 => cu([3 3 4 4 5;
5 5 6 6 7]),
)
srcs = Dict(
(0, true) => cu(ones(Int, 3, 4)),
(0, false) => cu(ones(Int, 3) * collect(1:4)'),
(1, true) => cu(ones(Int, 2, 3, 4)),
(1, false) => cu([1, 2] .* reshape(ones(Int, 3) * collect(1:4)', 1,3,4)),
)
idxs = Dict(
:int => cu([1 2 3 4;
4 2 1 3;
3 5 5 3]),
:tup => cu([(1,) (2,) (3,) (4,);
(4,) (2,) (1,) (3,);
(3,) (5,) (5,) (3,)]),
)
res = Dict(
(+, 0, true) => cu([5, 6, 9, 8, 9]),
(+, 1, true) => cu([5 5 8 6 7;
7 7 10 8 9]),
(+, 0, false) => cu([4, 4, 12, 5, 5]),
(+, 1, false) => cu([4 4 12 5 5;
8 8 24 10 10]),
(-, 0, true) => cu([1, 2, 1, 4, 5]),
(-, 1, true) => cu([1 1 0 2 3;
3 3 2 4 5]),
(-, 0, false) => cu([-4, -4, -12, -5, -5]),
(-, 1, false) => cu([-4 -4 -12 -5 -5;
-8 -8 -24 -10 -10]),
(max, 0, true) => cu([3, 4, 5, 6, 7]),
(max, 1, true) => cu([3 3 4 4 5;
5 5 6 6 7]),
(max, 0, false) => cu([3, 2, 4, 4, 3]),
(max, 1, false) => cu([3 2 4 4 3;
6 4 8 8 6]),
(min, 0, true) => cu([1, 1, 1, 1, 1]),
(min, 1, true) => cu([1 1 1 1 1;
1 1 1 1 1]),
(min, 0, false) => cu([1, 2, 1, 1, 2]),
(min, 1, false) => cu([1 2 1 1 2;
2 4 2 2 4]),
(*, 0, true) => cu([3, 4, 5, 6, 7]),
(*, 1, true) => cu([3 3 4 4 5;
5 5 6 6 7]),
(*, 0, false) => cu([3, 4, 48, 4, 6]),
(*, 1, false) => cu([3 4 48 4 6;
12 16 768 16 24]),
(/, 0, true) => cu([0.75, 1., 0.3125, 1.5, 1.75]),
(/, 1, true) => cu([0.75 0.75 0.25 1. 1.25;
1.25 1.25 0.375 1.5 1.75]),
(/, 0, false) => cu([1//3, 1//4, 1//48, 1//4, 1//6]),
(/, 1, false) => cu([1//3 1//4 1//48 1//4 1//6;
1//12 1//16 1//768 1//16 1//24]),
(mean, 0, true) => cu([4., 5., 6., 7., 8.]),
(mean, 1, true) => cu([4. 4. 5. 5. 6.;
6. 6. 7. 7. 8.]),
(mean, 0, false) => cu([2, 2, 3, 2.5, 2.5]),
(mean, 1, false) => cu([2. 2. 3. 2.5 2.5;
4. 4. 6. 5. 5.]),
)

types = [UInt32, UInt64, Int32, Int64, Float32, Float64]


@testset "scatter" begin
for T = types
@testset "$(T)" begin
@testset "add" begin
ys_ = cu([5 5 8 6 7;
7 7 10 8 9])
@test scatter_add!(T.(copy(ys)), T.(us), xs) == T.(ys_)
@test scatter!(:add, T.(copy(ys)), T.(us), xs) == T.(ys_)

@test scatter_add!(T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
@test scatter!(:add, T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
@testset "+" begin
for idx = values(idxs), dims = [0, 1]
mutated = true
@test scatter!(+, T.(dsts[dims]), srcs[(dims, mutated)], idx) == T.(res[(+, dims, mutated)])

mutated = false
# @test scatter(+, srcs[(dims, mutated)], idx) == T.(res[(+, dims, mutated)])
end
end

@testset "sub" begin
ys_ = cu([1 1 0 2 3;
3 3 2 4 5])
@test scatter_sub!(T.(copy(ys)), T.(us), xs) == T.(ys_)
@test scatter!(:sub, T.(copy(ys)), T.(us), xs) == T.(ys_)
@testset "-" begin
for idx = values(idxs), dims = [0, 1]
mutated = true
@test scatter!(-, T.(dsts[dims]), srcs[(dims, mutated)], idx) == T.(res[(-, dims, mutated)])

@test scatter_sub!(T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
@test scatter!(:sub, T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
mutated = false
# @test scatter(-, srcs[(dims, mutated)], idx) == T.(res[(-, dims, mutated)])
end
end

@testset "max" begin
ys_ = cu([3 3 4 4 5;
5 5 6 6 7])
@test scatter_max!(T.(copy(ys)), T.(us), xs) == T.(ys_)
@test scatter!(:max, T.(copy(ys)), T.(us), xs) == T.(ys_)
for idx = values(idxs), dims = [0, 1]
mutated = true
@test scatter!(max, T.(dsts[dims]), srcs[(dims, mutated)], idx) == T.(res[(max, dims, mutated)])

@test scatter_max!(T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
@test scatter!(:max, T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
mutated = false
# @test scatter(max, srcs[(dims, mutated)], idx) == T.(res[(max, dims, mutated)])
end
end

@testset "min" begin
ys_ = cu([1 1 1 1 1;
1 1 1 1 1])
@test scatter_min!(T.(copy(ys)), T.(us), xs) == T.(ys_)
@test scatter!(:min, T.(copy(ys)), T.(us), xs) == T.(ys_)
for idx = values(idxs), dims = [0, 1]
mutated = true
@test scatter!(min, T.(dsts[dims]), srcs[(dims, mutated)], idx) == T.(res[(min, dims, mutated)])

@test scatter_min!(T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
@test scatter!(:min, T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
mutated = false
# @test scatter(min, srcs[(dims, mutated)], idx) == T.(res[(min, dims, mutated)])
end
end
end
end


for T = [Float32, Float64]
@testset "$(T)" begin
@testset "add" begin
ys_ = cu([5 5 8 6 7;
7 7 10 8 9])
@test scatter_add!(T.(copy(ys)), T.(us), xs) == T.(ys_)
@test scatter!(:add, T.(copy(ys)), T.(us), xs) == T.(ys_)

@test scatter_add!(T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
@test scatter!(:add, T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
@testset "*" begin
for idx = values(idxs), dims = [0, 1]
mutated = true
@test scatter!(*, T.(dsts[dims]), srcs[(dims, mutated)], idx) == T.(res[(*, dims, mutated)])

mutated = false
# @test scatter(*, srcs[(dims, mutated)], idx) == T.(res[(*, dims, mutated)])
end
end

@testset "sub" begin
ys_ = cu([1 1 0 2 3;
3 3 2 4 5])
@test scatter_sub!(T.(copy(ys)), T.(us), xs) == T.(ys_)
@test scatter!(:sub, T.(copy(ys)), T.(us), xs) == T.(ys_)
@testset "/" begin
for idx = values(idxs), dims = [0, 1]
mutated = true
@test scatter!(/, T.(dsts[dims]), srcs[(dims, mutated)].*2, idx) == T.(res[(/, dims, mutated)])

@test scatter_sub!(T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
@test scatter!(:sub, T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
end

@testset "max" begin
ys_ = cu([3 3 4 4 5;
5 5 6 6 7])
@test scatter_max!(T.(copy(ys)), T.(us), xs) == T.(ys_)
@test scatter!(:max, T.(copy(ys)), T.(us), xs) == T.(ys_)

@test scatter_max!(T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
@test scatter!(:max, T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
end

@testset "min" begin
ys_ = cu([1 1 1 1 1;
1 1 1 1 1])
@test scatter_min!(T.(copy(ys)), T.(us), xs) == T.(ys_)
@test scatter!(:min, T.(copy(ys)), T.(us), xs) == T.(ys_)

@test scatter_min!(T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
@test scatter!(:min, T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
end

@testset "mul" begin
ys_ = cu([3 3 4 4 5;
5 5 6 6 7])
@test scatter_mul!(T.(copy(ys)), T.(us), xs) == T.(ys_)
@test scatter!(:mul, T.(copy(ys)), T.(us), xs) == T.(ys_)

@test scatter_mul!(T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
@test scatter!(:mul, T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
end

@testset "div" begin
us_div = us .* 2
ys_ = cu([0.75 0.75 0.25 1. 1.25;
1.25 1.25 0.375 1.5 1.75])
@test scatter_div!(T.(copy(ys)), T.(us_div), xs) == T.(ys_)
@test scatter!(:div, T.(copy(ys)), T.(us_div), xs) == T.(ys_)

@test scatter_div!(T.(copy(ys)), T.(us_div), xs_tup) == T.(ys_)
@test scatter!(:div, T.(copy(ys)), T.(us_div), xs_tup) == T.(ys_)
mutated = false
# @test scatter(/, srcs[(dims, mutated)], idx) == T.(res[(/, dims, mutated)])
end
end

@testset "mean" begin
ys_ = cu([4 4 5 5 6;
6 6 7 7 8])
@test scatter_mean!(T.(copy(ys)), T.(us), xs) == T.(ys_)
@test scatter!(:mean, T.(copy(ys)), T.(us), xs) == T.(ys_)
for idx = values(idxs), dims = [0, 1]
mutated = true
@test scatter!(mean, T.(dsts[dims]), srcs[(dims, mutated)], idx) == T.(res[(mean, dims, mutated)])

@test scatter_mean!(T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
@test scatter!(:mean, T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
mutated = false
# @test scatter(mean, srcs[(dims, mutated)], idx) == T.(res[(mean, dims, mutated)])
end
end
end
end
Expand Down

0 comments on commit 29f113d

Please sign in to comment.