-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ProjectTo{<:Tangent}
for tuples & Ref
#488
Changes from all commits
17a941e
8b8314e
1d75133
ecd318f
15756a4
3e5cda1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -272,18 +272,55 @@ end | |
##### | ||
|
||
# Ref | ||
# Note that Ref is mutable. This causes Zygote to represent its structral tangent not as a NamedTuple, | ||
# but as `Ref{Any}((x=val,))`. Here we use a Tangent, there is at present no mutable version, but see | ||
# https://github.com/JuliaDiff/ChainRulesCore.jl/issues/105 | ||
function ProjectTo(x::Ref) | ||
sub = ProjectTo(x[]) # should we worry about isdefined(Ref{Vector{Int}}(), :x)? | ||
if sub isa ProjectTo{<:AbstractZero} | ||
return ProjectTo{Tangent{typeof(x)}}(; x=sub) | ||
end | ||
(project::ProjectTo{<:Tangent{<:Ref}})(dx::Tangent) = project(Ref(first(backing(dx)))) | ||
function (project::ProjectTo{<:Tangent{<:Ref}})(dx::Ref) | ||
dy = project.x(dx[]) | ||
return project_type(project)(; x=dy) | ||
end | ||
# Since this works like a zero-array in broadcasting, it should also accept a number: | ||
(project::ProjectTo{<:Tangent{<:Ref}})(dx::Number) = project(Ref(dx)) | ||
|
||
# Tuple | ||
function ProjectTo(x::Tuple) | ||
elements = map(ProjectTo, x) | ||
if elements isa NTuple{<:Any,ProjectTo{<:AbstractZero}} | ||
return ProjectTo{NoTangent}() | ||
else | ||
return ProjectTo{Ref}(; type=typeof(x), x=sub) | ||
return ProjectTo{Tangent{typeof(x)}}(; elements=elements) | ||
end | ||
end | ||
(project::ProjectTo{Ref})(dx::Tangent{<:Ref}) = Tangent{project.type}(; x=project.x(dx.x)) | ||
(project::ProjectTo{Ref})(dx::Ref) = Tangent{project.type}(; x=project.x(dx[])) | ||
# Since this works like a zero-array in broadcasting, it should also accept a number: | ||
(project::ProjectTo{Ref})(dx::Number) = Tangent{project.type}(; x=project.x(dx)) | ||
# This method means that projection is re-applied to the contents of a Tangent. | ||
# We're not entirely sure whether this is every necessary; but it should be safe, | ||
# and should often compile away: | ||
(project::ProjectTo{<:Tangent{<:Tuple}})(dx::Tangent) = project(backing(dx)) | ||
function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::Tuple) | ||
len = length(project.elements) | ||
if length(dx) != len | ||
str = "tuple with length(x) == $len cannot have a gradient with length(dx) == $(length(dx))" | ||
throw(DimensionMismatch(str)) | ||
end | ||
Comment on lines
+305
to
+308
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will this not be caught by If we removed this check then this would basically be the general iterator fallback case. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It will, but the error is much less friendly... and might be a bug, JuliaLang/julia#42216 |
||
# Here map will fail if the lengths don't match, but gives a much less helpful error: | ||
dy = map((f, x) -> f(x), project.elements, dx) | ||
return project_type(project)(dy...) | ||
end | ||
function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::AbstractArray) | ||
for d in 1:ndims(dx) | ||
if size(dx, d) != get(length(project.elements), d, 1) | ||
throw(_projection_mismatch(axes(project.elements), size(dx))) | ||
end | ||
end | ||
dy = reshape(dx, axes(project.elements)) # allows for dx::OffsetArray | ||
dz = ntuple(i -> project.elements[i](dy[i]), length(project.elements)) | ||
return project_type(project)(dz...) | ||
end | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The above Tuple and AbstractArray cases are just optimizations of a general iterator one: function (project::ProjectTo{<:Tangent{<:Tuple}})(dxs) # iterator fallback
dzs = (f(dx) for (f, dx) in zip(project.elements, dxs))
return project_type(project)(dzs...)
end Should we have that as well? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could do. Do you have something in mind which might produce some weird type? If some NamedTuple leaks from Zygote, I think this will produce stranger error messages, since it may make a Tangent of the wrong length? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't really have it in mind, more is that that is the general case we are handling. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess I'm not entirely sure we need the array version at all. I was thinking about things like broadcasting, although that handles it explicitly... but map doesn't:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would be down with seeing it removed til we know we need it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My vote is to keep this, although I can't think of another example besides |
||
|
||
|
||
##### | ||
##### `LinearAlgebra` | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This means that projection on the output of this projector will disassemble the Tangent and re-process the Tuple inside. I'm not sure that's ideal. Maybe it's safe to pass on all Tangents without further investigation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should constant-fold out for most cases we care about.
Because it will see the ProjectTo for the backing elements as already being the right type, which is known from the type in the
project.elements
Do you want to check some with
@code_typed
/ Cthulu ?It is not safe to pass on all Tangents, because the tangent could be wrapping Complex Number/ Dense array that we need to fix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I wonder is whether I can think of the "mathematical" steps involving arrays etc. as being separate from the "structural" steps involving Tangents. If the first project, and then the backward flow assembles and de-assembles a Tangent, can this Tangent have "crossed a boundary" such that it belongs to a different argument type and hence may need further projection? I mostly think it can't; it would have to get un-packaed and those pieces operated on. But I'm not very sure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see.
Yeah, isn't this kind of true for many many operations that project?
They wouldn't need to project if were sure that the "mathematical" step before had projected and so only given them something good?
(also applies for if the mathematical step was a human constructing the tangents)
But it gets fuzzy around the edges?
How would
sum(sum, ((Diagonal[1f0,2f0]), (Diagonal[1f0,2f0])))
go down?