Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ProjectTo{<:Tangent} for tuples & Ref #488

Merged
merged 6 commits into from
Oct 15, 2021
Merged

Conversation

mcabbott
Copy link
Member

Replaces #457

@codecov-commenter
Copy link

codecov-commenter commented Oct 11, 2021

Codecov Report

Merging #488 (3e5cda1) into main (d0c3599) will decrease coverage by 0.06%.
The diff coverage is 92.30%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #488      +/-   ##
==========================================
- Coverage   93.00%   92.94%   -0.07%     
==========================================
  Files          15       15              
  Lines         801      822      +21     
==========================================
+ Hits          745      764      +19     
- Misses         56       58       +2     
Impacted Files Coverage Δ
src/projection.jl 97.47% <92.30%> (-0.68%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update d0c3599...3e5cda1. Read the comment docs.

# Since this works like a zero-array in broadcasting, it should also accept a number:
(project::ProjectTo{Ref})(dx::Number) = Tangent{project.type}(; x=project.x(dx))

(project::ProjectTo{<:Tangent{<:Tuple}})(dx::Tangent) = project(backing(dx))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This means that projection on the output of this projector will disassemble the Tangent and re-process the Tuple inside. I'm not sure that's ideal. Maybe it's safe to pass on all Tangents without further investigation?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should constant-fold out for most cases we care about.
Because it will see the ProjectTo for the backing elements as already being the right type, which is known from the type in the project.elements

Do you want to check some with @code_typed / Cthulu ?

It is not safe to pass on all Tangents, because the tangent could be wrapping Complex Number/ Dense array that we need to fix.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I wonder is whether I can think of the "mathematical" steps involving arrays etc. as being separate from the "structural" steps involving Tangents. If the first project, and then the backward flow assembles and de-assembles a Tangent, can this Tangent have "crossed a boundary" such that it belongs to a different argument type and hence may need further projection? I mostly think it can't; it would have to get un-packaed and those pieces operated on. But I'm not very sure.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see.
Yeah, isn't this kind of true for many many operations that project?
They wouldn't need to project if were sure that the "mathematical" step before had projected and so only given them something good?
(also applies for if the mathematical step was a human constructing the tangents)

But it gets fuzzy around the edges?
How would sum(sum, ((Diagonal[1f0,2f0]), (Diagonal[1f0,2f0]))) go down?

@oxinabox
Copy link
Member

The formatter is unhappy see: https://github.com/JuliaDiff/ChainRulesCore.jl/pull/488/checks?check_run_id=3854350212#step:5:187

I think it wasn't allowed to post suggestions as it was made from a fork.
(See #489)

That should be fixed, or the formatter will complain in other PRs.
(and also it is i guess right)

src/projection.jl Outdated Show resolved Hide resolved
dy = reshape(dx, axes(project.elements)) # allows for dx::OffsetArray
dz = ntuple(i -> project.elements[i](dy[i]), length(project.elements))
return project_type(project)(dz...)
end
Copy link
Member

@oxinabox oxinabox Oct 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The above Tuple and AbstractArray cases are just optimizations of a general iterator one:

function (project::ProjectTo{<:Tangent{<:Tuple}})(dxs)  # iterator fallback
    dzs = (f(dx) for (f, dx) in zip(project.elements, dxs))
    return project_type(project)(dzs...)
end

Should we have that as well?
And then we can note the others as just being optimizations?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could do. Do you have something in mind which might produce some weird type?

If some NamedTuple leaks from Zygote, I think this will produce stranger error messages, since it may make a Tangent of the wrong length?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really have it in mind, more is that that is the general case we are handling.
It is weird that we only actually handle the two optimizable versions?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I'm not entirely sure we need the array version at all. I was thinking about things like broadcasting, although that handles it explicitly... but map doesn't:

julia> Zygote.pullback(x -> sum(map(+, x, [1,2])), (1,2))[2](1)
([1, 1],)

julia> gradient(x -> sum(map(+, x, [1,2])), (1,2))  # uses projection
((1.0, 1.0),)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be down with seeing it removed til we know we need it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My vote is to keep this, although I can't think of another example besides map right now which uses this. It looks like gradient(x -> sum([i^2 for i in x]), (1,2)) does not.

src/projection.jl Outdated Show resolved Hide resolved
Comment on lines +301 to +304
if length(dx) != len
str = "tuple with length(x) == $len cannot have a gradient with length(dx) == $(length(dx))"
throw(DimensionMismatch(str))
end
Copy link
Member

@oxinabox oxinabox Oct 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this not be caught by map?

If we removed this check then this would basically be the general iterator fallback case.
https://github.com/JuliaDiff/ChainRulesCore.jl/pull/488/files#r726169970

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will, but the error is much less friendly... and might be a bug, JuliaLang/julia#42216

src/projection.jl Outdated Show resolved Hide resolved
src/projection.jl Outdated Show resolved Hide resolved
Copy link
Member

@oxinabox oxinabox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically LGTM.
A few things to address.
Once sorted merge when happy

mcabbott and others added 3 commits October 11, 2021 10:38
Co-authored-by: Lyndon White <oxinabox@ucc.asn.au>
@mcabbott
Copy link
Member Author

Ref is I think the only mutable type for which we have defined projection rules. I think the lesson of JuliaDiff/ChainRules.jl#539 (comment) is that Zygote's structural tangents are weird for such things.

Does Tangent make any such distinction, which needs handling here?

If not, then it is I think entirely a matter of fixing the Zygote/ChainRules interface. Maybe wrap_chainrules_input should handle this; it had a method for NamedTuples but not for Ref. There appear to be zero tests. In fact I'm pretty surprised I didn't stumble across this while fiddling with projection rules. Makes me wonder a bit if we ought to have more explicit downstream testing, i.e. don't just trust the package's tests, but add tests here (or in ChainRules) using packages explicitly, while thinking about any given feature.

@oxinabox
Copy link
Member

Does Tangent make any such distinction, which needs handling here?

It does not.
There is an open issue for it, as it can come up in forward-mode.
#105

If not, then it is I think entirely a matter of fixing the Zygote/ChainRules interface. Maybe wrap_chainrules_input should handle this; it had a method for NamedTuples but not for Ref.

Correct.
When I wrote that I did not know Zygote did this.

There appear to be zero tests. In fact I'm pretty surprised I didn't stumble across this while fiddling with projection rules.

You might be right, there are not tests of it directly.
I think in most cases though there are tests that hit it implictly.
But not for the mutable struct case I guess.
(The direct tests are mostly focused around making sure that the rules are picked up, and around things that were wrong before like functions that return tuples and functions that take keyword arguments)

Makes me wonder a bit if we ought to have more explicit downstream testing, i.e. don't just trust the package's tests, but add tests here (or in ChainRules) using packages explicitly, while thinking about any given feature.

I have worked on packages that have direct downstream tests.
It is basically hell, due to the near circular dependency.
It causes all kinds of problems with version bounds and things get fragile and start to display "intimate knowledge" of each other.
Much better is the way we are doing it here, which includes the downstream tests automatically passing if compat is not permitted.
We can just fix the tests in Zygote etc.

We should test features in ways that are like how they are intended to be used by downstream packages.
But we should not have tests that can't be fixed by changing this package.

@mcabbott
Copy link
Member Author

Ok, thanks!

It looks like we need this PR first. Then the fix is one line:

julia> gradient(x -> (x.x)^2, Ref(3))
ERROR: MethodError: no method matching (::ProjectTo{Float64, NamedTuple{(), Tuple{}}})(::NamedTuple{(:x,), Tuple{Int64}})
  ...
Stacktrace:
 [1] (::ProjectTo{Tangent{Base.RefValue{Int64}}, NamedTuple{(:x,), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}})(dx::Base.RefValue{Any})
   @ ChainRulesCore ~/.julia/dev/ChainRulesCore/src/projection.jl:282
 [2] _project
...

julia> @eval Zygote wrap_chainrules_input(x::Ref) = wrap_chainrules_input(x[]);

julia> gradient(x -> (x.x)^2, Ref(3))
((x = 6.0,),)

What that one line won't do is allow wrap_chainrules_output to restore this Ref(NamedTuple) thing. I don't see a way to do that without #105. Maybe I need to think up a way to test if that matters to Zygote.

@mcabbott
Copy link
Member Author

hell, due to the near circular dependency.

Point taken. It might be worthwhile developing a habit of writing a few Zygote tests while working on something here, to contribute there. Perhaps all gathered in one file, ideally to be coped to Diffractor later?

@devmotion
Copy link
Member

@mcabbott Did you on purpose not add any definitions for ProjectTo(x::NamedTuple)? They would be needed in JuliaDiff/ChainRulesTestUtils.jl#224, I tried to explain it a bit in JuliaDiff/ChainRulesTestUtils.jl#224 (comment).

@mcabbott
Copy link
Member Author

mcabbott commented Nov 9, 2021

No, I planned to but cut it out of this PR in the end. It ought to exist though, someone just has to write it.

@mcabbott
Copy link
Member Author

mcabbott commented Nov 9, 2021

The simplest version is something like:

function ProjectTo(x::NamedTuple)
    elements = map(ProjectTo, x)
    if Tuple(elements) isa NTuple{<:Any,ProjectTo{<:AbstractZero}}
        return ProjectTo{NoTangent}()
    else
        return ProjectTo{Tangent{typeof(x)}}(; elements...)
    end
end
(project::ProjectTo{<:Tangent{<:NamedTuple}})(dx::Tangent) = project(backing(dx))
function (project::ProjectTo{<:Tangent{<:NamedTuple}})(dx::NamedTuple)
    dy = map((f, x) -> f(x), backing(project), dx)
    return project_type(project)(; dy...)
end

That demands exact equality of the names, which I think is what you want.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants