Skip to content

Commit

Permalink
Add ProjectTo(::NamedTuple) (#515)
Browse files Browse the repository at this point in the history
* Add `ProjectTo(::NamedTuple)`

Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>

* Allow different order and subset of named tuples

Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
  • Loading branch information
devmotion and mcabbott authored Dec 5, 2021
1 parent 2dcd44b commit addf6d9
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 3 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesCore"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "1.11.1"
version = "1.11.2"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down
55 changes: 53 additions & 2 deletions src/projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ end
# Since this works like a zero-array in broadcasting, it should also accept a number:
(project::ProjectTo{<:Tangent{<:Ref}})(dx::Number) = project(Ref(dx))

# Tuple
# Tuple and NamedTuple
function ProjectTo(x::Tuple)
elements = map(ProjectTo, x)
if elements isa NTuple{<:Any,ProjectTo{<:AbstractZero}}
Expand All @@ -296,10 +296,22 @@ function ProjectTo(x::Tuple)
return ProjectTo{Tangent{typeof(x)}}(; elements=elements)
end
end
function ProjectTo(x::NamedTuple)
elements = map(ProjectTo, x)
if Tuple(elements) isa NTuple{<:Any,ProjectTo{<:AbstractZero}}
return ProjectTo{NoTangent}()
else
return ProjectTo{Tangent{typeof(x)}}(; elements...)
end
end

# This method means that projection is re-applied to the contents of a Tangent.
# We're not entirely sure whether this is every necessary; but it should be safe,
# and should often compile away:
(project::ProjectTo{<:Tangent{<:Tuple}})(dx::Tangent) = project(backing(dx))
function (project::ProjectTo{<:Tangent{<:Union{Tuple,NamedTuple}}})(dx::Tangent)
return project(backing(dx))
end

function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::Tuple)
len = length(project.elements)
if length(dx) != len
Expand All @@ -310,6 +322,45 @@ function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::Tuple)
dy = map((f, x) -> f(x), project.elements, dx)
return project_type(project)(dy...)
end
function (project::ProjectTo{<:Tangent{<:NamedTuple}})(dx::NamedTuple)
dy = _project_namedtuple(backing(project), dx)
return project_type(project)(; dy...)
end

# Diffractor returns not necessarily a named tuple with all keys and of the same order as
# the projector
# Thus we can't use `map`
function _project_namedtuple(f::NamedTuple{fn,ft}, x::NamedTuple{xn,xt}) where {fn,ft,xn,xt}
if @generated
vals = Any[
if xn[i] in fn
:(getfield(f, $(QuoteNode(xn[i])))(getfield(x, $(QuoteNode(xn[i])))))
else
throw(
ArgumentError(
"named tuple with keys(x) == $fn cannot have a gradient with key $(xn[i])",
),
)
end for i in 1:length(xn)
]
:(NamedTuple{$xn}(($(vals...),)))
else
vals = ntuple(Val(length(xn))) do i
name = xn[i]
if name in fn
getfield(f, name)(getfield(x, name))
else
throw(
ArgumentError(
"named tuple with keys(x) == $fn cannot have a gradient with key $(xn[i])",
),
)
end
end
NamedTuple{xn}(vals)
end
end

function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::AbstractArray)
for d in 1:ndims(dx)
if size(dx, d) != get(length(project.elements), d, 1)
Expand Down
36 changes: 36 additions & 0 deletions test/projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,42 @@ struct NoSuperType end
@test ProjectTo((true, [false])) isa ProjectTo{NoTangent}
end

@testset "Base: NamedTuple" begin
pt1 = @inferred(ProjectTo((a=1.0,)))
@test @inferred(pt1((a=1 + im,))) ==
Tangent{NamedTuple{(:a,),Tuple{Float64}}}(; a=1.0)
@test @inferred(pt1(pt1((a=1,)))) == @inferred(pt1(pt1((a=1,)))) # accepts correct Tangent
@test @inferred(pt1(Tangent{Any}(; a=1))) == pt1((a=1,)) # accepts Tangent{Any}
@test @inferred(pt1(NoTangent())) === NoTangent()
@test @inferred(pt1(ZeroTangent())) === ZeroTangent()

@test_throws Exception pt1((a=1, b=2)) # no projector for `b`
@test_throws Exception pt1((b=1,)) # no projector for `b`

# subset is allowed (required for Diffractor)
@test @inferred(pt1(NamedTuple())) === Tangent{NamedTuple{(:a,),Tuple{Float64}}}()

pt3 = @inferred(ProjectTo((a=[1, 2, 3], b=false, c=:gamma))) # partly non-differentiable
@test @inferred(pt3((a=1:3, b=4, c=5))) ==
Tangent{NamedTuple{(:a, :b, :c),Tuple{Vector{Int},Bool,Symbol}}}(;
a=[1.0, 2.0, 3.0], b=NoTangent(), c=NoTangent()
)

# different order
@test @inferred(pt3((b=4, a=1:3, c=5))) ==
Tangent{NamedTuple{(:a, :b, :c),Tuple{Vector{Int},Bool,Symbol}}}(;
b=NoTangent(), a=[1.0, 2.0, 3.0], c=NoTangent()
)

# only a subset
@test @inferred(pt3((c=5,))) ==
Tangent{NamedTuple{(:a, :b, :c),Tuple{Vector{Int},Bool,Symbol}}}(;
c=NoTangent()
)

@test @inferred(ProjectTo((a=true, b=[false]))) isa ProjectTo{NoTangent}
end

@testset "Base: non-diff" begin
@test ProjectTo(:a)(1) == NoTangent()
@test ProjectTo('b')(2) == NoTangent()
Expand Down

2 comments on commit addf6d9

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/49998

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.11.2 -m "<description of version>" addf6d97201f0130e895c944dd8091e66a8b9477
git push origin v1.11.2

Please sign in to comment.