Skip to content

Commit

Permalink
WIP: improve inference for getproperty
Browse files Browse the repository at this point in the history
This has regressed quite a bit due to #848. With this PR, we should be able to get back the same performance as before, assuming there is no custom implementation or pullback for `getproperty`. Still need to add tests.
  • Loading branch information
simeonschaub committed Oct 5, 2021
1 parent e9fa213 commit a3f8dc4
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/Zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ include("compiler/show.jl")

include("lib/grad.jl")
include("lib/lib.jl")
include("lib/literal_getproperty.jl")
include("lib/number.jl")
include("lib/base.jl")
include("lib/array.jl")
Expand Down
82 changes: 82 additions & 0 deletions src/lib/literal_getproperty.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Mostly copied over from Cassette in `src/overdub.jl`
# Return `Reflection` for signature `sigtypes` and `world`, if possible. Otherwise, return `nothing`.
function reflect(@nospecialize(sigtypes::Tuple), world::UInt = typemax(UInt))
if length(sigtypes) > 2 && sigtypes[1] === typeof(invoke)
@assert sigtypes[3] <: Type{<:Tuple}
sigtypes = (sigtypes[2], sigtypes[3].parameters[1].parameters...)
end
# This works around a subtyping bug. Basically, callers can deconstruct upstream
# `UnionAll` types in such a way that results in a type with free type variables, in
# which case subtyping can just break.
#
# God help you if you try to use a type parameter here (e.g. `::Type{S} where S<:Tuple`)
# instead of this nutty workaround, because the compiler can just rewrite `S` into
# whatever it thinks is "type equal" to the actual provided value. In other words, if
# `S` is defined as e.g. `f(::Type{S}) where S`, and you call `f(T)`, you should NOT
# assume that `S === T`. If you did, SHAME ON YOU. It doesn't matter that such an
# assumption holds true for essentially all other kinds of values. I haven't counted in
# a while, but I'm pretty sure I have ~40+ hellish years of Julia experience, and this
# still catches me every time. Who even uses this crazy language?
S = Tuple{map(s -> Core.Compiler.has_free_typevars(s) ? typeof(s.parameters[1]) : s, sigtypes)...}
(S.parameters[1]::DataType).name.module === Core.Compiler && return nothing
_methods = Base._methods_by_ftype(S, -1, world)
method_index = 0
for i in 1:length(_methods)
if _methods[i][1] === S
method_index = i
break
end
end
method_index === 0 && return nothing
type_signature, raw_static_params, method = _methods[method_index]
method_instance = Core.Compiler.specialize_method(method, type_signature, raw_static_params, false)
method_signature = method.sig
static_params = Any[raw_static_params...]
return method_instance, method_signature, static_params
end


# ugly hack to make differentiating `getproperty` infer a lot better
@generated function _pullback(cx::AContext, ::typeof(literal_getproperty), x, ::Val{f}) where f
sig(x) = Tuple{x, typeof(f)}
rrule_sig(x) = Tuple{typeof(getproperty), x, typeof(f)}
pb_sig(x) = Tuple{cx, typeof(getproperty), x, typeof(f)}

# either `getproperty` has a custom implementation or `_pullback(cx, getproperty, x, f)`
# / `rrule(getproperty, x, f) is overloaded directly
is_getfield_fallback = which(getproperty, sig(x)) == which(getproperty, sig(Any)) &&
which(_pullback, pb_sig(x)) == which(_pullback, pb_sig(Any)) &&
which(rrule, rrule_sig(x)) == which(rrule, rrule_sig(Any))

#ccall(:jl_safe_printf, Cvoid, (Cstring,), "$is_getfield_fallback: $x\n")

if is_getfield_fallback
# just copy pullback of `literal_getfield`
mi, _sig, sparams = reflect((typeof(_pullback), cx, typeof(literal_getfield), x, Val{f}))
ci = copy(Core.Compiler.retrieve_code_info(mi))

# we need to change the second arg to `_pullback` from `literal_getproperty` to
# `literal_getfield`
Meta.partially_inline!(
ci.code, Any[_pullback, Core.SlotNumber(2), literal_getfield],
_sig, sparams, 0, 0, :propagate,
)
ci.inlineable = true

# backedge for `_pullback`, see https://docs.julialang.org/en/v1/devdocs/ast/#MethodInstance
# this will cause a backedge to this particular MethodInstance to be attached to
# `_pullback(cx, getproperty, x, f)`
mi_pb_getproperty, _, _ = reflect((typeof(_pullback), pb_sig(x).parameters...))
mi_getproperty, _, _ = reflect((typeof(getproperty), sig(x).parameters...))
mi_rrule, _, _ = reflect((typeof(rrule), rrule_sig(x).parameters...))
ci.edges = Core.MethodInstance[mi, mi_pb_getproperty, mi_getproperty, mi_rrule]

return ci
else
# nothing to optimize here, need to recurse into `getproperty`
return quote
Base.@_inline_meta
_pullback(cx, getproperty, x, $(QuoteNode(f)))
end
end
end
47 changes: 47 additions & 0 deletions test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,5 +144,52 @@ end
@test Zygote.gradient(sumall, ms) == ((a = 2, b = 2),)
end

using ChainRulesCore

function _Gaussian(suffix::Symbol)
name = gensym(Symbol(:Gaussian_, suffix))
return @eval begin
struct $name{Tm, TP}
m::Tm
P::TP
end
$name
end
end

@testset "inference for `getproperty`" begin
Gaussian = _Gaussian(:getproperty)
g = Gaussian(randn(3), randn(3, 3))
y, back = @inferred pullback(x -> x.m, g)
@test y == getfield(g, :m)
@test Base.return_types(back, Tuple{Vector{Float64}}) == Any[Union{Tuple{Nothing}, typeof(((m = [1.0, 0.0, 0.0], P = nothing),))}]
@test back([1., 0, 0]) == ((m = [1.0, 0.0, 0.0], P = nothing),)

Base.getproperty(g::Gaussian, s::Symbol) = 2getfield(g, s)
y, back = pullback(x -> x.m, g)
@test y == 2getfield(g, :m)
@test back([1., 0, 0]) == ((m = [2.0, 0.0, 0.0], P = nothing),)


Gaussian = _Gaussian(:pullback)
g = Gaussian(randn(3), randn(3, 3))
y, back = @inferred pullback(x -> x.m, g)

Zygote._pullback(::typeof(getproperty), g::Gaussian, s::Symbol) = 3getfield(g, s), Δ -> (nothing, (; ((:m, :P) .=> nothing)..., s => 3Δ), nothing)
y, back = pullback(x -> x.m, g)
@test_broken y == 3getfield(g, :m)
@test_broken back([1., 0, 0]) == ((m = [3.0, 0.0, 0.0], P = nothing),)


Gaussian = _Gaussian(:rrule)
g = Gaussian(randn(3), randn(3, 3))
y, back = @inferred pullback(x -> x.m, g)

ChainRulesCore.rrule(::typeof(getproperty), g::Gaussian, s::Symbol) = 4getfield(g, s), Δ -> (NoTangent(), Tangent{typeof(g)}(; s => 4Δ), NoTangent())
y, back = pullback(x -> x.m, g)
@test y == 4getfield(g, :m)
@test back([1., 0, 0]) == ((m = [4.0, 0.0, 0.0], P = nothing),)
end

# issue 897
@test gradient(x -> sum(norm, collect(eachcol(x))), ones(3, 400))[1] fill(0.5773502691896258, 3, 400)

0 comments on commit a3f8dc4

Please sign in to comment.