@@ -5,25 +5,25 @@ struct DiffractorRuleConfig <: RuleConfig{Union{HasReverseMode,HasForwardsMode}}
5
5
@Base . constprop :aggressive accum (a:: Tuple , b:: Tuple ) = map (accum, a, b)
6
6
@Base . constprop :aggressive @generated function accum (x:: NamedTuple , y:: NamedTuple )
7
7
fnames = union (fieldnames (x), fieldnames (y))
8
+ isempty (fnames) && return :((;)) # code below makes () instead
8
9
gradx (f) = f in fieldnames (x) ? :(getfield (x, $ (quot (f)))) : :(ZeroTangent ())
9
10
grady (f) = f in fieldnames (y) ? :(getfield (y, $ (quot (f)))) : :(ZeroTangent ())
10
11
Expr (:tuple , [:($ f= accum ($ (gradx (f)), $ (grady (f)))) for f in fnames]. .. )
11
12
end
12
13
@Base . constprop :aggressive accum (a, b, c, args... ) = accum (accum (a, b), c, args... )
13
- @Base . constprop :aggressive accum (a:: NoTangent , b) = b
14
- @Base . constprop :aggressive accum (a, b:: NoTangent ) = a
15
- @Base . constprop :aggressive accum (a:: NoTangent , b:: NoTangent ) = NoTangent ()
14
+ @Base . constprop :aggressive accum (a:: AbstractZero , b) = b
15
+ @Base . constprop :aggressive accum (a, b:: AbstractZero ) = a
16
+ @Base . constprop :aggressive accum (a:: AbstractZero , b:: AbstractZero ) = NoTangent ()
16
17
17
18
using ChainRulesCore: Tangent, backing
18
19
19
20
function accum (x:: Tangent{T} , y:: NamedTuple ) where T
20
21
# @warn "gradient is both a Tangent and a NamedTuple" x y
21
- z = accum (backing (x), y)
22
- Tangent {T,typeof(z)} (z)
22
+ _tangent (T, accum (backing (x), y))
23
23
end
24
24
accum (x:: NamedTuple , y:: Tangent ) = accum (y, x)
25
+ # This solves an ambiguity, but also avoids Tangent{ZeroTangent}() which + does not:
26
+ accum (x:: Tangent{T} , y:: Tangent ) where T = _tangent (T, accum (backing (x), backing (y)))
25
27
26
- function accum (x:: Tangent{T} , y:: Tangent ) where T
27
- z = accum (backing (x), backing (y))
28
- Tangent {T,typeof(z)} (z)
29
- end
28
+ _tangent (:: Type{T} , z) where T = Tangent {T,typeof(z)} (z)
29
+ _tangent (:: Type , :: NamedTuple{()} ) = NoTangent ()
0 commit comments