Skip to content

Commit

Permalink
Merge #909
Browse files Browse the repository at this point in the history
909: improve inference for getproperty r=willtebbutt a=simeonschaub

This has regressed quite a bit due to #848. With this PR, we should be able to get back the same performance as before, assuming there is no custom implementation or pullback for `getproperty`. Still need to add tests.

Co-authored-by: Simeon David Schaub <schaub@mit.edu>
  • Loading branch information
bors[bot] and simeonschaub authored Oct 14, 2021
2 parents e245dee + 1d189fb commit 5887e46
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.6.26"
version = "0.6.27"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand Down
1 change: 1 addition & 0 deletions src/Zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ include("compiler/show.jl")

include("lib/grad.jl")
include("lib/lib.jl")
include("lib/literal_getproperty.jl")
include("lib/number.jl")
include("lib/base.jl")
include("lib/array.jl")
Expand Down
82 changes: 82 additions & 0 deletions src/lib/literal_getproperty.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Mostly copied over from Cassette in `src/overdub.jl`
# Return `Reflection` for signature `sigtypes` and `world`, if possible. Otherwise, return `nothing`.
function reflect(@nospecialize(sigtypes::Tuple), world::UInt = typemax(UInt))
if length(sigtypes) > 2 && sigtypes[1] === typeof(invoke)
@assert sigtypes[3] <: Type{<:Tuple}
sigtypes = (sigtypes[2], sigtypes[3].parameters[1].parameters...)
end
# This works around a subtyping bug. Basically, callers can deconstruct upstream
# `UnionAll` types in such a way that results in a type with free type variables, in
# which case subtyping can just break.
#
# God help you if you try to use a type parameter here (e.g. `::Type{S} where S<:Tuple`)
# instead of this nutty workaround, because the compiler can just rewrite `S` into
# whatever it thinks is "type equal" to the actual provided value. In other words, if
# `S` is defined as e.g. `f(::Type{S}) where S`, and you call `f(T)`, you should NOT
# assume that `S === T`. If you did, SHAME ON YOU. It doesn't matter that such an
# assumption holds true for essentially all other kinds of values. I haven't counted in
# a while, but I'm pretty sure I have ~40+ hellish years of Julia experience, and this
# still catches me every time. Who even uses this crazy language?
S = Tuple{map(s -> Core.Compiler.has_free_typevars(s) ? typeof(s.parameters[1]) : s, sigtypes)...}
(S.parameters[1]::DataType).name.module === Core.Compiler && return nothing
_methods = Base._methods_by_ftype(S, -1, world)
method_index = 0
for i in 1:length(_methods)
if _methods[i][1] === S
method_index = i
break
end
end
method_index === 0 && return nothing
type_signature, raw_static_params, method = _methods[method_index]
method_instance = Core.Compiler.specialize_method(method, type_signature, raw_static_params, false)
method_signature = method.sig
static_params = Any[raw_static_params...]
return method_instance, method_signature, static_params
end


# ugly hack to make differentiating `getproperty` infer a lot better
@generated function _pullback(cx::AContext, ::typeof(literal_getproperty), x, ::Val{f}) where f
sig(x) = Tuple{x, typeof(f)}
rrule_sig(x) = Tuple{typeof(getproperty), x, typeof(f)}
pb_sig(x) = Tuple{cx, typeof(getproperty), x, typeof(f)}

# either `getproperty` has a custom implementation or `_pullback(cx, getproperty, x, f)`
# / `rrule(getproperty, x, f) is overloaded directly
is_getfield_fallback = which(getproperty, sig(x)) == which(getproperty, sig(Any)) &&
which(_pullback, pb_sig(x)) == which(_pullback, pb_sig(Any)) &&
which(rrule, rrule_sig(x)) == which(rrule, rrule_sig(Any))

#ccall(:jl_safe_printf, Cvoid, (Cstring,), "$is_getfield_fallback: $x\n")

if is_getfield_fallback
# just copy pullback of `literal_getfield`
mi, _sig, sparams = reflect((typeof(_pullback), cx, typeof(literal_getfield), x, Val{f}))
ci = copy(Core.Compiler.retrieve_code_info(mi))

# we need to change the second arg to `_pullback` from `literal_getproperty` to
# `literal_getfield`
Meta.partially_inline!(
ci.code, Any[_pullback, Core.SlotNumber(2), literal_getfield],
_sig, sparams, 0, 0, :propagate,
)
ci.inlineable = true

# backedge for `_pullback`, see https://docs.julialang.org/en/v1/devdocs/ast/#MethodInstance
# this will cause a backedge to this particular MethodInstance to be attached to
# `_pullback(cx, getproperty, x, f)`
mi_pb_getproperty, _, _ = reflect((typeof(_pullback), pb_sig(x).parameters...))
mi_getproperty, _, _ = reflect((typeof(getproperty), sig(x).parameters...))
mi_rrule, _, _ = reflect((typeof(rrule), rrule_sig(x).parameters...))
ci.edges = Core.MethodInstance[mi, mi_pb_getproperty, mi_getproperty, mi_rrule]

return ci
else
# nothing to optimize here, need to recurse into `getproperty`
return quote
Base.@_inline_meta
_pullback(cx, getproperty, x, $(QuoteNode(f)))
end
end
end
55 changes: 55 additions & 0 deletions test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,5 +144,60 @@ end
@test Zygote.gradient(sumall, ms) == ((a = 2, b = 2),)
end

using ChainRulesCore

function _Gaussian(suffix::Symbol)
name = gensym(Symbol(:Gaussian_, suffix))
return @eval begin
struct $name{Tm, TP}
m::Tm
P::TP
end
$name
end
end

@testset "inference for `getproperty`" begin
Gaussian = _Gaussian(:getproperty)
g = Gaussian(randn(3), randn(3, 3))
y, back = @inferred pullback(x -> x.m, g)
@test y == getfield(g, :m)
# This type instability is due to the handling of non-bitstypes in `accum_param`
@test Base.return_types(back, Tuple{Vector{Float64}}) == Any[Union{Tuple{Nothing}, typeof(((m = [1.0, 0.0, 0.0], P = nothing),))}]
@test back([1., 0, 0]) == ((m = [1.0, 0.0, 0.0], P = nothing),)

Base.getproperty(g::Gaussian, s::Symbol) = 2getfield(g, s)
y, back = pullback(x -> x.m, g)
@test y == 2getfield(g, :m)
@test back([1., 0, 0]) == ((m = [2.0, 0.0, 0.0], P = nothing),)


Gaussian = _Gaussian(:pullback)
g = Gaussian(randn(3), randn(3, 3))
y, back = @inferred pullback(x -> x.m, g)

Zygote._pullback(::Zygote.AContext, ::typeof(getproperty), g::Gaussian, s::Symbol) = 3getfield(g, s), Δ -> (nothing, (; ((:m, :P) .=> nothing)..., s => 3Δ), nothing)
y, back = pullback(x -> x.m, g)
@test y == 3getfield(g, :m)
@test back([1., 0, 0]) == ((m = [3.0, 0.0, 0.0], P = nothing),)


Gaussian = _Gaussian(:rrule)
g = Gaussian(randn(3), randn(3, 3))
y, back = @inferred pullback(x -> x.m, g)

ChainRulesCore.rrule(::typeof(getproperty), g::Gaussian, s::Symbol) = 4getfield(g, s), Δ -> (NoTangent(), Tangent{typeof(g)}(; s => 4Δ), NoTangent())
y, back = pullback(x -> x.m, g)
@test y == 4getfield(g, :m)
@test back([1., 0, 0]) == ((m = [4.0, 0.0, 0.0], P = nothing),)


Gaussian = _Gaussian(:bitstype)
g = Gaussian(randn(), randn())
y, back = @inferred pullback(x -> x.m, g)
@test y == getfield(g, :m)
@test @inferred(back(1.0)) == ((m = 1.0, P = nothing),)
end

# issue 897
@test gradient(x -> sum(norm, collect(eachcol(x))), ones(3, 400))[1] fill(0.5773502691896258, 3, 400)

2 comments on commit 5887e46

@willtebbutt
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/46760

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.6.27 -m "<description of version>" 5887e46bf6280e3608dbed2e27f2229fa1456087
git push origin v0.6.27

Please sign in to comment.