-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ProjectTo{<:Tangent}
for tuples & Ref
#488
Conversation
Codecov Report
@@ Coverage Diff @@
## main #488 +/- ##
==========================================
- Coverage 93.00% 92.94% -0.07%
==========================================
Files 15 15
Lines 801 822 +21
==========================================
+ Hits 745 764 +19
- Misses 56 58 +2
Continue to review full report at Codecov.
|
# Since this works like a zero-array in broadcasting, it should also accept a number: | ||
(project::ProjectTo{Ref})(dx::Number) = Tangent{project.type}(; x=project.x(dx)) | ||
|
||
(project::ProjectTo{<:Tangent{<:Tuple}})(dx::Tangent) = project(backing(dx)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This means that projection on the output of this projector will disassemble the Tangent and re-process the Tuple inside. I'm not sure that's ideal. Maybe it's safe to pass on all Tangents without further investigation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should constant-fold out for most cases we care about.
Because it will see the ProjectTo for the backing elements as already being the right type, which is known from the type in the project.elements
Do you want to check some with @code_typed
/ Cthulu ?
It is not safe to pass on all Tangents, because the tangent could be wrapping Complex Number/ Dense array that we need to fix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I wonder is whether I can think of the "mathematical" steps involving arrays etc. as being separate from the "structural" steps involving Tangents. If the first project, and then the backward flow assembles and de-assembles a Tangent, can this Tangent have "crossed a boundary" such that it belongs to a different argument type and hence may need further projection? I mostly think it can't; it would have to get un-packaed and those pieces operated on. But I'm not very sure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see.
Yeah, isn't this kind of true for many many operations that project?
They wouldn't need to project if were sure that the "mathematical" step before had projected and so only given them something good?
(also applies for if the mathematical step was a human constructing the tangents)
But it gets fuzzy around the edges?
How would sum(sum, ((Diagonal[1f0,2f0]), (Diagonal[1f0,2f0])))
go down?
The formatter is unhappy see: https://github.com/JuliaDiff/ChainRulesCore.jl/pull/488/checks?check_run_id=3854350212#step:5:187 I think it wasn't allowed to post suggestions as it was made from a fork. That should be fixed, or the formatter will complain in other PRs. |
dy = reshape(dx, axes(project.elements)) # allows for dx::OffsetArray | ||
dz = ntuple(i -> project.elements[i](dy[i]), length(project.elements)) | ||
return project_type(project)(dz...) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The above Tuple and AbstractArray cases are just optimizations of a general iterator one:
function (project::ProjectTo{<:Tangent{<:Tuple}})(dxs) # iterator fallback
dzs = (f(dx) for (f, dx) in zip(project.elements, dxs))
return project_type(project)(dzs...)
end
Should we have that as well?
And then we can note the others as just being optimizations?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could do. Do you have something in mind which might produce some weird type?
If some NamedTuple leaks from Zygote, I think this will produce stranger error messages, since it may make a Tangent of the wrong length?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't really have it in mind, more is that that is the general case we are handling.
It is weird that we only actually handle the two optimizable versions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess I'm not entirely sure we need the array version at all. I was thinking about things like broadcasting, although that handles it explicitly... but map doesn't:
julia> Zygote.pullback(x -> sum(map(+, x, [1,2])), (1,2))[2](1)
([1, 1],)
julia> gradient(x -> sum(map(+, x, [1,2])), (1,2)) # uses projection
((1.0, 1.0),)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would be down with seeing it removed til we know we need it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My vote is to keep this, although I can't think of another example besides map
right now which uses this. It looks like gradient(x -> sum([i^2 for i in x]), (1,2))
does not.
if length(dx) != len | ||
str = "tuple with length(x) == $len cannot have a gradient with length(dx) == $(length(dx))" | ||
throw(DimensionMismatch(str)) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will this not be caught by map
?
If we removed this check then this would basically be the general iterator fallback case.
https://github.com/JuliaDiff/ChainRulesCore.jl/pull/488/files#r726169970
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will, but the error is much less friendly... and might be a bug, JuliaLang/julia#42216
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Basically LGTM.
A few things to address.
Once sorted merge when happy
Co-authored-by: Lyndon White <oxinabox@ucc.asn.au>
Does If not, then it is I think entirely a matter of fixing the Zygote/ChainRules interface. Maybe |
It does not.
Correct.
You might be right, there are not tests of it directly.
I have worked on packages that have direct downstream tests. We should test features in ways that are like how they are intended to be used by downstream packages. |
Ok, thanks! It looks like we need this PR first. Then the fix is one line:
What that one line won't do is allow |
Point taken. It might be worthwhile developing a habit of writing a few Zygote tests while working on something here, to contribute there. Perhaps all gathered in one file, ideally to be coped to Diffractor later? |
@mcabbott Did you on purpose not add any definitions for |
No, I planned to but cut it out of this PR in the end. It ought to exist though, someone just has to write it. |
The simplest version is something like: function ProjectTo(x::NamedTuple)
elements = map(ProjectTo, x)
if Tuple(elements) isa NTuple{<:Any,ProjectTo{<:AbstractZero}}
return ProjectTo{NoTangent}()
else
return ProjectTo{Tangent{typeof(x)}}(; elements...)
end
end
(project::ProjectTo{<:Tangent{<:NamedTuple}})(dx::Tangent) = project(backing(dx))
function (project::ProjectTo{<:Tangent{<:NamedTuple}})(dx::NamedTuple)
dy = map((f, x) -> f(x), backing(project), dx)
return project_type(project)(; dy...)
end That demands exact equality of the names, which I think is what you want. |
Replaces #457