Skip to content

Commit b548b4e

Browse files
committed
accumulate NamedTuple + Tangent
1 parent 9a8a788 commit b548b4e

File tree

2 files changed

+15
-9
lines changed

2 files changed

+15
-9
lines changed

src/runtime.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,17 @@ end
1313
@Base.constprop :aggressive accum(a::NoTangent, b) = b
1414
@Base.constprop :aggressive accum(a, b::NoTangent) = a
1515
@Base.constprop :aggressive accum(a::NoTangent, b::NoTangent) = NoTangent()
16+
17+
using ChainRulesCore: Tangent, backing
18+
19+
function accum(x::Tangent{T}, y::NamedTuple) where T
20+
# @warn "gradient is both a Tangent and a NamedTuple" x y
21+
z = accum(backing(x), y)
22+
Tangent{T,typeof(z)}(z)
23+
end
24+
accum(x::NamedTuple, y::Tangent) = accum(y, x)
25+
26+
function accum(x::Tangent{T}, y::Tangent) where T
27+
z = accum(backing(x), backing(y))
28+
Tangent{T,typeof(z)}(z)
29+
end

test/runtests.jl

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -162,15 +162,7 @@ end
162162
# Make sure that there's no infinite recursion in kwarg calls
163163
g_kw(;x=1.0) = sin(x)
164164
f_kw(x) = g_kw(;x)
165-
@test bwd(f_kw)(1.0) == bwd(sin)(1.0) broken=true
166-
#=
167-
MethodError: no method matching +(::Tangent{var"#g_kw#47"{var"#g_kw#11#48"}, NamedTuple{(Symbol("#g_kw#11"),), Tuple{ZeroTangent}}}, ::Tangent{Diffractor.KwFunc{var"#g_kw#47"{var"#g_kw#11#48"}, var"#g_kw#47##kw"}, NamedTuple{(:kwf,), Tuple{ZeroTangent}}})
168-
...
169-
[2] elementwise_add(a::NamedTuple{(:contents,), Tuple{Tangent{var"#g_kw#47"{var"#g_kw#11#48"}, NamedTuple{(Symbol("#g_kw#11"),), Tuple{ZeroTangent}}}}}, b::NamedTuple{(:contents,), Tuple{Tangent{Diffractor.KwFunc{var"#g_kw#47"{var"#g_kw#11#48"}, var"#g_kw#47##kw"}, NamedTuple{(:kwf,), Tuple{ZeroTangent}}}}})
170-
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/ctmSK/src/tangent_types/tangent.jl:287
171-
[3] +(a::Tangent{Core.Box, NamedTuple{(:contents,), Tuple{Tangent{var"#g_kw#47"{var"#g_kw#11#48"}, NamedTuple{(Symbol("#g_kw#11"),), Tuple{ZeroTangent}}}}}}, b::Tangent{Core.Box, NamedTuple{(:contents,), Tuple{Tangent{Diffractor.KwFunc{var"#g_kw#47"{var"#g_kw#11#48"}, var"#g_kw#47##kw"}, NamedTuple{(:kwf,), Tuple{ZeroTangent}}}}}})
172-
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/ctmSK/src/tangent_arithmetic.jl:130
173-
=#
165+
@test bwd(f_kw)(1.0) == bwd(sin)(1.0)
174166

175167
function f_crit_edge(a, b, c, x)
176168
# A function with two critical edges. This used to trigger an issue where

0 commit comments

Comments
 (0)