Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

invoked calls: record invoke signature in backedges #46010

Merged
merged 11 commits into from
Aug 24, 2022
Prev Previous commit
Teach jl_method_table_insert about invoke backedges
  • Loading branch information
timholy committed Aug 24, 2022
commit 8a9a87c75b5e75092226280c49f2316aced4966f
38 changes: 34 additions & 4 deletions src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -1851,11 +1851,41 @@ JL_DLLEXPORT void jl_method_table_insert(jl_methtable_t *mt, jl_method_t *method
if (k != n)
continue;
}
jl_array_ptr_1d_push(oldmi, (jl_value_t*)mi);
invalidate_external(mi, max_world);
// Before deciding whether to invalidate `mi`, check each backedge for `invoke`s
if (mi->backedges) {
invalidated = 1;
invalidate_backedges(&do_nothing_with_codeinst, mi, max_world, "jl_method_table_insert");
jl_array_t *backedges = mi->backedges;
size_t ib = 0, insb = 0, nb = jl_array_len(backedges);
jl_value_t *invokeTypes;
jl_method_instance_t *caller;
while (ib < nb) {
ib = get_next_edge(backedges, ib, &invokeTypes, &caller);
if (!invokeTypes) {
// ordinary dispatch, invalidate
invalidate_method_instance(&do_nothing_with_codeinst, caller, max_world, 1);
invalidated = 1;
} else {
// invoke-dispatch, check invokeTypes for validity
struct jl_typemap_assoc search = {invokeTypes, method->primary_world, NULL, 0, ~(size_t)0};
oldentry = jl_typemap_assoc_by_type(jl_atomic_load_relaxed(&mt->defs), &search, /*offs*/0, /*subtype*/0);
assert(oldentry);
if (oldentry->func.method == mi->def.method) {
jl_array_ptr_set(backedges, insb++, invokeTypes);
jl_array_ptr_set(backedges, insb++, caller);
continue;
}
invalidate_method_instance(&do_nothing_with_codeinst, caller, max_world, 1);
invalidated = 1;
}
}
jl_array_del_end(backedges, nb - insb);
}
if (!mi->backedges || jl_array_len(mi->backedges) == 0) {
jl_array_ptr_1d_push(oldmi, (jl_value_t*)mi);
invalidate_external(mi, max_world);
if (mi->backedges) {
invalidated = 1;
invalidate_backedges(&do_nothing_with_codeinst, mi, max_world, "jl_method_table_insert");
}
}
}
}
Expand Down
52 changes: 46 additions & 6 deletions test/precompile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -889,7 +889,7 @@ precompile_test_harness("invoke") do dir
write(joinpath(dir, "$InvokeModule.jl"),
"""
module $InvokeModule
export f, g, h, fnc, gnc, hnc # nc variants do not infer to a Const
export f, g, h, q, fnc, gnc, hnc, qnc # nc variants do not infer to a Const
# f is for testing invoke that occurs within a dependency
f(x::Real) = 0
f(x::Int) = x < 5 ? 1 : invoke(f, Tuple{Real}, x)
Expand All @@ -905,6 +905,9 @@ precompile_test_harness("invoke") do dir
h(x::Int) = x < 5 ? 1 : invoke(h, Tuple{Integer}, x)
hnc(x::Real) = rand()-1
hnc(x::Int) = x < 5 ? rand()+1 : invoke(hnc, Tuple{Integer}, x)
# q will have some callers invalidated
q(x::Integer) = 0
qnc(x::Integer) = rand()-1
end
""")
write(joinpath(dir, "$CallerModule.jl"),
Expand All @@ -915,9 +918,13 @@ precompile_test_harness("invoke") do dir
callf(x) = f(x)
callg(x) = x < 5 ? g(x) : invoke(g, Tuple{Real}, x)
callh(x) = h(x)
callq(x) = q(x)
callqi(x) = invoke(q, Tuple{Integer}, x)
callfnc(x) = fnc(x)
callgnc(x) = x < 5 ? gnc(x) : invoke(gnc, Tuple{Real}, x)
callhnc(x) = hnc(x)
callqnc(x) = qnc(x)
callqnci(x) = invoke(qnc, Tuple{Integer}, x)

# Purely internal
internal(x::Real) = 0
Expand All @@ -931,42 +938,75 @@ precompile_test_harness("invoke") do dir
callf(3)
callg(3)
callh(3)
callq(3)
callqi(3)
callfnc(3)
callgnc(3)
callhnc(3)
callqnc(3)
callqnci(3)
internal(3)
internalnc(3)
end

# Now that we've precompiled, invalidate with a new method that overrides the `invoke` dispatch
$InvokeModule.h(x::Integer) = -1
$InvokeModule.hnc(x::Integer) = rand() - 20
# ...and for q, override with a more specialized method that should leave only the invoked version still valid
$InvokeModule.q(x::Int) = -1
$InvokeModule.qnc(x::Int) = rand()+1
end
""")
Base.compilecache(Base.PkgId(string(CallerModule)))
@eval using $CallerModule
M = getfield(@__MODULE__, CallerModule)

function get_real_method(func) # return the method func(::Real)
function get_method_for_type(func, @nospecialize(T)) # return the method func(::T)
for m in methods(func)
m.sig.parameters[end] === Real && return m
m.sig.parameters[end] === T && return m
end
error("no ::Real method found for $func")
end
function nvalid(mi::Core.MethodInstance)
isdefined(mi, :cache) || return 0
ci = mi.cache
n = Int(ci.max_world == typemax(UInt))
while isdefined(ci, :next)
ci = ci.next
n += ci.max_world == typemax(UInt)
end
return n
end

for func in (M.f, M.g, M.internal, M.fnc, M.gnc, M.internalnc)
m = get_real_method(func)
m = get_method_for_type(func, Real)
mi = m.specializations[1]
@test length(mi.backedges) == 2
@test mi.backedges[1] === Tuple{typeof(func), Real}
@test isa(mi.backedges[2], Core.MethodInstance)
@test mi.cache.max_world == typemax(mi.cache.max_world)
end
for func in (M.q, M.qnc)
m = get_method_for_type(func, Integer)
mi = m.specializations[1]
@test length(mi.backedges) == 2
@test mi.backedges[1] === Tuple{typeof(func), Integer}
@test isa(mi.backedges[2], Core.MethodInstance)
@test mi.cache.max_world == typemax(mi.cache.max_world)
end

m = get_real_method(M.h)
m = get_method_for_type(M.h, Real)
@test isempty(m.specializations)
m = get_real_method(M.hnc)
m = get_method_for_type(M.hnc, Real)
@test isempty(m.specializations)
m = only(methods(M.callq))
@test isempty(m.specializations) || nvalid(m.specializations[1]) == 0
m = only(methods(M.callqnc))
@test isempty(m.specializations) || nvalid(m.specializations[1]) == 0
m = only(methods(M.callqi))
@test m.specializations[1].specTypes == Tuple{typeof(M.callqi), Int}
m = only(methods(M.callqnci))
@test m.specializations[1].specTypes == Tuple{typeof(M.callqnci), Int}

# Precompile specific methods for arbitrary arg types
invokeme(x) = 1
Expand Down