Skip to content

Commit 25b9aa6

Browse files
oxinaboxwilltebbutt
andcommitted
Remove special handling of multiple inputs to pullbacks from ChainRules (#1)
* ChainRules pullbacks always have 1 input JuliaDiff/ChainRulesCore.jl#152 * swap to version of chainrules that don't use multiarg pullbacks * update tests * make so don't need custom rule anymore * add comment * Update src/compiler/chainrules.jl Co-authored-by: willtebbutt <wt0881@my.bristol.ac.uk> Co-authored-by: willtebbutt <wt0881@my.bristol.ac.uk>
1 parent f0c3fb1 commit 25b9aa6

File tree

3 files changed

+18
-41
lines changed

3 files changed

+18
-41
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2222
[compat]
2323
AbstractFFTs = "0.5"
2424
ArrayLayouts = "0.1, 0.2"
25-
ChainRules = "0.5.1"
25+
ChainRules = "0.6.0"
2626
FillArrays = "0.8"
2727
ForwardDiff = "0"
2828
IRTools = "=0.3.1"

src/compiler/chainrules.jl

Lines changed: 9 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,16 @@ Convert `x` from the differentials types ChainRules uses to the format Zygote u
4949
@inline wrap_chainrules_output(x) = conj(unthunk(x)) # For now we are just not going to deal with thunks
5050
@inline wrap_chainrules_output(x::Tuple) = map(wrap_chainrules_output, x)
5151
@inline wrap_chainrules_output(x::ChainRules.AbstractZero) = nothing
52-
@inline function wrap_chainrules_output(x::ChainRules.Composite{P, T}) where {P, T}
53-
T_outer = T <: NamedTuple ? NamedTuple : Tuple # must be a Tuple or NamedTuple, don't care about exact parameter types
54-
xp = map(wrap_chainrules_output, x)
55-
convert(T_outer, xp)
52+
for T_outer in (:Tuple, :NamedTuple)
53+
# we create separate methods rather than using a `Union` + an `if` so that we avoid a
54+
# branch that changes output type, because nested AD on that kinda thing makes Zygote less
55+
# than happy.
56+
@eval @inline function wrap_chainrules_output(x::ChainRules.Composite{P, T}) where {P, T<:$T_outer}
57+
xp = map(wrap_chainrules_output, x)
58+
convert($T_outer, xp)
59+
end
5660
end
5761

58-
5962
"""
6063
wrap_chainrules_input(x)
6164
@@ -69,24 +72,6 @@ to differentials types ChainRules uses.
6972
ChainRules.Composite{Any, typeof(xp)}(xp)
7073
end
7174

72-
"""
73-
wrap_chainrules_pullback(f, args...)
74-
75-
Wrap a chainrule's pullback `f`, converting the format of the inputs (`args`),
76-
and the outputs.
77-
"""
78-
@inline function wrap_chainrules_pullback(pb, args...)
79-
return wrap_chainrules_output(pb(wrap_chainrules_input(args)...))
80-
end
81-
82-
# Note we hand-expess the single arg version of this to remove splatting
83-
# because splatting breaks constant folding
84-
# This can be removed after https://github.com/JuliaDiff/ChainRulesCore.jl/issues/152
85-
@inline function wrap_chainrules_pullback(pb, a)
86-
return wrap_chainrules_output(pb(wrap_chainrules_input(a)))
87-
end
88-
89-
9075
"""
9176
ZBack{F}(back) <: Function
9277
@@ -96,10 +81,7 @@ Wrapper for a ChainRules pullback `back`, that causes it to follow Zygote conven
9681
struct ZBack{F} <: Function
9782
back::F
9883
end
99-
@inline (s::ZBack)(dy) = wrap_chainrules_pullback(s.back, dy)
100-
# Dispatch here handles chainrules considing pullbacks to have multiple input if Tuple.
101-
# TODO: this could be removed if: https://github.com/JuliaDiff/ChainRulesCore.jl/issues/152
102-
@inline (s::ZBack)(dy::Tuple) = wrap_chainrules_pullback(s.back, dy...)
84+
@inline (s::ZBack)(dy) = wrap_chainrules_output(s.back(wrap_chainrules_input(dy)))
10385
# `nothing->nothing` can be deleted after https://github.com/FluxML/Zygote.jl/issues/603
10486
# though it might be worth keeping as a performance optimization (benchmarking pending)
10587
@inline (s::ZBack)(::Nothing) = nothing
@@ -127,6 +109,3 @@ As per [`chain_rrule`](@ref) but with support for kwargs.
127109
kw_zpullback(dy) = (nothing, nothing, ZBack(back)(dy)...) # first two nothings are for kwfunc and kwargs
128110
return y, kw_zpullback
129111
end
130-
131-
# Required for nested AD
132-
@adjoint ChainRules.Composite{Any, T}(x::T) where T = ChainRules.Composite{Any, T}(x), x->(x,)

test/chainrules.jl

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ using Zygote, Test, ChainRules
5757
simo(x) = (5x, 7x)
5858
function ChainRules.rrule(::typeof(simo), x)
5959
simo_rrule_hitcount[] += 1
60-
function simo_pullback(Δa, Δb)
60+
function simo_pullback((Δa, Δb))
6161
simo_pullback_hitcount[] += 1
6262
return ChainRules.NO_FIELDS, 5*Δa + 7*Δb
6363
end
@@ -101,7 +101,7 @@ using Zygote, Test, ChainRules
101101
mimo(a, b) = (5a + 7b, 100a, 10b)
102102
function ChainRules.rrule(::typeof(mimo), a, b)
103103
mimo_rrule_hitcount[] += 1
104-
function mimo_pullback(Δx, Δy, Δz)
104+
function mimo_pullback((Δx, Δy, Δz))
105105
mimo_pullback_hitcount[] += 1
106106
return ChainRules.NO_FIELDS, 5Δx + 100Δy , 7Δx + 10Δz
107107
end
@@ -126,9 +126,7 @@ using Zygote, Test, ChainRules
126126

127127
@testset "nested AD hitting identity(::Tuple) pullback" begin
128128
# This is is a particularly fiddly case.
129-
# the adjoint of `tuple` is `identity`
130-
# and `identity(::Tuple)`s pullback has multiple inputs
131-
# (since the primal had multiple outputs)
129+
# Its kind of a simplified version of `sin'''(0.5)` but different in some places.
132130

133131
f(x) = tuple(x, 2x, 3x)
134132

@@ -177,8 +175,8 @@ using Zygote, Test, ChainRules
177175
end
178176
end
179177

180-
# ChainRules doesn't have support for FastMath yet, so this fails
181-
# https://github.com/JuliaDiff/ChainRules.jl/issues/174
182-
@test_broken gradient(2.0) do x
183-
@fastmath x^2.0
184-
end == (4.0,)
178+
@testset "FastMath support" begin
179+
@test gradient(2.0) do x
180+
@fastmath x^2.0
181+
end == (4.0,)
182+
end

0 commit comments

Comments
 (0)