Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Insert _project into getproperty's gradient, and then improve z2d etc. to restore stability #1104

Merged
merged 29 commits into from
Nov 7, 2021

Conversation

mcabbott
Copy link
Member

The goal here is to avoid things like this:

julia> pullback(x -> abs2(x.im * ((1-im) * x).re), 4+5im)[2](1.0)
# Zygote v0.6.0:
((re = 450.0 + 450.0im, im = 810.0),)
# Zygote v0.6.28:
ERROR: MethodError: no method matching Complex(::ComplexF64, ::ZeroTangent)
# this PR:
(450.0 + 1260.0im,)

While projection was intended as a mechanism to encode & enforce mathematical properties like real vs complex, I think it's also a good way to encode (when desired) a preference for the "natural gradient" of a complex number, instead of a NamedTuple, because this can participate in other stages of the calculation. One example of this found in the wild was fixed in JuliaDiff/ChainRules.jl#509, a made-up example in the same spirit is:

julia> gradient(x -> sqrt(sum(parent(x .^ 2))), Diagonal([1,2,3]))[1]
# latest tagged:
ERROR: ArgumentError: broadcasting over dictionaries and `NamedTuple`s is reserved
# this PR + CRC#446:
3×3 Diagonal{Float64, Vector{Float64}}:
 0.267261             
          0.534522    
                   0.801784

I'm a little concerned this may affect what #909 just fixed. Maybe CI will tell us, but cc @simeonschaub ?

@simeonschaub
Copy link
Member

Seems reasonable, since _project should be type stable I don't expect this to cause too many issues. We probably want to do this in the mutable case as well though for consistency?

@mcabbott
Copy link
Member Author

Ok, good. Yes I need to stare a bit more at the mutable case, and think up some tests.

@willtebbutt
Copy link
Member

Should we need this here? There's nothing in the default getproperty implementation that would warrant it -- if something needs projecting here, presumably it's because a rule somewhere else was incorrect, no?

@mcabbott
Copy link
Member Author

mcabbott commented Oct 16, 2021

because a rule somewhere else was incorrect, no?

No. These examples aren't doing "mathematical" projection at all. They just change the representation, from a NamedTuple to something more like the original.

As I tried to say above, this is a second meaning to overload onto _project, and there is some chance that it ought to be called something else. But I think that something else would share many features, it wants to be similarly user-overloadable. Combining these seems fairly natural and tidy to me.

default getproperty implementation

There's nothing wrong with the default, except that it returns a NamedTuple. The examples above have steps before getproperty, which the gradient wants to flow backwards into, and these steps don't work with a NamedTuple -- you get the errors shown or the nonsensical answers shown.

@mcabbott mcabbott force-pushed the project_getprop branch 2 times, most recently from 28c05a1 to 34f6af6 Compare October 17, 2021 18:36
@willtebbutt
Copy link
Member

Ahh I see, thanks for the explanation. The only basis on which I think I take issue with doing this is on aspects of how we're doing AD that I know we already disagree about (when to use naturals vs structurals etc), so I won't get in the way of this.

Copy link
Member Author

@mcabbott mcabbott left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

90% of the PR is upgrades & bug fixes to the map to & from ChainRules types. They aren't perfect yet, but better.

_project now uses zygote2differential and hence does not produce Tangent{Any}, so we can remove some CRC piracy which handled that.

src/compiler/chainrules.jl Show resolved Hide resolved
src/compiler/chainrules.jl Show resolved Hide resolved
src/compiler/chainrules.jl Outdated Show resolved Hide resolved
src/lib/lib.jl Show resolved Hide resolved
test/compiler.jl Show resolved Hide resolved
test/features.jl Show resolved Hide resolved
Comment on lines +23 to +24
- {user: TuringLang, repo: DynamicPPL.jl, group: All}
- {user: TuringLang, repo: DistributionsAD.jl, group: Zygote}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Downstream test of DistributionsAD is trivial right now. When TuringLang/DistributionsAD.jl#203 is merged, it should then test just the Zygote test group.

It looks like that should take < 30 mins. Running all tests takes > 6 hours, times out on github actions.

primals = NamedTuple{fnames}(getfield(primal, fn) for fn in fnames)
tp::NamedTuple = map(z2d, complete_t, primals)
return canonicalize(Tangent{primal_type, typeof(tp)}(tp))
z2d(dx::NamedTuple, primal::AbstractDict) = dx # uses a NamedTuple but not for fields!
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
z2d(dx::NamedTuple, primal::AbstractDict) = dx # uses a NamedTuple but not for fields!
# Dict handling in Zygote is a mess for now leave it alone.
# TODO should this become a `Tangent{Dict,Dict}` ?
# right now it uses a NamedTuple but not for fields of AbstractDict struict
z2d(dx::NamedTuple, primal::AbstractDict) = dx

@oxinabox
Copy link
Member

Does this close #660 ?

@mcabbott
Copy link
Member Author

mcabbott commented Oct 26, 2021

Does this close #660 ?

That appears to be fixed, but not by this PR. Can add it as a test to another list of old PRs I was making...

@mcabbott mcabbott changed the title Insert _project into getproperty's gradient Insert _project into getproperty's gradient, and then improve z2d etc. to restore stability Oct 26, 2021
@oxinabox
Copy link
Member

I am going to leave this to @DhairyaLGandhi to give final approval

@devmotion
Copy link
Collaborator

FYI this PR fixes the remaining Zygote test errors in TuringLang/DistributionsAD.jl#203 (TuringLang/DistributionsAD.jl#203 (comment)).

@DhairyaLGandhi
Copy link
Member

Cool, let me take a pass at this in my morning. @devmotion does this address the accum(::NamedTuple, ::Tangent) case for you?

@devmotion
Copy link
Collaborator

@devmotion does this address the accum(::NamedTuple, ::Tangent) case for you?

I don't know which error you refer to, it's none of the errors in TuringLang/DistributionsAD.jl#203 (comment). Both Zygote problems in this comment (TuringLang/DistributionsAD.jl#203 (comment)) are fixed with this PR though.

src/compiler/chainrules.jl Show resolved Hide resolved
return canonicalize(Tangent{primal_type, typeof(tp)}(tp))
# Dict handling in Zygote is a mess... should this become a `Tangent{Dict,Dict}` ?
# Right now it uses a NamedTuple but not for fields of the AbstractDict struct
z2d(dx::NamedTuple, primal::AbstractDict) = dx
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why should it be the fields of the AbstractDict struct?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No statement is being made here about should or shouldn't.
Per the comment Dict handling is currently a mess, and is inconsistent.
But most of the time it is a named tuple that is not the fields, but rather is the values.
And that is fine.

Comment on lines +296 to +297
if backing isa Tuple{Vararg{AbstractZero}}
return NoTangent() # collapse all-zero case
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There seems to be a lot of "collapsing zero" going on. The accumulation step already takes care of the collapsing for the most part. Why can't we let the zeros cases get handled in the accumulation step? Is this related to zygote sometimes producing Array{Nothing}?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without it the DistributionsAD tests fail. The main reason is that ChainRules expects non-differentiable objects to have a NoTangent derivative but without this fix we end up with stuff like Tangent{T}(; a = NoTangent(), b = [NoTangent(), NoTangent()], c = Tangent{S}(; d = NoTangent())).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A stacktrace would be helpful.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are more explanations in TuringLang/DistributionsAD.jl#203 (comment). Unfortunately, it seems there are no GH actions logs of the original issue since I fixed it (partially) before adding and pushing a new commit.

@DhairyaLGandhi
Copy link
Member

DhairyaLGandhi commented Nov 1, 2021

Given the extensive breakage in the ecosystem, we would definitely need a hot patch so we can get PPL and SciML up and running again. We should look for a more robust solution than we have currently to handle the zygote <-> CR interaction.

@devmotion
Copy link
Collaborator

IMO the integration tests added in this PR are a big improvement and hopefully ensure that Zygote support in Distributions and Turing is not broken accidentally again.

I don't think it's the main problem right now but I assume the CR-Zygote integration could be simplified if Zygote would use NoTangent and ZeroTangent instead of nothing.

Comment on lines 159 to 161
# Could `reinterpret` instead here...
# One easy case, but can this go wrong?
# @inline wrap_chainrules_input(xs::Base.ReinterpretArray{<:NamedTuple, <:Tangent}) = parent(@show xs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like debugging code

Copy link
Member Author

@mcabbott mcabbott Nov 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The show you mean? It's in a comment, about further optimisations, for which again 1112 is the issue.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can get rid of the show either way. In general, it's best to leave reinterpreted arrays be, they can have unexpected behavior if we try to manage them manually. Even in a comment, I still doubt the show was intentionally expected to be part of committed code?

# Could `reinterpret` instead of broadcasting here -- TODO
@inline wrap_chainrules_output(xs::AbstractArray{<:ChainRules.Tangent}) = wrap_chainrules_output.(xs)
wrap_chainrules_output(dxs::AbstractArray{<:Number}) = dxs
wrap_chainrules_output(dxs::AbstractArray{<:AbstractArray{<:Number}}) = dxs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an mwe for this being needed? Is it for RecursiveArrayTools or does it come up elsewhere?

Copy link
Member Author

@mcabbott mcabbott Nov 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For all of these, the minimal correct thing is wrap_chainrules_output(dxs::AbstractArray) = map(wrap_chainrules_output, dxs), etc.

But for some common types, we know that the representation used by Zygote and by ChainRules agrees, so we can save a copy by doing nothing. That's the case for AbstractArray{<:Number} and AbstractArray{<:AbstractArray{<:Number}}. These aren't needed, but are an optimisation for common cases.

1112 is the issue about less trivial optimisations to save more copies.

Copy link
Member

@DhairyaLGandhi DhairyaLGandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bors r+

@mcabbott mcabbott merged commit 4ed3a86 into FluxML:master Nov 7, 2021
@mcabbott mcabbott deleted the project_getprop branch November 7, 2021 21:15
@ToucheSir
Copy link
Member

Cryptic error from Bors: https://app.bors.tech/repositories/20723/log

@mcabbott
Copy link
Member Author

mcabbott commented Nov 7, 2021

Indeed.

This did pass buildkite tests, BTW, after being rebased on master (a8306bc), and master hasn't changed since.

bors bot pushed a commit to TuringLang/DynamicPPL.jl that referenced this pull request Nov 8, 2021
FluxML/Zygote.jl#1104 was merged and hence (hopefully) Zygote tests are not broken anymore.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants