Skip to content

Commit

Permalink
refactor test cases.
Browse files Browse the repository at this point in the history
  • Loading branch information
norci committed Dec 20, 2020
1 parent 6d48ff8 commit 5a9f818
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions test/softmax.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ using Statistics: mean
@test softmax(Int[0, 0]) == [0.5, 0.5]
end

@testset "softmax dims" begin
xs = rand(fill(2,5)...)
@testset "softmax on different dims" begin
xs = rand(fill(2, 5)...)
out = similar(xs)
for (fn!, fn) in [(softmax!, softmax), (logsoftmax!, logsoftmax)], i = 1:ndims(xs)
@test fn!(out, xs; dims = i) == fn(xs; dims = i)
Expand Down Expand Up @@ -48,8 +48,8 @@ end
-0.930163 0.0519798 0.0549979 0.3799 -0.477112 0.437428
0.69246 0.569494 -0.503191 -0.925947 -0.0870738 -1.0697
]
@test isapprox(∇logsoftmax(ones(size(xs)), xs), ys; rtol = 1e-6)
@test isapprox(∇softmax(ones(size(xs)), xs), zeros(size(xs)); atol = 1e-6)
@test ∇logsoftmax(ones(size(xs)), xs) ys rtol = 1e-6
@test ∇softmax(ones(size(xs)), xs) zeros(size(xs)) atol = 1e-6
end

@testset "mutating softmax" begin
Expand All @@ -63,16 +63,16 @@ end
]) do xs
out = similar(xs)
softmax!(out, xs)
@test isapprox(out, softmax(xs); rtol = 1e-6)
@test out softmax(xs) rtol = 1e-6
logsoftmax!(out, xs)
@test isapprox(out, logsoftmax(xs); rtol = 1e-6)
@test out logsoftmax(xs) rtol = 1e-6

map([zeros, ones]) do fn
Δ = fn(Float64, size(xs))
∇softmax!(out, Δ, xs)
@test isapprox(out, ∇softmax(Δ, xs); rtol = 1e-6)
@test out ∇softmax(Δ, xs) rtol = 1e-6
∇logsoftmax!(out, Δ, xs)
@test isapprox(out, ∇logsoftmax(Δ, xs); rtol = 1e-6)
@test out ∇logsoftmax(Δ, xs) rtol = 1e-6
end
end
end
Expand Down

0 comments on commit 5a9f818

Please sign in to comment.