From 786d129d90ddf8ac2bf4899d4fce2d03cd21cd4d Mon Sep 17 00:00:00 2001 From: Simeon Schaub Date: Sat, 21 Nov 2020 12:22:43 +0100 Subject: [PATCH 1/7] WIP: fix Zygote on 1.6 Relies on ZygoteRules.jl#13. The test I marked as broken doesn't fail if run in an interpreter, so this might be a Julia bug. Other than that, I still see failures related to LoopVectorization, might be due to llvmcall. @chriselrod, do you have any ideas? --- Project.toml | 2 +- src/compiler/reverse.jl | 4 +++- src/forward/lib.jl | 9 ++++++--- src/lib/lib.jl | 15 ++++++++------- test/chainrules.jl | 11 +++++++++-- test/compiler.jl | 6 +++++- 6 files changed, 32 insertions(+), 15 deletions(-) diff --git a/Project.toml b/Project.toml index bf34d43ea..9aa7edb2d 100644 --- a/Project.toml +++ b/Project.toml @@ -31,7 +31,7 @@ MacroTools = "0.5" NaNMath = "0.3" Requires = "0.5, 1.0" SpecialFunctions = "0.10, 1.0" -ZygoteRules = "0.2" +ZygoteRules = "0.2.1" julia = "1.3" [extras] diff --git a/src/compiler/reverse.jl b/src/compiler/reverse.jl index 3c0c61c5e..5d21eec35 100644 --- a/src/compiler/reverse.jl +++ b/src/compiler/reverse.jl @@ -32,7 +32,9 @@ is_literal_getproperty(ex) = function instrument_getproperty!(ir, v, ex) is_literal_getproperty(ex) ? - (ir[v] = xcall(Zygote, :literal_getproperty, ex.args[2], Val(unwrapquote(ex.args[3])))) : + iscall(ex, Base, :getproperty) ? + (ir[v] = xcall(Zygote, :literal_getproperty, ex.args[2], Val(unwrapquote(ex.args[3])))) : + (ir[v] = xcall(Zygote, :literal_getproperty, ex.args[2], Val(unwrapquote(ex.args[3])), ex.args[1])) : ex end diff --git a/src/forward/lib.jl b/src/forward/lib.jl index ecfc4c511..1bb715f57 100644 --- a/src/forward/lib.jl +++ b/src/forward/lib.jl @@ -67,10 +67,13 @@ using ..Zygote: literal_getproperty, literal_getindex _pushforward(dargs, ::typeof(getproperty), x, f) = _pushforward(dargs, literal_getproperty, x, Val(f)) -@tangent function literal_getproperty(t, ::Val{i}) where i +_pushforward(dargs, ::typeof(getfield), x, f) = + _pushforward(dargs, literal_getproperty, x, Val(f), getfield) + +@tangent function literal_getproperty(t, ::Val{i}, getproperty=getproperty) where i y = getproperty(t, i) - forw(ṫ, _) = getproperty(ṫ, i) - forw(ṫ::Nothing, _) = zerolike(y) + forw(ṫ, _1, _2=nothing) = getproperty(ṫ, i) + forw(ṫ::Nothing, _1, _2=nothing) = zerolike(y) return y, forw end diff --git a/src/lib/lib.jl b/src/lib/lib.jl index 6ee471124..827291bcc 100644 --- a/src/lib/lib.jl +++ b/src/lib/lib.jl @@ -140,10 +140,10 @@ function _pullback(cx::AContext, ::typeof(literal_indexed_iterate), xs::Tuple, : end # Needed for iteration lowering -@adjoint Core.getfield(xs::NTuple{N,Any}, i::Integer) where N = +@adjoint Core.getfield(xs::NTuple{N,Any}, i::Int) where N = (xs[i], Δ -> (ntuple(j -> i == j ? Δ : nothing, Val(N)), nothing)) -@adjoint Core.getfield(xs::NamedTuple{K,<:NTuple{N,Any}}, i::Integer) where {K,N} = +@adjoint Core.getfield(xs::NamedTuple{K,<:NTuple{N,Any}}, i::Int) where {K,N} = (xs[i], Δ -> (NamedTuple{K}(ntuple(j -> i == j ? Δ : nothing, Val(N))), nothing)) @adjoint function Base.first(xs::Tuple) @@ -207,14 +207,15 @@ end @generated nt_nothing(x) = Expr(:tuple, [:($f=nothing) for f in fieldnames(x)]...) -@generated pair(::Val{k}, v) where k = :($k = v,) +@generated pair(::Val{k}, v, x=nothing) where k = :($k = v,) +@generated pair(::Val{k}, v, x::NamedTuple{keys}) where {k,keys} = k isa Int ? :($(getfield(keys, k)) = v,) : :($k = v,) -@adjoint function literal_getproperty(x, ::Val{f}) where f +@adjoint function literal_getproperty(x, ::Val{f}, getproperty=getproperty) where f val = getproperty(x, f) function back(Δ) accum_param(__context__, val, Δ) === nothing && return if isimmutable(x) - ((;nt_nothing(x)...,pair(Val(f), Δ)...), nothing) + ((;nt_nothing(x)...,pair(Val(f), Δ, x)...), nothing) else dx = grad_mut(__context__, x) dx[] = (;dx[]...,pair(Val(f),accum(getfield(dx[], f), Δ))...) @@ -228,12 +229,12 @@ _pullback(cx::AContext, ::typeof(getproperty), x, f::Symbol) = _pullback(cx, literal_getproperty, x, Val(f)) _pullback(cx::AContext, ::typeof(getfield), x, f::Symbol) = - _pullback(cx, literal_getproperty, x, Val(f)) + _pullback(cx, literal_getproperty, x, Val(f), getfield) _pullback(cx::AContext, ::typeof(literal_getindex), x::NamedTuple, ::Val{f}) where f = _pullback(cx, literal_getproperty, x, Val(f)) -_pullback(cx::AContext, ::typeof(literal_getproperty), x::Tuple, ::Val{f}) where f = +_pullback(cx::AContext, ::typeof(literal_getproperty), x::Tuple, ::Val{f}, _=nothing) where f = _pullback(cx, literal_getindex, x, Val(f)) grad_mut(x) = Ref{Any}(nt_nothing(x)) diff --git a/test/chainrules.jl b/test/chainrules.jl index c70cae556..7008a45e7 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -160,8 +160,15 @@ using Zygote, Test, ChainRules @test (1,) == h(1) - a3, pb3 = Zygote.pullback(h, 1) - @test ((1,),) == pb3(1) + if VERSION >= v"1.6-" + @test_broken begin + a3, pb3 = Zygote.pullback(h, 1) + ((1,),) == pb3(1) + end + else + a3, pb3 = Zygote.pullback(h, 1) + @test ((1,),) == pb3(1) + end end @testset "kwargs" begin diff --git a/test/compiler.jl b/test/compiler.jl index 9bd4d100a..f4edfc1e4 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -31,7 +31,11 @@ y, back = pullback(badly, 2) bt = try back(1) catch e stacktrace(catch_backtrace()) end @test trace_contains(bt, nothing, "compiler.jl", 20) -@test trace_contains(bt, :badly, "compiler.jl", 24) +if VERSION >= v"1.6-" + @test_broken trace_contains(bt, :badly, "compiler.jl", 24) +else + @test trace_contains(bt, :badly, "compiler.jl", 24) +end # Type inference checks From 76b9da22fffe44ee7d0dd8a0cf39dd7c0751fa65 Mon Sep 17 00:00:00 2001 From: Simeon Schaub Date: Sat, 12 Dec 2020 12:59:39 +0100 Subject: [PATCH 2/7] disable continue-on-error for nightly --- .github/workflows/ci.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 82a45f3a6..15a74e774 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -49,13 +49,13 @@ jobs: ${{ runner.os }}- - uses: julia-actions/julia-buildpkg@v1 # `allow-failure` not available yet https://github.com/actions/toolkit/issues/399 - continue-on-error: ${{ matrix.version == 'nightly' }} # comment out to report nightly failures + #continue-on-error: ${{ matrix.version == 'nightly' }} # comment out to report nightly failures - uses: julia-actions/julia-runtest@v1 - continue-on-error: ${{ matrix.version == 'nightly' }} # comment out to report nightly failures + #continue-on-error: ${{ matrix.version == 'nightly' }} # comment out to report nightly failures - uses: julia-actions/julia-processcoverage@v1 - continue-on-error: ${{ matrix.version == 'nightly' }} # comment out to report nightly failures + #continue-on-error: ${{ matrix.version == 'nightly' }} # comment out to report nightly failures - uses: codecov/codecov-action@v1 - continue-on-error: ${{ matrix.version == 'nightly' }} # comment out to report nightly failures + #continue-on-error: ${{ matrix.version == 'nightly' }} # comment out to report nightly failures with: file: lcov.info docs: From 77bec8446f4d9db908fc65b0f82fca498aab854b Mon Sep 17 00:00:00 2001 From: Simeon Schaub Date: Sat, 12 Dec 2020 16:18:40 +0100 Subject: [PATCH 3/7] use literal_getfield --- src/Zygote.jl | 3 ++- src/compiler/reverse.jl | 19 +++++++++++++------ src/forward/lib.jl | 17 ++++++++++++----- src/lib/lib.jl | 33 +++++++++++++++++++-------------- 4 files changed, 46 insertions(+), 26 deletions(-) diff --git a/src/Zygote.jl b/src/Zygote.jl index a6ce672ec..b536236c8 100644 --- a/src/Zygote.jl +++ b/src/Zygote.jl @@ -3,7 +3,8 @@ module Zygote using LinearAlgebra, Statistics using LinearAlgebra: copytri!, AbstractTriangular -import ZygoteRules: @adjoint, @adjoint!, AContext, adjoint, _pullback, pullback, literal_getproperty +import ZygoteRules: @adjoint, @adjoint!, AContext, adjoint, _pullback, pullback, + literal_getproperty, literal_getfield using ChainRules: ChainRules, rrule, unthunk using IRTools diff --git a/src/compiler/reverse.jl b/src/compiler/reverse.jl index 5d21eec35..33e27727d 100644 --- a/src/compiler/reverse.jl +++ b/src/compiler/reverse.jl @@ -26,15 +26,21 @@ end unwrapquote(x) = x unwrapquote(x::QuoteNode) = x.value -is_literal_getproperty(ex) = - (iscall(ex, Base, :getproperty) || iscall(ex, Core, :getfield) || iscall(ex, Base, :getfield)) && - ex.args[3] isa Union{QuoteNode,Integer} +is_literal_getproperty(ex) = iscall(ex, Base, :getproperty) && ex.args[3] isa Union{QuoteNode,Integer} function instrument_getproperty!(ir, v, ex) is_literal_getproperty(ex) ? - iscall(ex, Base, :getproperty) ? - (ir[v] = xcall(Zygote, :literal_getproperty, ex.args[2], Val(unwrapquote(ex.args[3])))) : - (ir[v] = xcall(Zygote, :literal_getproperty, ex.args[2], Val(unwrapquote(ex.args[3])), ex.args[1])) : + (ir[v] = xcall(Zygote, :literal_getproperty, ex.args[2], Val(unwrapquote(ex.args[3])))) : + ex +end + +is_literal_getfield(ex) = + (iscall(ex, Core, :getfield) || iscall(ex, Base, :getfield)) && + ex.args[3] isa Union{QuoteNode,Integer} + +function instrument_getfield!(ir, v, ex) + is_literal_getfield(ex) ? + (ir[v] = xcall(Zygote, :literal_getfield, ex.args[2], Val(unwrapquote(ex.args[3])))) : ex end @@ -59,6 +65,7 @@ end function instrument_literals!(ir, v, ex) ex = instrument_getproperty!(ir, v, ex) + ex = instrument_getfield!(ir, v, ex) ex = instrument_getindex!(ir, v, ex) ex = instrument_iterate!(ir, v, ex) end diff --git a/src/forward/lib.jl b/src/forward/lib.jl index 1bb715f57..f4b521c85 100644 --- a/src/forward/lib.jl +++ b/src/forward/lib.jl @@ -62,18 +62,25 @@ if VERSION >= v"1.4.0-DEV.304" _pushforward((first(args), tail(tail(dargs))...), Core._apply, f, args...) end -using ..Zygote: literal_getproperty, literal_getindex +using ..Zygote: literal_getproperty, literal_getfield, literal_getindex _pushforward(dargs, ::typeof(getproperty), x, f) = _pushforward(dargs, literal_getproperty, x, Val(f)) _pushforward(dargs, ::typeof(getfield), x, f) = - _pushforward(dargs, literal_getproperty, x, Val(f), getfield) + _pushforward(dargs, literal_getfield, x, Val(f)) -@tangent function literal_getproperty(t, ::Val{i}, getproperty=getproperty) where i +@tangent function literal_getproperty(t, ::Val{i}) where i y = getproperty(t, i) - forw(ṫ, _1, _2=nothing) = getproperty(ṫ, i) - forw(ṫ::Nothing, _1, _2=nothing) = zerolike(y) + forw(ṫ, _) = getproperty(ṫ, i) + forw(ṫ::Nothing, _) = zerolike(y) + return y, forw +end + +@tangent function literal_getfield(t, ::Val{i}) where i + y = getfield(t, i) + forw(ṫ, _) = getfield(ṫ, i) + forw(ṫ::Nothing, _) = zerolike(y) return y, forw end diff --git a/src/lib/lib.jl b/src/lib/lib.jl index 827291bcc..0a2e758ea 100644 --- a/src/lib/lib.jl +++ b/src/lib/lib.jl @@ -210,31 +210,36 @@ end @generated pair(::Val{k}, v, x=nothing) where k = :($k = v,) @generated pair(::Val{k}, v, x::NamedTuple{keys}) where {k,keys} = k isa Int ? :($(getfield(keys, k)) = v,) : :($k = v,) -@adjoint function literal_getproperty(x, ::Val{f}, getproperty=getproperty) where f - val = getproperty(x, f) - function back(Δ) - accum_param(__context__, val, Δ) === nothing && return - if isimmutable(x) - ((;nt_nothing(x)...,pair(Val(f), Δ, x)...), nothing) - else - dx = grad_mut(__context__, x) - dx[] = (;dx[]...,pair(Val(f),accum(getfield(dx[], f), Δ))...) - return (dx,nothing) +for getproperty in [:getproperty, :getfield] + @eval @adjoint function $(Symbol(:literal_, getproperty))(x, ::Val{f}) where f + val = $getproperty(x, f) + function back(Δ) + accum_param(__context__, val, Δ) === nothing && return + if isimmutable(x) + ((; nt_nothing(x)..., pair(Val(f), Δ, x)...), nothing) + else + dx = grad_mut(__context__, x) + dx[] = (; dx[]..., pair(Val(f), accum(getfield(dx[], f), Δ))...) + return (dx,nothing) + end end + unwrap(val), back end - unwrap(val), back end _pullback(cx::AContext, ::typeof(getproperty), x, f::Symbol) = _pullback(cx, literal_getproperty, x, Val(f)) _pullback(cx::AContext, ::typeof(getfield), x, f::Symbol) = - _pullback(cx, literal_getproperty, x, Val(f), getfield) + _pullback(cx, literal_getfield, x, Val(f)) _pullback(cx::AContext, ::typeof(literal_getindex), x::NamedTuple, ::Val{f}) where f = - _pullback(cx, literal_getproperty, x, Val(f)) + _pullback(cx, literal_getfield, x, Val(f)) + +_pullback(cx::AContext, ::typeof(literal_getproperty), x::Tuple, ::Val{f}) where f = + _pullback(cx, literal_getindex, x, Val(f)) -_pullback(cx::AContext, ::typeof(literal_getproperty), x::Tuple, ::Val{f}, _=nothing) where f = +_pullback(cx::AContext, ::typeof(literal_getfield), x::Tuple, ::Val{f}) where f = _pullback(cx, literal_getindex, x, Val(f)) grad_mut(x) = Ref{Any}(nt_nothing(x)) From 0c78d95759da189a26dc0ff78624006591190694 Mon Sep 17 00:00:00 2001 From: Simeon Schaub Date: Mon, 14 Dec 2020 13:04:29 +0100 Subject: [PATCH 4/7] fix adjoint for literal_getproperty --- src/compiler/reverse.jl | 13 ++++++++++--- src/forward/lib.jl | 14 +++++--------- src/lib/lib.jl | 36 +++++++++++++++++------------------- 3 files changed, 32 insertions(+), 31 deletions(-) diff --git a/src/compiler/reverse.jl b/src/compiler/reverse.jl index 33e27727d..6cb69e258 100644 --- a/src/compiler/reverse.jl +++ b/src/compiler/reverse.jl @@ -26,12 +26,19 @@ end unwrapquote(x) = x unwrapquote(x::QuoteNode) = x.value -is_literal_getproperty(ex) = iscall(ex, Base, :getproperty) && ex.args[3] isa Union{QuoteNode,Integer} +is_getproperty(ex) = iscall(ex, Base, :getproperty) function instrument_getproperty!(ir, v, ex) - is_literal_getproperty(ex) ? - (ir[v] = xcall(Zygote, :literal_getproperty, ex.args[2], Val(unwrapquote(ex.args[3])))) : + if is_getproperty(ex) + if ex.args[3] isa Union{QuoteNode,Integer} + ir[v] = xcall(Zygote, :literal_getproperty, ex.args[2], Val(unwrapquote(ex.args[3]))) + else + f = insert!(ir, v, :(Val($(ex.args[3])))) + ir[v] = xcall(Zygote, :literal_getproperty, ex.args[2], f) + end + else ex + end end is_literal_getfield(ex) = diff --git a/src/forward/lib.jl b/src/forward/lib.jl index f4b521c85..1ab7682f1 100644 --- a/src/forward/lib.jl +++ b/src/forward/lib.jl @@ -64,18 +64,14 @@ end using ..Zygote: literal_getproperty, literal_getfield, literal_getindex -_pushforward(dargs, ::typeof(getproperty), x, f) = - _pushforward(dargs, literal_getproperty, x, Val(f)) +_pushforward(dargs, ::typeof(literal_getproperty), x::NamedTuple, ::Val{f}) where {f} = + _pushforward(dargs, literal_getfield, x, Val(f)) -_pushforward(dargs, ::typeof(getfield), x, f) = +_pushforward(dargs, ::typeof(getproperty), x::NamedTuple, f) = _pushforward(dargs, literal_getfield, x, Val(f)) -@tangent function literal_getproperty(t, ::Val{i}) where i - y = getproperty(t, i) - forw(ṫ, _) = getproperty(ṫ, i) - forw(ṫ::Nothing, _) = zerolike(y) - return y, forw -end +_pushforward(dargs, ::typeof(getfield), x, f) = + _pushforward(dargs, literal_getfield, x, Val(f)) @tangent function literal_getfield(t, ::Val{i}) where i y = getfield(t, i) diff --git a/src/lib/lib.jl b/src/lib/lib.jl index 0a2e758ea..d61064a60 100644 --- a/src/lib/lib.jl +++ b/src/lib/lib.jl @@ -207,32 +207,30 @@ end @generated nt_nothing(x) = Expr(:tuple, [:($f=nothing) for f in fieldnames(x)]...) -@generated pair(::Val{k}, v, x=nothing) where k = :($k = v,) -@generated pair(::Val{k}, v, x::NamedTuple{keys}) where {k,keys} = k isa Int ? :($(getfield(keys, k)) = v,) : :($k = v,) - -for getproperty in [:getproperty, :getfield] - @eval @adjoint function $(Symbol(:literal_, getproperty))(x, ::Val{f}) where f - val = $getproperty(x, f) - function back(Δ) - accum_param(__context__, val, Δ) === nothing && return - if isimmutable(x) - ((; nt_nothing(x)..., pair(Val(f), Δ, x)...), nothing) - else - dx = grad_mut(__context__, x) - dx[] = (; dx[]..., pair(Val(f), accum(getfield(dx[], f), Δ))...) - return (dx,nothing) - end +@generated pair(::Val{k}, v, _=nothing) where k = :($k = v,) +@generated pair(::Val{k}, v, ::NamedTuple{keys}) where {k,keys} = k isa Int ? :($(getfield(keys, k)) = v,) : :($k = v,) + +@adjoint function literal_getfield(x, ::Val{f}) where f + val = getfield(x, f) + function back(Δ) + accum_param(__context__, val, Δ) === nothing && return + if isimmutable(x) + ((; nt_nothing(x)..., pair(Val(f), Δ, x)...), nothing) + else + dx = grad_mut(__context__, x) + dx[] = (; dx[]..., pair(Val(f), accum(getfield(dx[], f), Δ))...) + return (dx,nothing) end - unwrap(val), back end + unwrap(val), back end -_pullback(cx::AContext, ::typeof(getproperty), x, f::Symbol) = - _pullback(cx, literal_getproperty, x, Val(f)) - _pullback(cx::AContext, ::typeof(getfield), x, f::Symbol) = _pullback(cx, literal_getfield, x, Val(f)) +_pullback(cx::AContext, ::typeof(literal_getproperty), x::NamedTuple, ::Val{f}) where f = + _pullback(cx, literal_getfield, x, Val(f)) + _pullback(cx::AContext, ::typeof(literal_getindex), x::NamedTuple, ::Val{f}) where f = _pullback(cx, literal_getfield, x, Val(f)) From ca8cd24143692abc9d0756fe1557116e8256f09e Mon Sep 17 00:00:00 2001 From: Simeon Schaub Date: Mon, 14 Dec 2020 13:04:51 +0100 Subject: [PATCH 5/7] add test case from #851 --- test/compiler.jl | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/test/compiler.jl b/test/compiler.jl index f4edfc1e4..19c010fa4 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -85,3 +85,27 @@ buf = IOBuffer() Base.show(buf, methods(Base.show)) str_repr = String(take!(buf)) @test !isempty(str_repr) + +struct Funky + x + y +end + +@testset "issue #851" begin + f = Funky(1, 1); + function Base.getproperty(f::Funky, i::Symbol) + return 2 + end + @test getproperty(f, :x) == 2 + @test getfield(f, :x) == 1 + + y, pb = Zygote._pullback(getproperty, f, :x) + @test y == 2 + @test pb(1) == (nothing, nothing, nothing) + y, pb = Zygote._pullback((f, x) -> getproperty(f, x), f, :x) + @test y == 2 + @test pb(1) == (nothing, nothing, nothing) + y, pb = Zygote._pullback(getfield, f, :x) + @test y == 1 + @test pb(1) == (nothing, (x = 1, y = nothing), nothing) +end From cb8ac7155db9e2d9399e5eb3b8a0b453965386f6 Mon Sep 17 00:00:00 2001 From: Simeon Schaub Date: Fri, 18 Dec 2020 13:25:53 +0100 Subject: [PATCH 6/7] address review comments --- src/compiler/reverse.jl | 30 +++++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/src/compiler/reverse.jl b/src/compiler/reverse.jl index 6cb69e258..e144be68d 100644 --- a/src/compiler/reverse.jl +++ b/src/compiler/reverse.jl @@ -28,6 +28,13 @@ unwrapquote(x::QuoteNode) = x.value is_getproperty(ex) = iscall(ex, Base, :getproperty) +# The initial premise of literal_getproperty was in some ways inherently flawed, because for +# getproperty it was intended that _pullback falls back to literal_getproperty, but we actually +# want the opposite to happen, since Zygote should fall back to recursing into the getproperty +# implementation by default. Users still want to define custom adjoints using only +# literal_getproperty, though. We can't really have mutually recursive definitions here, so we +# now always instrument getproperty as literal_getproperty, no matter whether the second +# argument is a literal or not. function instrument_getproperty!(ir, v, ex) if is_getproperty(ex) if ex.args[3] isa Union{QuoteNode,Integer} @@ -45,29 +52,38 @@ is_literal_getfield(ex) = (iscall(ex, Core, :getfield) || iscall(ex, Base, :getfield)) && ex.args[3] isa Union{QuoteNode,Integer} +# Here, only instrumenting getfield with literals is fine, since users should never have to +# define custom adjoints for literal_getfield function instrument_getfield!(ir, v, ex) - is_literal_getfield(ex) ? - (ir[v] = xcall(Zygote, :literal_getfield, ex.args[2], Val(unwrapquote(ex.args[3])))) : + if is_literal_getfield(ex) + ir[v] = xcall(Zygote, :literal_getfield, ex.args[2], Val(unwrapquote(ex.args[3]))) + else ex + end end is_literal_getindex(ex) = iscall(ex, Base, :getindex) && length(ex.args) == 3 && ex.args[3] isa Union{Integer,QuoteNode} +# TODO: is this always correct for user defined getindex methods? function instrument_getindex!(ir, v, ex) - is_literal_getindex(ex) ? - (ir[v] = xcall(Zygote, :literal_getindex, ex.args[2], Val(unwrapquote(ex.args[3])))) : + if is_literal_getindex(ex) + ir[v] = xcall(Zygote, :literal_getindex, ex.args[2], Val(unwrapquote(ex.args[3]))) + else ex + end end is_literal_iterate(ex) = iscall(ex, Base, :indexed_iterate) && length(ex.args) >= 3 && ex.args[3] isa Union{Integer,QuoteNode} function instrument_iterate!(ir, v, ex) - is_literal_iterate(ex) ? - (ir[v] = xcall(Zygote, :literal_indexed_iterate, ex.args[2], - Val(unwrapquote(ex.args[3])), ex.args[4:end]...)) : + if is_literal_iterate(ex) + ir[v] = xcall(Zygote, :literal_indexed_iterate, ex.args[2], + Val(unwrapquote(ex.args[3])), ex.args[4:end]...) + else ex + end end function instrument_literals!(ir, v, ex) From f4fa03ae7830ff9b8c00d3cdebec6a0da9b6b5ee Mon Sep 17 00:00:00 2001 From: Simeon Schaub Date: Fri, 18 Dec 2020 13:53:22 +0100 Subject: [PATCH 7/7] style improvements in lib.jl --- src/forward/lib.jl | 7 ++++--- src/lib/lib.jl | 30 ++++++++++++++++++------------ 2 files changed, 22 insertions(+), 15 deletions(-) diff --git a/src/forward/lib.jl b/src/forward/lib.jl index 1ab7682f1..b297dab41 100644 --- a/src/forward/lib.jl +++ b/src/forward/lib.jl @@ -64,9 +64,10 @@ end using ..Zygote: literal_getproperty, literal_getfield, literal_getindex -_pushforward(dargs, ::typeof(literal_getproperty), x::NamedTuple, ::Val{f}) where {f} = - _pushforward(dargs, literal_getfield, x, Val(f)) - +function _pushforward(dargs, ::typeof(literal_getproperty), x::NamedTuple, + ::Val{property_name}) where {property_name} + return _pushforward(dargs, literal_getfield, x, Val(property_name)) +end _pushforward(dargs, ::typeof(getproperty), x::NamedTuple, f) = _pushforward(dargs, literal_getfield, x, Val(f)) diff --git a/src/lib/lib.jl b/src/lib/lib.jl index d61064a60..0d52c876d 100644 --- a/src/lib/lib.jl +++ b/src/lib/lib.jl @@ -225,20 +225,26 @@ end unwrap(val), back end -_pullback(cx::AContext, ::typeof(getfield), x, f::Symbol) = - _pullback(cx, literal_getfield, x, Val(f)) +_pullback(cx::AContext, ::typeof(getfield), x, field_name::Symbol) = + _pullback(cx, literal_getfield, x, Val(field_name)) -_pullback(cx::AContext, ::typeof(literal_getproperty), x::NamedTuple, ::Val{f}) where f = - _pullback(cx, literal_getfield, x, Val(f)) - -_pullback(cx::AContext, ::typeof(literal_getindex), x::NamedTuple, ::Val{f}) where f = - _pullback(cx, literal_getfield, x, Val(f)) - -_pullback(cx::AContext, ::typeof(literal_getproperty), x::Tuple, ::Val{f}) where f = - _pullback(cx, literal_getindex, x, Val(f)) +function _pullback(cx::AContext, ::typeof(literal_getproperty), x::NamedTuple, + ::Val{property_name}) where {property_name} + return _pullback(cx, literal_getfield, x, Val(property_name)) +end +function _pullback(cx::AContext, ::typeof(literal_getindex), x::NamedTuple, + ::Val{key}) where {key} + return _pullback(cx, literal_getfield, x, Val(key)) +end -_pullback(cx::AContext, ::typeof(literal_getfield), x::Tuple, ::Val{f}) where f = - _pullback(cx, literal_getindex, x, Val(f)) +function _pullback(cx::AContext, ::typeof(literal_getproperty), x::Tuple, + ::Val{index}) where {index} + return _pullback(cx, literal_getindex, x, Val(index)) +end +function _pullback(cx::AContext, ::typeof(literal_getfield), x::Tuple, + ::Val{index}) where {index} + return _pullback(cx, literal_getindex, x, Val(index)) +end grad_mut(x) = Ref{Any}(nt_nothing(x))