diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl index d9dd1b2cb..97a797025 100644 --- a/src/compiler/interface.jl +++ b/src/compiler/interface.jl @@ -40,9 +40,13 @@ _pullback(f, args...) = _pullback(Context(), f, args...) tailmemaybe(::Nothing) = nothing tailmemaybe(x::Tuple) = Base.tail(x) +replacezero(t) = ntuple(i -> t[i] isa AbstractZero ? nothing : t[i], length(t)) +replacezero(::Nothing) = nothing +replacezero(::AbstractZero) = nothing + function pullback(f, args...) y, back = _pullback(f, args...) - y, Δ -> tailmemaybe(back(Δ)) + y, Δ -> replacezero(tailmemaybe(back(Δ))) end sensitivity(y::Number) = one(y)