Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve inference for getproperty #909

Merged
merged 4 commits into from
Oct 14, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ include("compiler/show.jl")

include("lib/grad.jl")
include("lib/lib.jl")
include("lib/literal_getproperty.jl")
include("lib/number.jl")
include("lib/base.jl")
include("lib/array.jl")
Expand Down
82 changes: 82 additions & 0 deletions src/lib/literal_getproperty.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Mostly copied over from Cassette in `src/overdub.jl`
# Return `Reflection` for signature `sigtypes` and `world`, if possible. Otherwise, return `nothing`.
function reflect(@nospecialize(sigtypes::Tuple), world::UInt = typemax(UInt))
if length(sigtypes) > 2 && sigtypes[1] === typeof(invoke)
@assert sigtypes[3] <: Type{<:Tuple}
sigtypes = (sigtypes[2], sigtypes[3].parameters[1].parameters...)
end
# This works around a subtyping bug. Basically, callers can deconstruct upstream
# `UnionAll` types in such a way that results in a type with free type variables, in
# which case subtyping can just break.
#
# God help you if you try to use a type parameter here (e.g. `::Type{S} where S<:Tuple`)
# instead of this nutty workaround, because the compiler can just rewrite `S` into
# whatever it thinks is "type equal" to the actual provided value. In other words, if
# `S` is defined as e.g. `f(::Type{S}) where S`, and you call `f(T)`, you should NOT
# assume that `S === T`. If you did, SHAME ON YOU. It doesn't matter that such an
# assumption holds true for essentially all other kinds of values. I haven't counted in
# a while, but I'm pretty sure I have ~40+ hellish years of Julia experience, and this
# still catches me every time. Who even uses this crazy language?
S = Tuple{map(s -> Core.Compiler.has_free_typevars(s) ? typeof(s.parameters[1]) : s, sigtypes)...}
(S.parameters[1]::DataType).name.module === Core.Compiler && return nothing
_methods = Base._methods_by_ftype(S, -1, world)
method_index = 0
for i in 1:length(_methods)
if _methods[i][1] === S
method_index = i
break
end
end
method_index === 0 && return nothing
type_signature, raw_static_params, method = _methods[method_index]
method_instance = Core.Compiler.specialize_method(method, type_signature, raw_static_params, false)
method_signature = method.sig
static_params = Any[raw_static_params...]
return method_instance, method_signature, static_params
end


# ugly hack to make differentiating `getproperty` infer a lot better
@generated function _pullback(cx::AContext, ::typeof(literal_getproperty), x, ::Val{f}) where f
sig(x) = Tuple{x, typeof(f)}
rrule_sig(x) = Tuple{typeof(getproperty), x, typeof(f)}
pb_sig(x) = Tuple{cx, typeof(getproperty), x, typeof(f)}

# either `getproperty` has a custom implementation or `_pullback(cx, getproperty, x, f)`
# / `rrule(getproperty, x, f) is overloaded directly
is_getfield_fallback = which(getproperty, sig(x)) == which(getproperty, sig(Any)) &&
which(_pullback, pb_sig(x)) == which(_pullback, pb_sig(Any)) &&
which(rrule, rrule_sig(x)) == which(rrule, rrule_sig(Any))

#ccall(:jl_safe_printf, Cvoid, (Cstring,), "$is_getfield_fallback: $x\n")

if is_getfield_fallback
# just copy pullback of `literal_getfield`
mi, _sig, sparams = reflect((typeof(_pullback), cx, typeof(literal_getfield), x, Val{f}))
ci = copy(Core.Compiler.retrieve_code_info(mi))

# we need to change the second arg to `_pullback` from `literal_getproperty` to
# `literal_getfield`
Meta.partially_inline!(
ci.code, Any[_pullback, Core.SlotNumber(2), literal_getfield],
_sig, sparams, 0, 0, :propagate,
)
ci.inlineable = true

# backedge for `_pullback`, see https://docs.julialang.org/en/v1/devdocs/ast/#MethodInstance
# this will cause a backedge to this particular MethodInstance to be attached to
# `_pullback(cx, getproperty, x, f)`
mi_pb_getproperty, _, _ = reflect((typeof(_pullback), pb_sig(x).parameters...))
mi_getproperty, _, _ = reflect((typeof(getproperty), sig(x).parameters...))
mi_rrule, _, _ = reflect((typeof(rrule), rrule_sig(x).parameters...))
ci.edges = Core.MethodInstance[mi, mi_pb_getproperty, mi_getproperty, mi_rrule]

return ci
else
# nothing to optimize here, need to recurse into `getproperty`
return quote
Base.@_inline_meta
_pullback(cx, getproperty, x, $(QuoteNode(f)))
end
end
end
47 changes: 47 additions & 0 deletions test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,5 +144,52 @@ end
@test Zygote.gradient(sumall, ms) == ((a = 2, b = 2),)
end

using ChainRulesCore

function _Gaussian(suffix::Symbol)
name = gensym(Symbol(:Gaussian_, suffix))
return @eval begin
struct $name{Tm, TP}
m::Tm
P::TP
end
$name
end
end

@testset "inference for `getproperty`" begin
Gaussian = _Gaussian(:getproperty)
g = Gaussian(randn(3), randn(3, 3))
y, back = @inferred pullback(x -> x.m, g)
@test y == getfield(g, :m)
@test Base.return_types(back, Tuple{Vector{Float64}}) == Any[Union{Tuple{Nothing}, typeof(((m = [1.0, 0.0, 0.0], P = nothing),))}]
willtebbutt marked this conversation as resolved.
Show resolved Hide resolved
@test back([1., 0, 0]) == ((m = [1.0, 0.0, 0.0], P = nothing),)

Base.getproperty(g::Gaussian, s::Symbol) = 2getfield(g, s)
y, back = pullback(x -> x.m, g)
@test y == 2getfield(g, :m)
@test back([1., 0, 0]) == ((m = [2.0, 0.0, 0.0], P = nothing),)


Gaussian = _Gaussian(:pullback)
g = Gaussian(randn(3), randn(3, 3))
y, back = @inferred pullback(x -> x.m, g)

Zygote._pullback(::typeof(getproperty), g::Gaussian, s::Symbol) = 3getfield(g, s), Δ -> (nothing, (; ((:m, :P) .=> nothing)..., s => 3Δ), nothing)
willtebbutt marked this conversation as resolved.
Show resolved Hide resolved
y, back = pullback(x -> x.m, g)
@test_broken y == 3getfield(g, :m)
@test_broken back([1., 0, 0]) == ((m = [3.0, 0.0, 0.0], P = nothing),)


Gaussian = _Gaussian(:rrule)
g = Gaussian(randn(3), randn(3, 3))
y, back = @inferred pullback(x -> x.m, g)

ChainRulesCore.rrule(::typeof(getproperty), g::Gaussian, s::Symbol) = 4getfield(g, s), Δ -> (NoTangent(), Tangent{typeof(g)}(; s => 4Δ), NoTangent())
y, back = pullback(x -> x.m, g)
@test y == 4getfield(g, :m)
@test back([1., 0, 0]) == ((m = [4.0, 0.0, 0.0], P = nothing),)
end

# issue 897
@test gradient(x -> sum(norm, collect(eachcol(x))), ones(3, 400))[1] ≈ fill(0.5773502691896258, 3, 400)