diff --git a/test/softmax.jl b/test/softmax.jl index cee8d030e..d969f6eff 100644 --- a/test/softmax.jl +++ b/test/softmax.jl @@ -5,6 +5,14 @@ using Statistics: mean @test softmax(Int[0, 0]) == [0.5, 0.5] end +@testset "softmax dims" begin + xs = rand(fill(2,5)...) + out = similar(xs) + for (fn!, fn) in [(softmax!, softmax), (logsoftmax!, logsoftmax)], i = 1:ndims(xs) + @test fn!(out, xs; dims = i) == fn(xs; dims = i) + end +end + @testset "softmax" begin xs = rand(5, 5) @test all(sum(softmax(xs), dims = 1) .≈ 1)