-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ProjectTo(::NamedTuple)
#515
Conversation
Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
StatsFuns and Diffractor test errors seem unrelated. |
Zygote seems to always make the whole tuple, and in the right order:
But Diffractor seems to just keep the nonzero ones:
I suspect this means it can't just call |
Yeah, Zygote didn't use to do that and that led to issues, see e.g. Zygote has to do that because it uses I guess we just have to do |
|
The order does not matter as far as I know. I think what we want is to allow julia> nt = (;a=1.0, b=2.0);
julia> project = ProjectTo(nt);
julia> t = Tangent{typeof(nt)}(;a=2.3)
julia> project(t) but not julia> project((;a=1.0)) The reason being that it is a feature of Do we ever see a NamedTuple being directly projected? (As opposed to from it being the backing of a Tangent?). Should we even allow it at all? |
This is the case that errors in the CRTestUtils PR: FiniteDifferences returns a |
Allowing Perhaps it should just iterate over the keys of this input? If any of them are not found in the projector, that's an error. |
OK, I'll update the PR this evening (have a draft on my computer but are busy ATM) 🙂 |
Codecov Report
@@ Coverage Diff @@
## main #515 +/- ##
==========================================
+ Coverage 92.91% 93.09% +0.17%
==========================================
Files 15 15
Lines 819 854 +35
==========================================
+ Hits 761 795 +34
- Misses 58 59 +1
Continue to review full report at Codecov.
|
return project_type(project)(; dy...) | ||
end | ||
|
||
# Diffractor returns not necessarily a named tuple with all keys and of the same order as |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does Diffractor have to do with anything, and why does it return a namedtuple?
It should be a Tangent
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It refers to #515 (comment). The Tangent
s are already unpacked at this stage.
# Diffractor returns not necessarily a named tuple with all keys and of the same order as | ||
# the projector | ||
# Thus we can't use `map` | ||
function _project_namedtuple(f::NamedTuple{fn,ft}, x::NamedTuple{xn,xt}) where {fn,ft,xn,xt} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is this function for?
Can't we just stick the thing into a Tangent{typeof(f), typeof(x)}(x)
?
which should robustly handly non-present keys and keys in different orders.
And if for some reason we can't handle that then add a canonicalize
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is our custom projection map
. Initially, in the first commit I just used map
with the named tuple of projectors and named tuple of derivatives, as suggested by @mcabbott. However, map
requires that the names of both named tuples are exactly identical, i.e., all derivatives are present and in the same order as the projectors. This function here just maps all existing derivatives and throws a more descriptive error if a derivative is present without corresponding projector.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess it could take short namedtuples and route them through Tangent -> canonise -> backing -> map -> Tangent, to re-use more stuff:
julia> using ChainRulesCore
julia> x = (a=1, b=2, c=3); dx = (b=400,);
julia> Tangent{typeof(x)}(; dx...)
Tangent{NamedTuple{(:a, :b, :c), Tuple{Int64, Int64, Int64}}}(b = 400,)
julia> ChainRulesCore.canonicalize(ans)
Tangent{NamedTuple{(:a, :b, :c), Tuple{Int64, Int64, Int64}}}(a = ZeroTangent(), b = 400, c = ZeroTangent())
julia> ChainRulesCore.backing(ans)
(a = ZeroTangent(), b = 400, c = ZeroTangent())
My slight reservation about all approaches really is whether we can insert enough complication to confuse Diffractor when it wants to take a 3rd derivative or something.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need the backing
step?
why not
Tangent -> map
, which already returns a Tangent
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, like this it doesn't:
julia> tang
Tangent{NamedTuple{(:a, :b, :c), Tuple{Int64, Int64, Int64}}}(a = ZeroTangent(), b = 400, c = ZeroTangent())
julia> projs = map(ProjectTo, x);
julia> map((f,x) -> f(x), projs, tang)
3-element Vector{Any}:
ZeroTangent()
400.0
ZeroTangent()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't dug through all the functions closely, recently, but my reservation here is that this seems close to being a second use of canonicalize
, just with a different carefully optimised generated implementation. It seems that if ever something breaks one, we'll have to fix both.
Is there a precedent anywhere else here about whether filling in all fields with NoTangent is preferable / not compared to leaving omitted ones out?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We would preferably not fill in all fields.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was my understanding as well - and hence I don't think one should use canonicalize
here since we don't want to fill all fields.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mcabbott Are you OK with merging the PR as is and improving the implementation later, if e.g. there is a clear need for a two argument map
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, do it. I think this is the right behaviour. I do wish it could be shorter but that's not the end of the world. Sorry about dragging this out so long.
As an explanation for the current state of this PR and how
The projection step is implemented with a function that contains an A code example: julia> using ChainRulesCore
julia> x = (a=1.0, b=[1.0, 3.0, 4.0], c=false);
julia> pt1 = ProjectTo(x)
ProjectTo{Tangent{NamedTuple{(:a, :b, :c),Tuple{Float64,Array{Float64,1},Bool}},T} where T}(a = ProjectTo{Float64}(), b = ProjectTo{AbstractArray}(element = ProjectTo{Float64}(), axes = (Base.OneTo(3),)), c = ProjectTo{NoTangent}())
# subsets are OK
julia> pt1(Tangent{typeof(x)}(; b = [3.0 + 0*im, 2.0, 1.0]))
Tangent{NamedTuple{(:a, :b, :c),Tuple{Float64,Array{Float64,1},Bool}}}(b = [3.0, 2.0, 1.0],)
julia> pt1((b = [3.0 + 0*im, 2.0, 1.0],))
Tangent{NamedTuple{(:a, :b, :c),Tuple{Float64,Array{Float64,1},Bool}}}(b = [3.0, 2.0, 1.0],)
# order does not matter
julia> pt1(Tangent{typeof(x)}(; b = [3.0 + 0*im, 2.0, 1.0], a=big(π)))
Tangent{NamedTuple{(:a, :b, :c),Tuple{Float64,Array{Float64,1},Bool}}}(b = [3.0, 2.0, 1.0], a = 3.141592653589793)
julia> pt1((b = [3.0 + 0*im, 2.0, 1.0], a=big(π)))
Tangent{NamedTuple{(:a, :b, :c),Tuple{Float64,Array{Float64,1},Bool}}}(b = [3.0, 2.0, 1.0], a = 3.141592653589793)
# error if derivative without projecto
julia> pt1(Tangent{typeof(x)}(; d=42, b = [3.0 + 0*im, 2.0, 1.0]))
ERROR: LoadError: ArgumentError: named tuple with keys(x) == (:a, :b, :c) cannot have a gradient with key d
...
julia> pt1((d=42, b = [3.0 + 0*im, 2.0, 1.0]))
ERROR: LoadError: ArgumentError: named tuple with keys(x) == (:a, :b, :c) cannot have a gradient with key d
... One could "complete" the Also I think the design rationale and the implementation of this PR are reasonably simple and hence I don't think |
What's the status here, should the design be changed that I tried to explain above? The PR still blocks the PR in ChainRulesTestUtils. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I was indecisive.
I have been wondering about https://github.com/JuliaDiff/ChainRulesCore.jl/pull/515/files#r759728574
but regardless, if we do that or not we should do this.
This PR is ok, we can always improve it later.
This PR adds
ProjectTo(::NamedTuple)
according to the suggestion by @mcabbott (I added you as a co-author to give credit). I only added two additional more descriptive error messages and some tests, similar to the implementation forTuple
s.Fixes #511.