Skip to content

Commit

Permalink
style improvements in lib.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
simeonschaub committed Dec 18, 2020
1 parent cb8ac71 commit f4fa03a
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 15 deletions.
7 changes: 4 additions & 3 deletions src/forward/lib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,10 @@ end

using ..Zygote: literal_getproperty, literal_getfield, literal_getindex

_pushforward(dargs, ::typeof(literal_getproperty), x::NamedTuple, ::Val{f}) where {f} =
_pushforward(dargs, literal_getfield, x, Val(f))

function _pushforward(dargs, ::typeof(literal_getproperty), x::NamedTuple,
::Val{property_name}) where {property_name}
return _pushforward(dargs, literal_getfield, x, Val(property_name))
end
_pushforward(dargs, ::typeof(getproperty), x::NamedTuple, f) =
_pushforward(dargs, literal_getfield, x, Val(f))

Expand Down
30 changes: 18 additions & 12 deletions src/lib/lib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,20 +225,26 @@ end
unwrap(val), back
end

_pullback(cx::AContext, ::typeof(getfield), x, f::Symbol) =
_pullback(cx, literal_getfield, x, Val(f))
_pullback(cx::AContext, ::typeof(getfield), x, field_name::Symbol) =
_pullback(cx, literal_getfield, x, Val(field_name))

_pullback(cx::AContext, ::typeof(literal_getproperty), x::NamedTuple, ::Val{f}) where f =
_pullback(cx, literal_getfield, x, Val(f))

_pullback(cx::AContext, ::typeof(literal_getindex), x::NamedTuple, ::Val{f}) where f =
_pullback(cx, literal_getfield, x, Val(f))

_pullback(cx::AContext, ::typeof(literal_getproperty), x::Tuple, ::Val{f}) where f =
_pullback(cx, literal_getindex, x, Val(f))
function _pullback(cx::AContext, ::typeof(literal_getproperty), x::NamedTuple,
::Val{property_name}) where {property_name}
return _pullback(cx, literal_getfield, x, Val(property_name))
end
function _pullback(cx::AContext, ::typeof(literal_getindex), x::NamedTuple,
::Val{key}) where {key}
return _pullback(cx, literal_getfield, x, Val(key))
end

_pullback(cx::AContext, ::typeof(literal_getfield), x::Tuple, ::Val{f}) where f =
_pullback(cx, literal_getindex, x, Val(f))
function _pullback(cx::AContext, ::typeof(literal_getproperty), x::Tuple,
::Val{index}) where {index}
return _pullback(cx, literal_getindex, x, Val(index))
end
function _pullback(cx::AContext, ::typeof(literal_getfield), x::Tuple,
::Val{index}) where {index}
return _pullback(cx, literal_getindex, x, Val(index))
end

grad_mut(x) = Ref{Any}(nt_nothing(x))

Expand Down

0 comments on commit f4fa03a

Please sign in to comment.