diff --git a/src/compiler/emit.jl b/src/compiler/emit.jl index d277621ab..1c82a44f1 100644 --- a/src/compiler/emit.jl +++ b/src/compiler/emit.jl @@ -95,7 +95,7 @@ end varargs(m::Method, n) = m.isva ? n - m.nargs + 1 : nothing -function _lookup_grad(T) +function _generate_pullback_via_decomposition(T) (m = meta(T)) === nothing && return va = varargs(m.method, length(T.parameters)) forw, back = stacks!(Adjoint(IR(m), varargs = va, normalise = false), T) diff --git a/src/compiler/interface2.jl b/src/compiler/interface2.jl index 141d90a77..ac4a5a76a 100644 --- a/src/compiler/interface2.jl +++ b/src/compiler/interface2.jl @@ -23,7 +23,7 @@ end hascr && return :($chain_rrule_f(ZygoteRuleConfig(ctx), f, args...)) - g = try _lookup_grad(T) catch e e end + g = try _generate_pullback_via_decomposition(T) catch e e end g === nothing && return :(f(args...), Pullback{$T}((f,))) meta, forw, _ = g argnames!(meta, Symbol("#self#"), :ctx, :f, :args) @@ -37,7 +37,8 @@ end @generated function (j::Pullback{T})(Δ) where T ignore_sig(T) && return :nothing - g = try _lookup_grad(T) + g = try + _generate_pullback_via_decomposition(T) catch e rethrow(CompileError(T,e)) end