diff --git a/Project.toml b/Project.toml index c94d6c4dc..bc8781e18 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.11.1" +version = "1.11.2" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" diff --git a/src/projection.jl b/src/projection.jl index 78a9b389b..8a9c8ad0a 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -287,7 +287,7 @@ end # Since this works like a zero-array in broadcasting, it should also accept a number: (project::ProjectTo{<:Tangent{<:Ref}})(dx::Number) = project(Ref(dx)) -# Tuple +# Tuple and NamedTuple function ProjectTo(x::Tuple) elements = map(ProjectTo, x) if elements isa NTuple{<:Any,ProjectTo{<:AbstractZero}} @@ -296,10 +296,22 @@ function ProjectTo(x::Tuple) return ProjectTo{Tangent{typeof(x)}}(; elements=elements) end end +function ProjectTo(x::NamedTuple) + elements = map(ProjectTo, x) + if Tuple(elements) isa NTuple{<:Any,ProjectTo{<:AbstractZero}} + return ProjectTo{NoTangent}() + else + return ProjectTo{Tangent{typeof(x)}}(; elements...) + end +end + # This method means that projection is re-applied to the contents of a Tangent. # We're not entirely sure whether this is every necessary; but it should be safe, # and should often compile away: -(project::ProjectTo{<:Tangent{<:Tuple}})(dx::Tangent) = project(backing(dx)) +function (project::ProjectTo{<:Tangent{<:Union{Tuple,NamedTuple}}})(dx::Tangent) + return project(backing(dx)) +end + function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::Tuple) len = length(project.elements) if length(dx) != len @@ -310,6 +322,45 @@ function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::Tuple) dy = map((f, x) -> f(x), project.elements, dx) return project_type(project)(dy...) end +function (project::ProjectTo{<:Tangent{<:NamedTuple}})(dx::NamedTuple) + dy = _project_namedtuple(backing(project), dx) + return project_type(project)(; dy...) +end + +# Diffractor returns not necessarily a named tuple with all keys and of the same order as +# the projector +# Thus we can't use `map` +function _project_namedtuple(f::NamedTuple{fn,ft}, x::NamedTuple{xn,xt}) where {fn,ft,xn,xt} + if @generated + vals = Any[ + if xn[i] in fn + :(getfield(f, $(QuoteNode(xn[i])))(getfield(x, $(QuoteNode(xn[i]))))) + else + throw( + ArgumentError( + "named tuple with keys(x) == $fn cannot have a gradient with key $(xn[i])", + ), + ) + end for i in 1:length(xn) + ] + :(NamedTuple{$xn}(($(vals...),))) + else + vals = ntuple(Val(length(xn))) do i + name = xn[i] + if name in fn + getfield(f, name)(getfield(x, name)) + else + throw( + ArgumentError( + "named tuple with keys(x) == $fn cannot have a gradient with key $(xn[i])", + ), + ) + end + end + NamedTuple{xn}(vals) + end +end + function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::AbstractArray) for d in 1:ndims(dx) if size(dx, d) != get(length(project.elements), d, 1) diff --git a/test/projection.jl b/test/projection.jl index 592b0a84d..bbd8cd7ff 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -160,6 +160,42 @@ struct NoSuperType end @test ProjectTo((true, [false])) isa ProjectTo{NoTangent} end + @testset "Base: NamedTuple" begin + pt1 = @inferred(ProjectTo((a=1.0,))) + @test @inferred(pt1((a=1 + im,))) == + Tangent{NamedTuple{(:a,),Tuple{Float64}}}(; a=1.0) + @test @inferred(pt1(pt1((a=1,)))) == @inferred(pt1(pt1((a=1,)))) # accepts correct Tangent + @test @inferred(pt1(Tangent{Any}(; a=1))) == pt1((a=1,)) # accepts Tangent{Any} + @test @inferred(pt1(NoTangent())) === NoTangent() + @test @inferred(pt1(ZeroTangent())) === ZeroTangent() + + @test_throws Exception pt1((a=1, b=2)) # no projector for `b` + @test_throws Exception pt1((b=1,)) # no projector for `b` + + # subset is allowed (required for Diffractor) + @test @inferred(pt1(NamedTuple())) === Tangent{NamedTuple{(:a,),Tuple{Float64}}}() + + pt3 = @inferred(ProjectTo((a=[1, 2, 3], b=false, c=:gamma))) # partly non-differentiable + @test @inferred(pt3((a=1:3, b=4, c=5))) == + Tangent{NamedTuple{(:a, :b, :c),Tuple{Vector{Int},Bool,Symbol}}}(; + a=[1.0, 2.0, 3.0], b=NoTangent(), c=NoTangent() + ) + + # different order + @test @inferred(pt3((b=4, a=1:3, c=5))) == + Tangent{NamedTuple{(:a, :b, :c),Tuple{Vector{Int},Bool,Symbol}}}(; + b=NoTangent(), a=[1.0, 2.0, 3.0], c=NoTangent() + ) + + # only a subset + @test @inferred(pt3((c=5,))) == + Tangent{NamedTuple{(:a, :b, :c),Tuple{Vector{Int},Bool,Symbol}}}(; + c=NoTangent() + ) + + @test @inferred(ProjectTo((a=true, b=[false]))) isa ProjectTo{NoTangent} + end + @testset "Base: non-diff" begin @test ProjectTo(:a)(1) == NoTangent() @test ProjectTo('b')(2) == NoTangent()