Skip to content

Commit

Permalink
Merge pull request #848 from simeonschaub/sds/fix_on_master
Browse files Browse the repository at this point in the history
fix Zygote on 1.6, fix #851
  • Loading branch information
CarloLucibello authored Dec 21, 2020
2 parents 6d882d2 + f4fa03a commit b3edf99
Show file tree
Hide file tree
Showing 8 changed files with 126 additions and 41 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@ jobs:
${{ runner.os }}-
- uses: julia-actions/julia-buildpkg@v1
# `allow-failure` not available yet https://github.com/actions/toolkit/issues/399
continue-on-error: ${{ matrix.version == 'nightly' }} # comment out to report nightly failures
#continue-on-error: ${{ matrix.version == 'nightly' }} # comment out to report nightly failures
- uses: julia-actions/julia-runtest@v1
continue-on-error: ${{ matrix.version == 'nightly' }} # comment out to report nightly failures
#continue-on-error: ${{ matrix.version == 'nightly' }} # comment out to report nightly failures
- uses: julia-actions/julia-processcoverage@v1
continue-on-error: ${{ matrix.version == 'nightly' }} # comment out to report nightly failures
#continue-on-error: ${{ matrix.version == 'nightly' }} # comment out to report nightly failures
- uses: codecov/codecov-action@v1
continue-on-error: ${{ matrix.version == 'nightly' }} # comment out to report nightly failures
#continue-on-error: ${{ matrix.version == 'nightly' }} # comment out to report nightly failures
with:
file: lcov.info
docs:
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ MacroTools = "0.5"
NaNMath = "0.3"
Requires = "0.5, 1.0"
SpecialFunctions = "0.10, 1.0"
ZygoteRules = "0.2"
ZygoteRules = "0.2.1"
julia = "1.3"

[extras]
Expand Down
3 changes: 2 additions & 1 deletion src/Zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ module Zygote
using LinearAlgebra, Statistics
using LinearAlgebra: copytri!, AbstractTriangular

import ZygoteRules: @adjoint, @adjoint!, AContext, adjoint, _pullback, pullback, literal_getproperty
import ZygoteRules: @adjoint, @adjoint!, AContext, adjoint, _pullback, pullback,
literal_getproperty, literal_getfield

using ChainRules: ChainRules, rrule, unthunk
using IRTools
Expand Down
52 changes: 42 additions & 10 deletions src/compiler/reverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,37 +26,69 @@ end
unwrapquote(x) = x
unwrapquote(x::QuoteNode) = x.value

is_literal_getproperty(ex) =
(iscall(ex, Base, :getproperty) || iscall(ex, Core, :getfield) || iscall(ex, Base, :getfield)) &&
is_getproperty(ex) = iscall(ex, Base, :getproperty)

# The initial premise of literal_getproperty was in some ways inherently flawed, because for
# getproperty it was intended that _pullback falls back to literal_getproperty, but we actually
# want the opposite to happen, since Zygote should fall back to recursing into the getproperty
# implementation by default. Users still want to define custom adjoints using only
# literal_getproperty, though. We can't really have mutually recursive definitions here, so we
# now always instrument getproperty as literal_getproperty, no matter whether the second
# argument is a literal or not.
function instrument_getproperty!(ir, v, ex)
if is_getproperty(ex)
if ex.args[3] isa Union{QuoteNode,Integer}
ir[v] = xcall(Zygote, :literal_getproperty, ex.args[2], Val(unwrapquote(ex.args[3])))
else
f = insert!(ir, v, :(Val($(ex.args[3]))))
ir[v] = xcall(Zygote, :literal_getproperty, ex.args[2], f)
end
else
ex
end
end

is_literal_getfield(ex) =
(iscall(ex, Core, :getfield) || iscall(ex, Base, :getfield)) &&
ex.args[3] isa Union{QuoteNode,Integer}

function instrument_getproperty!(ir, v, ex)
is_literal_getproperty(ex) ?
(ir[v] = xcall(Zygote, :literal_getproperty, ex.args[2], Val(unwrapquote(ex.args[3])))) :
# Here, only instrumenting getfield with literals is fine, since users should never have to
# define custom adjoints for literal_getfield
function instrument_getfield!(ir, v, ex)
if is_literal_getfield(ex)
ir[v] = xcall(Zygote, :literal_getfield, ex.args[2], Val(unwrapquote(ex.args[3])))
else
ex
end
end

is_literal_getindex(ex) =
iscall(ex, Base, :getindex) && length(ex.args) == 3 && ex.args[3] isa Union{Integer,QuoteNode}

# TODO: is this always correct for user defined getindex methods?
function instrument_getindex!(ir, v, ex)
is_literal_getindex(ex) ?
(ir[v] = xcall(Zygote, :literal_getindex, ex.args[2], Val(unwrapquote(ex.args[3])))) :
if is_literal_getindex(ex)
ir[v] = xcall(Zygote, :literal_getindex, ex.args[2], Val(unwrapquote(ex.args[3])))
else
ex
end
end

is_literal_iterate(ex) =
iscall(ex, Base, :indexed_iterate) && length(ex.args) >= 3 && ex.args[3] isa Union{Integer,QuoteNode}

function instrument_iterate!(ir, v, ex)
is_literal_iterate(ex) ?
(ir[v] = xcall(Zygote, :literal_indexed_iterate, ex.args[2],
Val(unwrapquote(ex.args[3])), ex.args[4:end]...)) :
if is_literal_iterate(ex)
ir[v] = xcall(Zygote, :literal_indexed_iterate, ex.args[2],
Val(unwrapquote(ex.args[3])), ex.args[4:end]...)
else
ex
end
end

function instrument_literals!(ir, v, ex)
ex = instrument_getproperty!(ir, v, ex)
ex = instrument_getfield!(ir, v, ex)
ex = instrument_getindex!(ir, v, ex)
ex = instrument_iterate!(ir, v, ex)
end
Expand Down
19 changes: 13 additions & 6 deletions src/forward/lib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,21 @@ if VERSION >= v"1.4.0-DEV.304"
_pushforward((first(args), tail(tail(dargs))...), Core._apply, f, args...)
end

using ..Zygote: literal_getproperty, literal_getindex
using ..Zygote: literal_getproperty, literal_getfield, literal_getindex

_pushforward(dargs, ::typeof(getproperty), x, f) =
_pushforward(dargs, literal_getproperty, x, Val(f))
function _pushforward(dargs, ::typeof(literal_getproperty), x::NamedTuple,
::Val{property_name}) where {property_name}
return _pushforward(dargs, literal_getfield, x, Val(property_name))
end
_pushforward(dargs, ::typeof(getproperty), x::NamedTuple, f) =
_pushforward(dargs, literal_getfield, x, Val(f))

_pushforward(dargs, ::typeof(getfield), x, f) =
_pushforward(dargs, literal_getfield, x, Val(f))

@tangent function literal_getproperty(t, ::Val{i}) where i
y = getproperty(t, i)
forw(ṫ, _) = getproperty(ṫ, i)
@tangent function literal_getfield(t, ::Val{i}) where i
y = getfield(t, i)
forw(ṫ, _) = getfield(ṫ, i)
forw(ṫ::Nothing, _) = zerolike(y)
return y, forw
end
Expand Down
42 changes: 26 additions & 16 deletions src/lib/lib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,10 @@ function _pullback(cx::AContext, ::typeof(literal_indexed_iterate), xs::Tuple, :
end

# Needed for iteration lowering
@adjoint Core.getfield(xs::NTuple{N,Any}, i::Integer) where N =
@adjoint Core.getfield(xs::NTuple{N,Any}, i::Int) where N =
(xs[i], Δ -> (ntuple(j -> i == j ? Δ : nothing, Val(N)), nothing))

@adjoint Core.getfield(xs::NamedTuple{K,<:NTuple{N,Any}}, i::Integer) where {K,N} =
@adjoint Core.getfield(xs::NamedTuple{K,<:NTuple{N,Any}}, i::Int) where {K,N} =
(xs[i], Δ -> (NamedTuple{K}(ntuple(j -> i == j ? Δ : nothing, Val(N))), nothing))

@adjoint function Base.first(xs::Tuple)
Expand Down Expand Up @@ -207,34 +207,44 @@ end

@generated nt_nothing(x) = Expr(:tuple, [:($f=nothing) for f in fieldnames(x)]...)

@generated pair(::Val{k}, v) where k = :($k = v,)
@generated pair(::Val{k}, v, _=nothing) where k = :($k = v,)
@generated pair(::Val{k}, v, ::NamedTuple{keys}) where {k,keys} = k isa Int ? :($(getfield(keys, k)) = v,) : :($k = v,)

@adjoint function literal_getproperty(x, ::Val{f}) where f
val = getproperty(x, f)
@adjoint function literal_getfield(x, ::Val{f}) where f
val = getfield(x, f)
function back(Δ)
accum_param(__context__, val, Δ) === nothing && return
if isimmutable(x)
((;nt_nothing(x)...,pair(Val(f), Δ)...), nothing)
((; nt_nothing(x)..., pair(Val(f), Δ, x)...), nothing)
else
dx = grad_mut(__context__, x)
dx[] = (;dx[]...,pair(Val(f),accum(getfield(dx[], f), Δ))...)
dx[] = (; dx[]..., pair(Val(f), accum(getfield(dx[], f), Δ))...)
return (dx,nothing)
end
end
unwrap(val), back
end

_pullback(cx::AContext, ::typeof(getproperty), x, f::Symbol) =
_pullback(cx, literal_getproperty, x, Val(f))
_pullback(cx::AContext, ::typeof(getfield), x, field_name::Symbol) =
_pullback(cx, literal_getfield, x, Val(field_name))

_pullback(cx::AContext, ::typeof(getfield), x, f::Symbol) =
_pullback(cx, literal_getproperty, x, Val(f))

_pullback(cx::AContext, ::typeof(literal_getindex), x::NamedTuple, ::Val{f}) where f =
_pullback(cx, literal_getproperty, x, Val(f))
function _pullback(cx::AContext, ::typeof(literal_getproperty), x::NamedTuple,
::Val{property_name}) where {property_name}
return _pullback(cx, literal_getfield, x, Val(property_name))
end
function _pullback(cx::AContext, ::typeof(literal_getindex), x::NamedTuple,
::Val{key}) where {key}
return _pullback(cx, literal_getfield, x, Val(key))
end

_pullback(cx::AContext, ::typeof(literal_getproperty), x::Tuple, ::Val{f}) where f =
_pullback(cx, literal_getindex, x, Val(f))
function _pullback(cx::AContext, ::typeof(literal_getproperty), x::Tuple,
::Val{index}) where {index}
return _pullback(cx, literal_getindex, x, Val(index))
end
function _pullback(cx::AContext, ::typeof(literal_getfield), x::Tuple,
::Val{index}) where {index}
return _pullback(cx, literal_getindex, x, Val(index))
end

grad_mut(x) = Ref{Any}(nt_nothing(x))

Expand Down
11 changes: 9 additions & 2 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,15 @@ using Zygote, Test, ChainRules

@test (1,) == h(1)

a3, pb3 = Zygote.pullback(h, 1)
@test ((1,),) == pb3(1)
if VERSION >= v"1.6-"
@test_broken begin
a3, pb3 = Zygote.pullback(h, 1)
((1,),) == pb3(1)
end
else
a3, pb3 = Zygote.pullback(h, 1)
@test ((1,),) == pb3(1)
end
end

@testset "kwargs" begin
Expand Down
30 changes: 29 additions & 1 deletion test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ y, back = pullback(badly, 2)
bt = try back(1) catch e stacktrace(catch_backtrace()) end

@test trace_contains(bt, nothing, "compiler.jl", 20)
@test trace_contains(bt, :badly, "compiler.jl", 24)
if VERSION >= v"1.6-"
@test_broken trace_contains(bt, :badly, "compiler.jl", 24)
else
@test trace_contains(bt, :badly, "compiler.jl", 24)
end

# Type inference checks

Expand Down Expand Up @@ -81,3 +85,27 @@ buf = IOBuffer()
Base.show(buf, methods(Base.show))
str_repr = String(take!(buf))
@test !isempty(str_repr)

struct Funky
x
y
end

@testset "issue #851" begin
f = Funky(1, 1);
function Base.getproperty(f::Funky, i::Symbol)
return 2
end
@test getproperty(f, :x) == 2
@test getfield(f, :x) == 1

y, pb = Zygote._pullback(getproperty, f, :x)
@test y == 2
@test pb(1) == (nothing, nothing, nothing)
y, pb = Zygote._pullback((f, x) -> getproperty(f, x), f, :x)
@test y == 2
@test pb(1) == (nothing, nothing, nothing)
y, pb = Zygote._pullback(getfield, f, :x)
@test y == 1
@test pb(1) == (nothing, (x = 1, y = nothing), nothing)
end

0 comments on commit b3edf99

Please sign in to comment.