Skip to content

Commit 6e4f634

Browse files
committed
second derivatives
1 parent 17b57f0 commit 6e4f634

File tree

2 files changed

+26
-8
lines changed

2 files changed

+26
-8
lines changed

src/destructure.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,3 +128,8 @@ end
128128
_grad!(x, dx::Zero, off, flat::AbstractVector) = nothing
129129
_grad!(x, dx::Zero, off::Integer, flat::AbstractVector) = nothing # ambiguity
130130

131+
function ChainRulesCore.rrule(::typeof(_grad!), x, dx, off, flat)
132+
println("grad! fwd ", length(flat))
133+
_grad_back(dflat) = (NoT, NoT, _rebuild(x, off, unthunk(dflat); walk = _Tangent_biwalk, prune = NoT), NoT, NoT)
134+
_grad!(x, dx, off, flat), _grad_back
135+
end

test/destructure.jl

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,15 @@ end
5757
@test g6.a isa Vector{Float64}
5858
@test g6.b == [0+im]
5959

60-
# Second derivative -- no method matching rrule(::typeof(Optimisers._rebuild), ...?
61-
@test_broken gradient([1,2,3]) do v
62-
sum(abs2, gradient(m -> sum(abs2, destructure(m)[1]), (v, [4,5,6]))[1][1])
63-
end[1] [8,16,24]
60+
@testset "second derivative" begin
61+
@test_broken gradient([1,2,3.0]) do v
62+
sum(abs2, gradient(m -> sum(abs2, destructure(m)[1]), (v, [4,5,6.0]))[1][1])
63+
end[1] [8,16,24]
64+
65+
@test_skip gradient([1,2,3.0]) do v
66+
sum(gradient(m -> sum(destructure(m)[1]), (v, [4,5,6.0]))[1][1])
67+
end
68+
end
6469
end
6570

6671
@testset "gradient of rebuild" begin
@@ -85,10 +90,18 @@ end
8590
@test gradient(x -> re7(x).b[2][2], rand(3))[1] == [0,0,0]
8691
@test gradient(x -> re7(x).c[2][1], rand(3))[1] == [0,0,0]
8792

88-
# Second derivative -- error from _tryaxes(x::Tangent) in Zygote's map rule
89-
@test_broken gradient(collect(1:6)) do y
90-
sum(abs2, gradient(x -> sum(abs2, re2(x)[1]), y)[1])
91-
end[1] [8,16,24,0,0,0]
93+
@testset "second derivative" begin
94+
# ERROR: Need an adjoint for constructor ChainRulesCore.Tangent{Any, Tuple{Vector{Float64}, ChainRulesCore.ZeroTangent}}. Gradient is of type Tuple{Vector{Float64}, Vector{Float64}}
95+
@test_broken gradient(collect(1:6.0)) do y
96+
sum(abs2, gradient(x -> sum(abs2, re2(x)[1]), y)[1])
97+
end[1] [8,16,24,0,0,0]
98+
# This fixes it!
99+
# Zygote.@adjoint Tangent{T,B}(x::Tuple) where {T,B<:Tuple} = Tangent{T,B}(x), dx -> (dx,)
100+
@test_skip gradient(collect(1:6.0)) do y
101+
sum(abs2, gradient(x -> sum(abs2, re3(x).z), y)[1])
102+
end[1]
103+
# Zygote.@adjoint Tangent{T,B}(x::NamedTuple) where {T,B<:NamedTuple} = Tangent{T,B}(x), dx -> (dx,)
104+
end
92105
end
93106

94107
@testset "Flux issue 1826" begin

0 commit comments

Comments
 (0)