-
-
Notifications
You must be signed in to change notification settings - Fork 217
Description
(This is related to https://discourse.julialang.org/t/zygote-jl-adjoint-mutating-inplace-adjoints/78241)
Inspecting the Zygote code, I can see that aside from @adjoint
there is also @adjoint!
that is used to declare the adjoints of some mutating functions (like push!
etc). I can’t find any doc strings or documentation when and how this can be used. I suspect, there are some limitations as Zygote generally forbids mutating. Additionally, ChainRules
in its documentation says (https://juliadiff.org/ChainRulesCore.jl/stable/rule_author/which_functions_need_rules.html#Functions-which-mutate-arrays):
Rules for functions which mutate its arguments, e.g. sort!, should not be written at the moment. While technically they are supported, they would break Zygote.jl such that it would sometimes quietly return the wrong answer. This may be resolved in the future by allowing AD systems to opt-in or opt-out of certain types of rules.
And then goes on to demonstrate how to write these rules for Zygote nonetheless, with the example for a function that adds inplace to the input array. This seems to work (and is probably translated to an @adjoint!
rule?).
(just copied from the ChainRules doc)
using ChainRules, Zygote
function addone!(array)
array .+= 1
return sum(array)
end
function ChainRules.rrule(::typeof(addone!), a)
y = addone!(a)
function addone!_pullback(ȳ)
return NoTangent(), ones(length(a))
end
return y, addone!_pullback
end
julia> gradient(addone!, a)
([1.0, 1.0, 1.0],)
So what are the requirements that these rules defined by ChainRules or by @adjoint!
work? In the linked issue at ChainRulesCore
, they can also not really exactly name these requirements. At the very least there should be some documentation on that and even better Zygote should return some kind of warning if they are prone to fail.