|
57 | 57 | @test g6.a isa Vector{Float64}
|
58 | 58 | @test g6.b == [0+im]
|
59 | 59 |
|
60 |
| - # Second derivative -- no method matching rrule(::typeof(Optimisers._rebuild), ...? |
61 |
| - @test_broken gradient([1,2,3]) do v |
62 |
| - sum(abs2, gradient(m -> sum(abs2, destructure(m)[1]), (v, [4,5,6]))[1][1]) |
63 |
| - end[1] ≈ [8,16,24] |
| 60 | + @testset "second derivative" begin |
| 61 | + @test_broken gradient([1,2,3.0]) do v |
| 62 | + sum(abs2, gradient(m -> sum(abs2, destructure(m)[1]), (v, [4,5,6.0]))[1][1]) |
| 63 | + end[1] ≈ [8,16,24] |
| 64 | + |
| 65 | + @test_skip gradient([1,2,3.0]) do v |
| 66 | + sum(gradient(m -> sum(destructure(m)[1]), (v, [4,5,6.0]))[1][1]) |
| 67 | + end |
| 68 | + end |
64 | 69 | end
|
65 | 70 |
|
66 | 71 | @testset "gradient of rebuild" begin
|
|
85 | 90 | @test gradient(x -> re7(x).b[2][2], rand(3))[1] == [0,0,0]
|
86 | 91 | @test gradient(x -> re7(x).c[2][1], rand(3))[1] == [0,0,0]
|
87 | 92 |
|
88 |
| - # Second derivative -- error from _tryaxes(x::Tangent) in Zygote's map rule |
89 |
| - @test_broken gradient(collect(1:6)) do y |
90 |
| - sum(abs2, gradient(x -> sum(abs2, re2(x)[1]), y)[1]) |
91 |
| - end[1] ≈ [8,16,24,0,0,0] |
| 93 | + @testset "second derivative" begin |
| 94 | + # ERROR: Need an adjoint for constructor ChainRulesCore.Tangent{Any, Tuple{Vector{Float64}, ChainRulesCore.ZeroTangent}}. Gradient is of type Tuple{Vector{Float64}, Vector{Float64}} |
| 95 | + @test_broken gradient(collect(1:6.0)) do y |
| 96 | + sum(abs2, gradient(x -> sum(abs2, re2(x)[1]), y)[1]) |
| 97 | + end[1] ≈ [8,16,24,0,0,0] |
| 98 | + # This fixes it! |
| 99 | + # Zygote.@adjoint Tangent{T,B}(x::Tuple) where {T,B<:Tuple} = Tangent{T,B}(x), dx -> (dx,) |
| 100 | + @test_skip gradient(collect(1:6.0)) do y |
| 101 | + sum(abs2, gradient(x -> sum(abs2, re3(x).z), y)[1]) |
| 102 | + end[1] |
| 103 | + # Zygote.@adjoint Tangent{T,B}(x::NamedTuple) where {T,B<:NamedTuple} = Tangent{T,B}(x), dx -> (dx,) |
| 104 | + end |
92 | 105 | end
|
93 | 106 |
|
94 | 107 | @testset "Flux issue 1826" begin
|
|
0 commit comments