-
-
Notifications
You must be signed in to change notification settings - Fork 213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Insert _project
into getproperty
's gradient, and then improve z2d
etc. to restore stability
#1104
Conversation
Seems reasonable, since |
Ok, good. Yes I need to stare a bit more at the mutable case, and think up some tests. |
Should we need this here? There's nothing in the default |
No. These examples aren't doing "mathematical" projection at all. They just change the representation, from a NamedTuple to something more like the original. As I tried to say above, this is a second meaning to overload onto
There's nothing wrong with the default, except that it returns a NamedTuple. The examples above have steps before |
28c05a1
to
34f6af6
Compare
Ahh I see, thanks for the explanation. The only basis on which I think I take issue with doing this is on aspects of how we're doing AD that I know we already disagree about (when to use naturals vs structurals etc), so I won't get in the way of this. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
90% of the PR is upgrades & bug fixes to the map to & from ChainRules types. They aren't perfect yet, but better.
_project
now uses zygote2differential
and hence does not produce Tangent{Any}
, so we can remove some CRC piracy which handled that.
- {user: TuringLang, repo: DynamicPPL.jl, group: All} | ||
- {user: TuringLang, repo: DistributionsAD.jl, group: Zygote} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Downstream test of DistributionsAD is trivial right now. When TuringLang/DistributionsAD.jl#203 is merged, it should then test just the Zygote test group.
It looks like that should take < 30 mins. Running all tests takes > 6 hours, times out on github actions.
src/compiler/chainrules.jl
Outdated
primals = NamedTuple{fnames}(getfield(primal, fn) for fn in fnames) | ||
tp::NamedTuple = map(z2d, complete_t, primals) | ||
return canonicalize(Tangent{primal_type, typeof(tp)}(tp)) | ||
z2d(dx::NamedTuple, primal::AbstractDict) = dx # uses a NamedTuple but not for fields! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
z2d(dx::NamedTuple, primal::AbstractDict) = dx # uses a NamedTuple but not for fields! | |
# Dict handling in Zygote is a mess for now leave it alone. | |
# TODO should this become a `Tangent{Dict,Dict}` ? | |
# right now it uses a NamedTuple but not for fields of AbstractDict struict | |
z2d(dx::NamedTuple, primal::AbstractDict) = dx | |
Does this close #660 ? |
That appears to be fixed, but not by this PR. Can add it as a test to another list of old PRs I was making... |
45f5c27
to
a8306bc
Compare
_project
into getproperty
's gradient_project
into getproperty
's gradient, and then improve z2d
etc. to restore stability
I am going to leave this to @DhairyaLGandhi to give final approval |
FYI this PR fixes the remaining Zygote test errors in TuringLang/DistributionsAD.jl#203 (TuringLang/DistributionsAD.jl#203 (comment)). |
Cool, let me take a pass at this in my morning. @devmotion does this address the |
I don't know which error you refer to, it's none of the errors in TuringLang/DistributionsAD.jl#203 (comment). Both Zygote problems in this comment (TuringLang/DistributionsAD.jl#203 (comment)) are fixed with this PR though. |
return canonicalize(Tangent{primal_type, typeof(tp)}(tp)) | ||
# Dict handling in Zygote is a mess... should this become a `Tangent{Dict,Dict}` ? | ||
# Right now it uses a NamedTuple but not for fields of the AbstractDict struct | ||
z2d(dx::NamedTuple, primal::AbstractDict) = dx |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why should it be the fields of the AbstractDict
struct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No statement is being made here about should or shouldn't.
Per the comment Dict handling is currently a mess, and is inconsistent.
But most of the time it is a named tuple that is not the fields, but rather is the values.
And that is fine.
if backing isa Tuple{Vararg{AbstractZero}} | ||
return NoTangent() # collapse all-zero case |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There seems to be a lot of "collapsing zero" going on. The accumulation step already takes care of the collapsing for the most part. Why can't we let the zeros cases get handled in the accumulation step? Is this related to zygote sometimes producing Array{Nothing}
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Without it the DistributionsAD tests fail. The main reason is that ChainRules expects non-differentiable objects to have a NoTangent
derivative but without this fix we end up with stuff like Tangent{T}(; a = NoTangent(), b = [NoTangent(), NoTangent()], c = Tangent{S}(; d = NoTangent()))
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A stacktrace would be helpful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are more explanations in TuringLang/DistributionsAD.jl#203 (comment). Unfortunately, it seems there are no GH actions logs of the original issue since I fixed it (partially) before adding and pushing a new commit.
Given the extensive breakage in the ecosystem, we would definitely need a hot patch so we can get PPL and SciML up and running again. We should look for a more robust solution than we have currently to handle the zygote <-> CR interaction. |
IMO the integration tests added in this PR are a big improvement and hopefully ensure that Zygote support in Distributions and Turing is not broken accidentally again. I don't think it's the main problem right now but I assume the CR-Zygote integration could be simplified if Zygote would use |
src/compiler/chainrules.jl
Outdated
# Could `reinterpret` instead here... | ||
# One easy case, but can this go wrong? | ||
# @inline wrap_chainrules_input(xs::Base.ReinterpretArray{<:NamedTuple, <:Tangent}) = parent(@show xs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like debugging code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The show you mean? It's in a comment, about further optimisations, for which again 1112 is the issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can get rid of the show either way. In general, it's best to leave reinterpreted arrays be, they can have unexpected behavior if we try to manage them manually. Even in a comment, I still doubt the show was intentionally expected to be part of committed code?
# Could `reinterpret` instead of broadcasting here -- TODO | ||
@inline wrap_chainrules_output(xs::AbstractArray{<:ChainRules.Tangent}) = wrap_chainrules_output.(xs) | ||
wrap_chainrules_output(dxs::AbstractArray{<:Number}) = dxs | ||
wrap_chainrules_output(dxs::AbstractArray{<:AbstractArray{<:Number}}) = dxs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there an mwe for this being needed? Is it for RecursiveArrayTools or does it come up elsewhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For all of these, the minimal correct thing is wrap_chainrules_output(dxs::AbstractArray) = map(wrap_chainrules_output, dxs)
, etc.
But for some common types, we know that the representation used by Zygote and by ChainRules agrees, so we can save a copy by doing nothing. That's the case for AbstractArray{<:Number}
and AbstractArray{<:AbstractArray{<:Number}}
. These aren't needed, but are an optimisation for common cases.
1112 is the issue about less trivial optimisations to save more copies.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bors r+
Cryptic error from Bors: https://app.bors.tech/repositories/20723/log |
Indeed. This did pass buildkite tests, BTW, after being rebased on master (a8306bc), and master hasn't changed since. |
FluxML/Zygote.jl#1104 was merged and hence (hopefully) Zygote tests are not broken anymore.
The goal here is to avoid things like this:
While projection was intended as a mechanism to encode & enforce mathematical properties like real vs complex, I think it's also a good way to encode (when desired) a preference for the "natural gradient" of a complex number, instead of a NamedTuple, because this can participate in other stages of the calculation. One example of this found in the wild was fixed in JuliaDiff/ChainRules.jl#509, a made-up example in the same spirit is:
I'm a little concerned this may affect what #909 just fixed. Maybe CI will tell us, but cc @simeonschaub ?