Skip to content

Commit 480833c

Browse files
committed
second derivative tests
1 parent a74f86d commit 480833c

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

test/layers/basic.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,3 +269,18 @@ import Flux: activations
269269
@test_throws DimensionMismatch m(OneHotVector(3, 1000))
270270
end
271271
end
272+
273+
@testset "second derivatives" begin
274+
m1 = Chain(Dense(3,4,tanh; bias=false), Dense(4,2))
275+
@test Zygote.hessian_dual(summ1, [1,2,3]) Zygote.hessian_reverse(summ1, [1,2,3])
276+
277+
# NNlib's softmax gradient writes in-place
278+
m2 = Chain(Dense(3,4,tanh), Dense(4,2), softmax)
279+
@test_broken Zygote.hessian_dual(summ2, [1,2,3]) Zygote.hessian_reverse(summ2, [1,2,3])
280+
281+
# https://github.com/FluxML/NNlib.jl/issues/362
282+
m3 = Chain(Conv((3,), 2 => 3, relu), Dense(2,2))
283+
x3 = cat(Float32[1 2; 3 4; 5 6; 7 8]; dims=3)
284+
@test_broken Zygote.hessian_dual(summ3, x3) Zygote.hessian_reverse(summ3, x3)
285+
end
286+

0 commit comments

Comments
 (0)