Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ProjectTo{<:Tangent} for tuples & Ref #488

Merged
merged 6 commits into from
Oct 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesCore"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "1.8.0"
version = "1.9.0"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down
49 changes: 43 additions & 6 deletions src/projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -272,18 +272,55 @@ end
#####

# Ref
# Note that Ref is mutable. This causes Zygote to represent its structral tangent not as a NamedTuple,
# but as `Ref{Any}((x=val,))`. Here we use a Tangent, there is at present no mutable version, but see
# https://github.com/JuliaDiff/ChainRulesCore.jl/issues/105
function ProjectTo(x::Ref)
sub = ProjectTo(x[]) # should we worry about isdefined(Ref{Vector{Int}}(), :x)?
if sub isa ProjectTo{<:AbstractZero}
return ProjectTo{Tangent{typeof(x)}}(; x=sub)
end
(project::ProjectTo{<:Tangent{<:Ref}})(dx::Tangent) = project(Ref(first(backing(dx))))
function (project::ProjectTo{<:Tangent{<:Ref}})(dx::Ref)
dy = project.x(dx[])
return project_type(project)(; x=dy)
end
# Since this works like a zero-array in broadcasting, it should also accept a number:
(project::ProjectTo{<:Tangent{<:Ref}})(dx::Number) = project(Ref(dx))

# Tuple
function ProjectTo(x::Tuple)
elements = map(ProjectTo, x)
if elements isa NTuple{<:Any,ProjectTo{<:AbstractZero}}
return ProjectTo{NoTangent}()
else
return ProjectTo{Ref}(; type=typeof(x), x=sub)
return ProjectTo{Tangent{typeof(x)}}(; elements=elements)
end
end
(project::ProjectTo{Ref})(dx::Tangent{<:Ref}) = Tangent{project.type}(; x=project.x(dx.x))
(project::ProjectTo{Ref})(dx::Ref) = Tangent{project.type}(; x=project.x(dx[]))
# Since this works like a zero-array in broadcasting, it should also accept a number:
(project::ProjectTo{Ref})(dx::Number) = Tangent{project.type}(; x=project.x(dx))
# This method means that projection is re-applied to the contents of a Tangent.
# We're not entirely sure whether this is every necessary; but it should be safe,
# and should often compile away:
(project::ProjectTo{<:Tangent{<:Tuple}})(dx::Tangent) = project(backing(dx))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This means that projection on the output of this projector will disassemble the Tangent and re-process the Tuple inside. I'm not sure that's ideal. Maybe it's safe to pass on all Tangents without further investigation?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should constant-fold out for most cases we care about.
Because it will see the ProjectTo for the backing elements as already being the right type, which is known from the type in the project.elements

Do you want to check some with @code_typed / Cthulu ?

It is not safe to pass on all Tangents, because the tangent could be wrapping Complex Number/ Dense array that we need to fix.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I wonder is whether I can think of the "mathematical" steps involving arrays etc. as being separate from the "structural" steps involving Tangents. If the first project, and then the backward flow assembles and de-assembles a Tangent, can this Tangent have "crossed a boundary" such that it belongs to a different argument type and hence may need further projection? I mostly think it can't; it would have to get un-packaed and those pieces operated on. But I'm not very sure.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see.
Yeah, isn't this kind of true for many many operations that project?
They wouldn't need to project if were sure that the "mathematical" step before had projected and so only given them something good?
(also applies for if the mathematical step was a human constructing the tangents)

But it gets fuzzy around the edges?
How would sum(sum, ((Diagonal[1f0,2f0]), (Diagonal[1f0,2f0]))) go down?

function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::Tuple)
len = length(project.elements)
if length(dx) != len
str = "tuple with length(x) == $len cannot have a gradient with length(dx) == $(length(dx))"
throw(DimensionMismatch(str))
end
Comment on lines +305 to +308
Copy link
Member

@oxinabox oxinabox Oct 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this not be caught by map?

If we removed this check then this would basically be the general iterator fallback case.
https://github.com/JuliaDiff/ChainRulesCore.jl/pull/488/files#r726169970

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will, but the error is much less friendly... and might be a bug, JuliaLang/julia#42216

# Here map will fail if the lengths don't match, but gives a much less helpful error:
dy = map((f, x) -> f(x), project.elements, dx)
return project_type(project)(dy...)
end
function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::AbstractArray)
for d in 1:ndims(dx)
if size(dx, d) != get(length(project.elements), d, 1)
throw(_projection_mismatch(axes(project.elements), size(dx)))
end
end
dy = reshape(dx, axes(project.elements)) # allows for dx::OffsetArray
dz = ntuple(i -> project.elements[i](dy[i]), length(project.elements))
return project_type(project)(dz...)
end
Copy link
Member

@oxinabox oxinabox Oct 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The above Tuple and AbstractArray cases are just optimizations of a general iterator one:

function (project::ProjectTo{<:Tangent{<:Tuple}})(dxs)  # iterator fallback
    dzs = (f(dx) for (f, dx) in zip(project.elements, dxs))
    return project_type(project)(dzs...)
end

Should we have that as well?
And then we can note the others as just being optimizations?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could do. Do you have something in mind which might produce some weird type?

If some NamedTuple leaks from Zygote, I think this will produce stranger error messages, since it may make a Tangent of the wrong length?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really have it in mind, more is that that is the general case we are handling.
It is weird that we only actually handle the two optimizable versions?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I'm not entirely sure we need the array version at all. I was thinking about things like broadcasting, although that handles it explicitly... but map doesn't:

julia> Zygote.pullback(x -> sum(map(+, x, [1,2])), (1,2))[2](1)
([1, 1],)

julia> gradient(x -> sum(map(+, x, [1,2])), (1,2))  # uses projection
((1.0, 1.0),)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be down with seeing it removed til we know we need it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My vote is to keep this, although I can't think of another example besides map right now which uses this. It looks like gradient(x -> sum([i^2 for i in x]), (1,2)) does not.



#####
##### `LinearAlgebra`
Expand Down
19 changes: 18 additions & 1 deletion test/projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,29 @@ struct NoSuperType end
prefvec = ProjectTo(Ref([1, 2, 3 + 4im])) # recurses into contents
@test prefvec(Ref(1:3)).x isa Vector{ComplexF64}
@test prefvec(Tangent{Base.RefValue}(; x=1:3)).x isa Vector{ComplexF64}
@test_skip @test_throws DimensionMismatch prefvec(Tangent{Base.RefValue}(; x=1:5))
@test_throws DimensionMismatch prefvec(Tangent{Base.RefValue}(; x=1:5))

@test ProjectTo(Ref(true)) isa ProjectTo{NoTangent}
@test ProjectTo(Ref([false]')) isa ProjectTo{NoTangent}
end

@testset "Base: Tuple" begin
pt1 = ProjectTo((1.0,))
@test pt1((1 + im,)) == Tangent{Tuple{Float64}}(1.0,)
@test pt1(pt1((1,))) == pt1(pt1((1,))) # accepts correct Tangent
@test pt1(Tangent{Any}(1)) == pt1((1,)) # accepts Tangent{Any}
@test pt1([1,]) == Tangent{Tuple{Float64}}(1.0,) # accepts Vector
@test pt1(NoTangent()) === NoTangent()
@test pt1(ZeroTangent()) === ZeroTangent()

@test_throws Exception pt1([1, 2]) # DimensionMismatch, wrong length
@test_throws Exception pt1([])

pt3 = ProjectTo(([1, 2, 3], false, :gamma)) # partly non-differentiable
@test pt3((1:3, 4, 5)) == Tangent{Tuple{Vector{Int}, Bool, Symbol}}([1.0, 2.0, 3.0], NoTangent(), NoTangent())
@test ProjectTo((true, [false])) isa ProjectTo{NoTangent}
end

@testset "Base: non-diff" begin
@test ProjectTo(:a)(1) == NoTangent()
@test ProjectTo('b')(2) == NoTangent()
Expand Down