Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ProjectTo(::NamedTuple) #515

Merged
merged 2 commits into from
Dec 5, 2021
Merged

Add ProjectTo(::NamedTuple) #515

merged 2 commits into from
Dec 5, 2021

Conversation

devmotion
Copy link
Member

This PR adds ProjectTo(::NamedTuple) according to the suggestion by @mcabbott (I added you as a co-author to give credit). I only added two additional more descriptive error messages and some tests, similar to the implementation for Tuples.

Fixes #511.

Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
@devmotion
Copy link
Member Author

StatsFuns and Diffractor test errors seem unrelated.

@mcabbott
Copy link
Member

Zygote seems to always make the whole tuple, and in the right order:

julia> gradient(x -> x.a /x.b , (a=1, b=2, c=3))
((a = 0.5, b = -0.25, c = nothing),)

But Diffractor seems to just keep the nonzero ones:

julia> g = gradient(x -> x.a /x.b , (a=1, b=2, c=3))
(Tangent{NamedTuple{(:a, :b, :c), Tuple{Int64, Int64, Int64}}}(b = -0.25, a = 0.5),)

julia> g[1].c
ChainRulesCore.ZeroTangent()

julia> ChainRulesCore.canonicalize(g[1])
Tangent{NamedTuple{(:a, :b, :c), Tuple{Int64, Int64, Int64}}}(a = 0.5, b = -0.25, c = ZeroTangent())

I suspect this means it can't just call backing and map, but will have to at least accept subsets of the keys. Whether it should also produce them or not I don't know.

@mzgubic
Copy link
Member

mzgubic commented Nov 11, 2021

Yeah, Zygote didn't use to do that and that led to issues, see e.g.
FluxML/Zygote.jl#922 (comment)
which was fixed by FluxML/Zygote.jl#926

Zygote has to do that because it uses NamedTuples, not Tangents (where the elements that are not there are ZeroTangent automatically). Diffractor uses Tangents internally so it should be fine.

I guess we just have to do backing(canonicalize(tangent)). I don't see why we can't use map? That's a NamedTuple projected onto a NamedTuple, nothing to do with Tangent right?

@devmotion
Copy link
Member Author

map requires that the names are identical. It is quite straightforward to write a map alternative that would just project the existing derivatives. Would this be preferred over using canonicalize? And is it OK if the order of the keys is different in the resulting Tangent than in the projector, or should they be reordered?

@mzgubic
Copy link
Member

mzgubic commented Nov 11, 2021

The order does not matter as far as I know.

I think what we want is to allow

julia> nt = (;a=1.0, b=2.0);

julia> project = ProjectTo(nt);

julia> t = Tangent{typeof(nt)}(;a=2.3)

julia> project(t)

but not

julia> project((;a=1.0))

The reason being that it is a feature of Tangents that they are implicitly ZeroTangent for all non explicitly specified fields.
On the other hand, ChainRules doesn't think a NamedTuple is a valid tangent type, so we should probably throw up a fuss if there is anything remotely suspicious going on?

Do we ever see a NamedTuple being directly projected? (As opposed to from it being the backing of a Tangent?). Should we even allow it at all?

@devmotion
Copy link
Member Author

Do we ever see a NamedTuple being directly projected? (As opposed to from it being the backing of a Tangent?). Should we even allow it at all?

This is the case that errors in the CRTestUtils PR: FiniteDifferences returns a NamedTuple and we have to project it to the Tangent representation. Currently this is done with a custom method but it only covers Tuple and NamedTuple but eg not nested Tuples and NamedTuples. @oxinabox suggested to use ProjectTo instead of adding more dispatches to the custom method.

@mcabbott
Copy link
Member

Allowing project((;a=1.0)) seems like a good idea to me. IIRC one reason the equivalent Tuple rule does accepts tuples is that it makes writing rules easier: you don't have to construct the Tangent by hand, and can treat Union{Tuple, Vector} the same way.

Perhaps it should just iterate over the keys of this input? If any of them are not found in the projector, that's an error.

@devmotion
Copy link
Member Author

OK, I'll update the PR this evening (have a draft on my computer but are busy ATM) 🙂

@codecov-commenter
Copy link

codecov-commenter commented Nov 11, 2021

Codecov Report

Merging #515 (d1bb3b2) into main (f9304b3) will increase coverage by 0.17%.
The diff coverage is 95.45%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #515      +/-   ##
==========================================
+ Coverage   92.91%   93.09%   +0.17%     
==========================================
  Files          15       15              
  Lines         819      854      +35     
==========================================
+ Hits          761      795      +34     
- Misses         58       59       +1     
Impacted Files Coverage Δ
src/projection.jl 97.29% <95.45%> (-0.11%) ⬇️
src/rule_definition_tools.jl 96.27% <0.00%> (+0.02%) ⬆️
src/accumulation.jl 97.22% <0.00%> (+0.07%) ⬆️
src/tangent_types/thunks.jl 95.00% <0.00%> (+0.10%) ⬆️
src/tangent_types/tangent.jl 85.50% <0.00%> (+0.32%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update f9304b3...d1bb3b2. Read the comment docs.

return project_type(project)(; dy...)
end

# Diffractor returns not necessarily a named tuple with all keys and of the same order as
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does Diffractor have to do with anything, and why does it return a namedtuple?
It should be a Tangent.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It refers to #515 (comment). The Tangents are already unpacked at this stage.

# Diffractor returns not necessarily a named tuple with all keys and of the same order as
# the projector
# Thus we can't use `map`
function _project_namedtuple(f::NamedTuple{fn,ft}, x::NamedTuple{xn,xt}) where {fn,ft,xn,xt}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this function for?

Can't we just stick the thing into a Tangent{typeof(f), typeof(x)}(x) ?
which should robustly handly non-present keys and keys in different orders.
And if for some reason we can't handle that then add a canonicalize ?

Copy link
Member Author

@devmotion devmotion Nov 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is our custom projection map. Initially, in the first commit I just used map with the named tuple of projectors and named tuple of derivatives, as suggested by @mcabbott. However, map requires that the names of both named tuples are exactly identical, i.e., all derivatives are present and in the same order as the projectors. This function here just maps all existing derivatives and throws a more descriptive error if a derivative is present without corresponding projector.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it could take short namedtuples and route them through Tangent -> canonise -> backing -> map -> Tangent, to re-use more stuff:

julia> using ChainRulesCore

julia> x = (a=1, b=2, c=3); dx = (b=400,);

julia> Tangent{typeof(x)}(; dx...)
Tangent{NamedTuple{(:a, :b, :c), Tuple{Int64, Int64, Int64}}}(b = 400,)

julia> ChainRulesCore.canonicalize(ans)
Tangent{NamedTuple{(:a, :b, :c), Tuple{Int64, Int64, Int64}}}(a = ZeroTangent(), b = 400, c = ZeroTangent())

julia> ChainRulesCore.backing(ans)
(a = ZeroTangent(), b = 400, c = ZeroTangent())

My slight reservation about all approaches really is whether we can insert enough complication to confuse Diffractor when it wants to take a 3rd derivative or something.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need the backing step?
why not
Tangent -> map, which already returns a Tangent?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, like this it doesn't:

julia> tang
Tangent{NamedTuple{(:a, :b, :c), Tuple{Int64, Int64, Int64}}}(a = ZeroTangent(), b = 400, c = ZeroTangent())

julia> projs = map(ProjectTo, x);

julia> map((f,x) -> f(x), projs, tang)
3-element Vector{Any}:
    ZeroTangent()
 400.0
    ZeroTangent()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't dug through all the functions closely, recently, but my reservation here is that this seems close to being a second use of canonicalize, just with a different carefully optimised generated implementation. It seems that if ever something breaks one, we'll have to fix both.

Is there a precedent anywhere else here about whether filling in all fields with NoTangent is preferable / not compared to leaving omitted ones out?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would preferably not fill in all fields.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was my understanding as well - and hence I don't think one should use canonicalize here since we don't want to fill all fields.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mcabbott Are you OK with merging the PR as is and improving the implementation later, if e.g. there is a clear need for a two argument map?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, do it. I think this is the right behaviour. I do wish it could be shorter but that's not the end of the world. Sorry about dragging this out so long.

@devmotion
Copy link
Member Author

As an explanation for the current state of this PR and how ProjectTo(::NamedTuple) acts:

  • ProjectTo(::NamedTuple) is an analogue of of ProjectTo(::Tuple), just with Tangent{<:NamedTuple} instead of Tangent{<:Tuple} and a named tuple of projectors instead of a tuple of projectors

  • ProjectTo{<:Tangent{<:NamedTuple}} allows to project Tangent{<:NamedTuple} and NamedTuple to the desired space

  • Since it is allowed (it seems?) to use Tangent types in rrules and from AD backends (such as Diffractor, see the example above) that do

    • only contain derivatives of a subset of elements/entries in the primal input (rest assumed to be ZeroTangent)
    • contain derivatives of elements/entries in the primal input in any order

    we do not use map((f, x) -> f(x), backing(projector), dx) in the intermediate projection step (see @mcabbott's initial suggestion) but instead just project every entry in dx (map(f, ::NamedTuple, ::NamedTuple) requires that the names of both named tuples are identical, i.e., no subsets and only in the same order)

The projection step is implemented with a function that contains an if @generated block to ensure that the output can be inferred and it can be computed efficiently. It will display a hopefully helpful error message if any keys are present in dx for which no projector is defined.

A code example:

julia> using ChainRulesCore

julia> x = (a=1.0, b=[1.0, 3.0, 4.0], c=false);

julia> pt1 = ProjectTo(x)
ProjectTo{Tangent{NamedTuple{(:a, :b, :c),Tuple{Float64,Array{Float64,1},Bool}},T} where T}(a = ProjectTo{Float64}(), b = ProjectTo{AbstractArray}(element = ProjectTo{Float64}(), axes = (Base.OneTo(3),)), c = ProjectTo{NoTangent}())

# subsets are OK
julia> pt1(Tangent{typeof(x)}(; b = [3.0 + 0*im, 2.0, 1.0]))
Tangent{NamedTuple{(:a, :b, :c),Tuple{Float64,Array{Float64,1},Bool}}}(b = [3.0, 2.0, 1.0],)

julia> pt1((b = [3.0 + 0*im, 2.0, 1.0],))
Tangent{NamedTuple{(:a, :b, :c),Tuple{Float64,Array{Float64,1},Bool}}}(b = [3.0, 2.0, 1.0],)

# order does not matter
julia> pt1(Tangent{typeof(x)}(; b = [3.0 + 0*im, 2.0, 1.0], a=big(π)))
Tangent{NamedTuple{(:a, :b, :c),Tuple{Float64,Array{Float64,1},Bool}}}(b = [3.0, 2.0, 1.0], a = 3.141592653589793)

julia> pt1((b = [3.0 + 0*im, 2.0, 1.0], a=big(π)))
Tangent{NamedTuple{(:a, :b, :c),Tuple{Float64,Array{Float64,1},Bool}}}(b = [3.0, 2.0, 1.0], a = 3.141592653589793)

# error if derivative without projecto
julia> pt1(Tangent{typeof(x)}(; d=42, b = [3.0 + 0*im, 2.0, 1.0]))
ERROR: LoadError: ArgumentError: named tuple with keys(x) == (:a, :b, :c) cannot have a gradient with key d
...

julia> pt1((d=42, b = [3.0 + 0*im, 2.0, 1.0]))
ERROR: LoadError: ArgumentError: named tuple with keys(x) == (:a, :b, :c) cannot have a gradient with key d
...

One could "complete" the Tangent or named tuple with canonicalize, as @mcabbott suggested, since it would allow to project the derivatives more easily (no "missing" derivatives and of the same order as in the primal input). I don't know anything about the design choices behind ProjectTo but to me it seems a bit surprising to use canonicalize here if it is just fine and supported to use Tangents with a subset of keys and different order currently in rrule definitions. As long as the canonical version is not always enforced it seems a bit inconsistent to use it in ProjectTo.

Also I think the design rationale and the implementation of this PR are reasonably simple and hence I don't think canonicalize would reduce code complexity much: _project_namedtuple(projectors::NamedTuple, derivatives::NamedTuple) just projects every derivative in derivatives with the corresponding projector in projectors, or displays an error message if not existent.

@mcabbott mcabbott added the ProjectTo related to the projection functionality label Nov 24, 2021
@devmotion
Copy link
Member Author

What's the status here, should the design be changed that I tried to explain above? The PR still blocks the PR in ChainRulesTestUtils.

Copy link
Member

@oxinabox oxinabox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I was indecisive.
I have been wondering about https://github.com/JuliaDiff/ChainRulesCore.jl/pull/515/files#r759728574

but regardless, if we do that or not we should do this.
This PR is ok, we can always improve it later.

@devmotion devmotion merged commit addf6d9 into main Dec 5, 2021
@devmotion devmotion deleted the dw/projectto_namedtuple branch December 5, 2021 23:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ProjectTo related to the projection functionality
Projects
None yet
Development

Successfully merging this pull request may close these issues.

implement ProjectTo(::NamedTuple)
5 participants