Skip to content

Commit fc55c22

Browse files
committed
fixup
1 parent b548b4e commit fc55c22

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

src/runtime.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,25 @@ struct DiffractorRuleConfig <: RuleConfig{Union{HasReverseMode,HasForwardsMode}}
55
@Base.constprop :aggressive accum(a::Tuple, b::Tuple) = map(accum, a, b)
66
@Base.constprop :aggressive @generated function accum(x::NamedTuple, y::NamedTuple)
77
fnames = union(fieldnames(x), fieldnames(y))
8+
isempty(fnames) && return :((;)) # code below makes () instead
89
gradx(f) = f in fieldnames(x) ? :(getfield(x, $(quot(f)))) : :(ZeroTangent())
910
grady(f) = f in fieldnames(y) ? :(getfield(y, $(quot(f)))) : :(ZeroTangent())
1011
Expr(:tuple, [:($f=accum($(gradx(f)), $(grady(f)))) for f in fnames]...)
1112
end
1213
@Base.constprop :aggressive accum(a, b, c, args...) = accum(accum(a, b), c, args...)
13-
@Base.constprop :aggressive accum(a::NoTangent, b) = b
14-
@Base.constprop :aggressive accum(a, b::NoTangent) = a
15-
@Base.constprop :aggressive accum(a::NoTangent, b::NoTangent) = NoTangent()
14+
@Base.constprop :aggressive accum(a::AbstractZero, b) = b
15+
@Base.constprop :aggressive accum(a, b::AbstractZero) = a
16+
@Base.constprop :aggressive accum(a::AbstractZero, b::AbstractZero) = NoTangent()
1617

1718
using ChainRulesCore: Tangent, backing
1819

1920
function accum(x::Tangent{T}, y::NamedTuple) where T
2021
# @warn "gradient is both a Tangent and a NamedTuple" x y
21-
z = accum(backing(x), y)
22-
Tangent{T,typeof(z)}(z)
22+
_tangent(T, accum(backing(x), y))
2323
end
2424
accum(x::NamedTuple, y::Tangent) = accum(y, x)
25+
# This solves an ambiguity, but also avoids Tangent{ZeroTangent}() which + does not:
26+
accum(x::Tangent{T}, y::Tangent) where T = _tangent(T, accum(backing(x), backing(y)))
2527

26-
function accum(x::Tangent{T}, y::Tangent) where T
27-
z = accum(backing(x), backing(y))
28-
Tangent{T,typeof(z)}(z)
29-
end
28+
_tangent(::Type{T}, z) where T = Tangent{T,typeof(z)}(z)
29+
_tangent(::Type, ::NamedTuple{()}) = NoTangent()

0 commit comments

Comments
 (0)