@@ -49,13 +49,16 @@ Convert `x` from the differentials types ChainRules uses to the format Zygote u
4949@inline wrap_chainrules_output (x) = conj (unthunk (x)) # For now we are just not going to deal with thunks
5050@inline wrap_chainrules_output (x:: Tuple ) = map (wrap_chainrules_output, x)
5151@inline wrap_chainrules_output (x:: ChainRules.AbstractZero ) = nothing
52- @inline function wrap_chainrules_output (x:: ChainRules.Composite{P, T} ) where {P, T}
53- T_outer = T <: NamedTuple ? NamedTuple : Tuple # must be a Tuple or NamedTuple, don't care about exact parameter types
54- xp = map (wrap_chainrules_output, x)
55- convert (T_outer, xp)
52+ for T_outer in (:Tuple , :NamedTuple )
53+ # we create separate methods rather than using a `Union` + an `if` so that we avoid a
54+ # branch that changes output type, because nested AD on that kinda thing makes Zygote less
55+ # than happy.
56+ @eval @inline function wrap_chainrules_output (x:: ChainRules.Composite{P, T} ) where {P, T<: $T_outer }
57+ xp = map (wrap_chainrules_output, x)
58+ convert ($ T_outer, xp)
59+ end
5660end
5761
58-
5962"""
6063 wrap_chainrules_input(x)
6164
@@ -69,24 +72,6 @@ to differentials types ChainRules uses.
6972 ChainRules. Composite {Any, typeof(xp)} (xp)
7073end
7174
72- """
73- wrap_chainrules_pullback(f, args...)
74-
75- Wrap a chainrule's pullback `f`, converting the format of the inputs (`args`),
76- and the outputs.
77- """
78- @inline function wrap_chainrules_pullback (pb, args... )
79- return wrap_chainrules_output (pb (wrap_chainrules_input (args)... ))
80- end
81-
82- # Note we hand-expess the single arg version of this to remove splatting
83- # because splatting breaks constant folding
84- # This can be removed after https://github.com/JuliaDiff/ChainRulesCore.jl/issues/152
85- @inline function wrap_chainrules_pullback (pb, a)
86- return wrap_chainrules_output (pb (wrap_chainrules_input (a)))
87- end
88-
89-
9075"""
9176 ZBack{F}(back) <: Function
9277
@@ -96,10 +81,7 @@ Wrapper for a ChainRules pullback `back`, that causes it to follow Zygote conven
9681struct ZBack{F} <: Function
9782 back:: F
9883end
99- @inline (s:: ZBack )(dy) = wrap_chainrules_pullback (s. back, dy)
100- # Dispatch here handles chainrules considing pullbacks to have multiple input if Tuple.
101- # TODO : this could be removed if: https://github.com/JuliaDiff/ChainRulesCore.jl/issues/152
102- @inline (s:: ZBack )(dy:: Tuple ) = wrap_chainrules_pullback (s. back, dy... )
84+ @inline (s:: ZBack )(dy) = wrap_chainrules_output (s. back (wrap_chainrules_input (dy)))
10385# `nothing->nothing` can be deleted after https://github.com/FluxML/Zygote.jl/issues/603
10486# though it might be worth keeping as a performance optimization (benchmarking pending)
10587@inline (s:: ZBack )(:: Nothing ) = nothing
@@ -127,6 +109,3 @@ As per [`chain_rrule`](@ref) but with support for kwargs.
127109 kw_zpullback (dy) = (nothing , nothing , ZBack (back)(dy). .. ) # first two nothings are for kwfunc and kwargs
128110 return y, kw_zpullback
129111end
130-
131- # Required for nested AD
132- @adjoint ChainRules. Composite {Any, T} (x:: T ) where T = ChainRules. Composite {Any, T} (x), x-> (x,)
0 commit comments