diff --git a/src/forward/lib.jl b/src/forward/lib.jl index 1ab7682f1..b297dab41 100644 --- a/src/forward/lib.jl +++ b/src/forward/lib.jl @@ -64,9 +64,10 @@ end using ..Zygote: literal_getproperty, literal_getfield, literal_getindex -_pushforward(dargs, ::typeof(literal_getproperty), x::NamedTuple, ::Val{f}) where {f} = - _pushforward(dargs, literal_getfield, x, Val(f)) - +function _pushforward(dargs, ::typeof(literal_getproperty), x::NamedTuple, + ::Val{property_name}) where {property_name} + return _pushforward(dargs, literal_getfield, x, Val(property_name)) +end _pushforward(dargs, ::typeof(getproperty), x::NamedTuple, f) = _pushforward(dargs, literal_getfield, x, Val(f)) diff --git a/src/lib/lib.jl b/src/lib/lib.jl index d61064a60..0d52c876d 100644 --- a/src/lib/lib.jl +++ b/src/lib/lib.jl @@ -225,20 +225,26 @@ end unwrap(val), back end -_pullback(cx::AContext, ::typeof(getfield), x, f::Symbol) = - _pullback(cx, literal_getfield, x, Val(f)) +_pullback(cx::AContext, ::typeof(getfield), x, field_name::Symbol) = + _pullback(cx, literal_getfield, x, Val(field_name)) -_pullback(cx::AContext, ::typeof(literal_getproperty), x::NamedTuple, ::Val{f}) where f = - _pullback(cx, literal_getfield, x, Val(f)) - -_pullback(cx::AContext, ::typeof(literal_getindex), x::NamedTuple, ::Val{f}) where f = - _pullback(cx, literal_getfield, x, Val(f)) - -_pullback(cx::AContext, ::typeof(literal_getproperty), x::Tuple, ::Val{f}) where f = - _pullback(cx, literal_getindex, x, Val(f)) +function _pullback(cx::AContext, ::typeof(literal_getproperty), x::NamedTuple, + ::Val{property_name}) where {property_name} + return _pullback(cx, literal_getfield, x, Val(property_name)) +end +function _pullback(cx::AContext, ::typeof(literal_getindex), x::NamedTuple, + ::Val{key}) where {key} + return _pullback(cx, literal_getfield, x, Val(key)) +end -_pullback(cx::AContext, ::typeof(literal_getfield), x::Tuple, ::Val{f}) where f = - _pullback(cx, literal_getindex, x, Val(f)) +function _pullback(cx::AContext, ::typeof(literal_getproperty), x::Tuple, + ::Val{index}) where {index} + return _pullback(cx, literal_getindex, x, Val(index)) +end +function _pullback(cx::AContext, ::typeof(literal_getfield), x::Tuple, + ::Val{index}) where {index} + return _pullback(cx, literal_getindex, x, Val(index)) +end grad_mut(x) = Ref{Any}(nt_nothing(x))