Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Canonicalize the Composite before changing it to a NamedTuple #926

Merged
merged 8 commits into from
Mar 25, 2021

Conversation

mzgubic
Copy link
Collaborator

@mzgubic mzgubic commented Mar 24, 2021

Closes #922.

The issue occurs when the following are all true:

  1. We have a struct with two differentiable fields
  2. We need to accumulate the gradients with respect to both fields
  3. The gradients that need to be accumulated both originate from an rrule

In that case, the two gradients that originate from an rrule are Composites with disjoint fields. When they are transformed to a NamedTuple, the Zygote internal representation of derivatives w.r.t. structs, these two NamedTuples have disjoint sets of keys, and are not accumulated correctly.

By adding canonicalize(x), the Composite gets explicit Zero() fields, which means the resulting NamedTuples have the complete set of fields.

- blocked by JuliaDiff/ChainRules.jl#390 (tests only)
- need to change compat once the above is merged

@mzgubic
Copy link
Collaborator Author

mzgubic commented Mar 25, 2021

Looks like we either need JuliaDiff/ChainRulesCore.jl#321 or use

    maybe_canonicalized = isabstracttype(P) ? x : canonicalize(x)
    xp = map(wrap_chainrules_output, maybe_canonicalized)

instead. It seems to be able to do type inference alright.

@DhairyaLGandhi
Copy link
Member

Seems a bit hairy. Could you give me an idea of what canonicalize does concretely?

@mzgubic
Copy link
Collaborator Author

mzgubic commented Mar 25, 2021

Yeah I'd prefer the issue but there might be reasons not to do that (I can't think of any, but there might be).

To understand canonicalize one needs to know that Composites, which are the differential types for structs, only store non-zero gradients explicitly. For example, if I have a MyStruct with two fields, a and b, then the rrule for getproperty(mystruct, a) can return Composite{MyStruct}(;a=ȳ) rather than Composite{MyStruct}(;a=ȳ, b=Zero()).

The zeros for other fields are implicit, meaning it is possible to sum Composites easily:

julia> Composite{MyStruct}(;a=ȳ) + Composite{MyStruct}(;b=ȳ2)
Composite{MyStruct}(;a=ȳ, b=ȳ2)

canonicalize will simply make these implicit Zero()s explicit, i.e.

julia> canonicalize(Composite{MyStruct}(;a=ȳ))
Composite{MyStruct}(;a=ȳ, b=Zero())

This relies on knowing the primal type (MyStruct) that Composite stores the derivative for in order to access the fields. Here, the issue is that in higher-order AD sometimes the primal type is unknown*, and in that case it is not possible to know which fields should explicitly be set to zero.

*I think we can get rid of this once Zygote uses ChainRules differential types internally

@oxinabox
Copy link
Member

I am pretty sure we should do this.
Since otherwise the NamedTuple that comes out of a ChainRules rule will not ensc have the same fields as one from Zygote working it out on its own.
And that will lead to badness.

I commented on this here: #861 (comment)

@mzgubic
Copy link
Collaborator Author

mzgubic commented Mar 25, 2021

I've added ChainRulesCore as a direct dependency to be able to specify a compat bound. I don't have merge rights but there is nothing I want to add, so please go ahead once you think it's ready

@oxinabox
Copy link
Member

the GPU tests for o Cholesky weem flaky

See if they work for bores

Bors r+

@mzgubic
Copy link
Collaborator Author

mzgubic commented Mar 25, 2021

Same test failed in a different PR which only added another test. Not sure what to think about that, but I don't think it's related to either of the PRs?

@bors
Copy link
Contributor

bors bot commented Mar 25, 2021

Merge conflict.

@mzgubic
Copy link
Collaborator Author

mzgubic commented Mar 25, 2021

Could you try again please?

@oxinabox
Copy link
Member

bors r+

@bors
Copy link
Contributor

bors bot commented Mar 25, 2021

Build succeeded:

@bors bors bot merged commit 56f4118 into FluxML:master Mar 25, 2021
@mzgubic mzgubic deleted the mz/canonicalize branch March 25, 2021 16:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

accumulation of gradients
3 participants