Skip to content

Commit 6e90c55

Browse files
committed
Invalidate methods when binding is typed/const-defined
This allows for patterns like: ``` julia> function foo(N) for i = 1:N x = bar(i) end end julia> foo(1_000_000_000) ERROR: UndefVarError: `bar` not defined ``` not to suffer a tremendous performance regression because of the fact that `foo` was inferred with `bar` still undefined. Strictly speaking the original code remains valid, but for performance reasons once the global is defined we'd like to invalidate the code anyway to get an improved inference result. ``` julia> bar(x) = 3x bar (generic function with 1 method) julia> foo(1_000_000_000) # w/o PR: takes > 30 seconds ```
1 parent 9477472 commit 6e90c55

17 files changed

+205
-27
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2602,6 +2602,7 @@ function abstract_eval_isdefined(interp::AbstractInterpreter, e::Expr, vtypes::U
26022602
elseif isdefinedconst_globalref(sym)
26032603
rt = Const(true)
26042604
else
2605+
add_binding_backedge!(sv, sym, :const)
26052606
effects = Effects(EFFECTS_TOTAL; consistent=ALWAYS_FALSE)
26062607
end
26072608
elseif isexpr(sym, :static_parameter)
@@ -2822,18 +2823,21 @@ end
28222823
isdefined_globalref(g::GlobalRef) = !iszero(ccall(:jl_globalref_boundp, Cint, (Any,), g))
28232824
isdefinedconst_globalref(g::GlobalRef) = isconst(g) && isdefined_globalref(g)
28242825

2825-
function abstract_eval_globalref_type(g::GlobalRef)
2826+
function abstract_eval_globalref_type(g::GlobalRef, sv::Union{AbsIntState,Nothing}=nothing)
28262827
if isdefinedconst_globalref(g)
28272828
return Const(ccall(:jl_get_globalref_value, Any, (Any,), g))
28282829
end
28292830
ty = ccall(:jl_get_binding_type, Any, (Any, Any), g.mod, g.name)
2830-
ty === nothing && return Any
2831+
if ty === nothing
2832+
sv !== nothing && add_binding_backedge!(sv, g, :type)
2833+
return Any
2834+
end
28312835
return ty
28322836
end
2833-
abstract_eval_global(M::Module, s::Symbol) = abstract_eval_globalref_type(GlobalRef(M, s))
2837+
abstract_eval_global(M::Module, s::Symbol, sv::Union{AbsIntState,Nothing}=nothing) = abstract_eval_globalref_type(GlobalRef(M, s), sv)
28342838

28352839
function abstract_eval_globalref(interp::AbstractInterpreter, g::GlobalRef, sv::AbsIntState)
2836-
rt = abstract_eval_globalref_type(g)
2840+
rt = abstract_eval_globalref_type(g, sv)
28372841
consistent = inaccessiblememonly = ALWAYS_FALSE
28382842
nothrow = false
28392843
if isa(rt, Const)

base/compiler/inferencestate.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,6 +1038,14 @@ function add_mt_backedge!(irsv::IRInterpretationState, mt::MethodTable, @nospeci
10381038
return push!(irsv.edges, mt, typ)
10391039
end
10401040

1041+
function add_binding_backedge!(caller::InferenceState, g::GlobalRef, kind::Symbol)
1042+
isa(caller.linfo.def, Method) || return nothing # don't add backedges to toplevel method instance
1043+
return push!(get_stmt_edges!(caller), g, kind)
1044+
end
1045+
function add_binding_backedge!(irsv::IRInterpretationState, g::GlobalRef)
1046+
return push!(irsv.edges, g, kind)
1047+
end
1048+
10411049
get_curr_ssaflag(sv::InferenceState) = sv.src.ssaflags[sv.currpc]
10421050
get_curr_ssaflag(sv::IRInterpretationState) = sv.ir.stmts[sv.curridx][:flag]
10431051

base/compiler/typeinfer.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,8 @@ function store_backedges(caller::MethodInstance, edges::Vector{Any})
641641
callee = itr.caller
642642
if isa(callee, MethodInstance)
643643
ccall(:jl_method_instance_add_backedge, Cvoid, (Any, Any, Any), callee, itr.sig, caller)
644+
elseif isa(callee, GlobalRef)
645+
ccall(:jl_globalref_add_backedge, Cvoid, (Any, Any, Any), callee, itr.sig, caller)
644646
else
645647
typeassert(callee, MethodTable)
646648
ccall(:jl_method_table_add_backedge, Cvoid, (Any, Any, Any), callee, itr.sig, caller)

base/compiler/utilities.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -336,16 +336,17 @@ end
336336
const empty_backedge_iter = BackedgeIterator(Any[])
337337

338338
struct BackedgePair
339-
sig # ::Union{Nothing,Type}
340-
caller::Union{MethodInstance,MethodTable}
341-
BackedgePair(@nospecialize(sig), caller::Union{MethodInstance,MethodTable}) = new(sig, caller)
339+
sig # ::Union{Nothing,Symbol,Type}
340+
caller::Union{MethodInstance,MethodTable,GlobalRef}
341+
BackedgePair(@nospecialize(sig), caller::Union{MethodInstance,MethodTable,GlobalRef}) = new(sig, caller)
342342
end
343343

344344
function iterate(iter::BackedgeIterator, i::Int=1)
345345
backedges = iter.backedges
346346
i > length(backedges) && return nothing
347347
item = backedges[i]
348348
isa(item, MethodInstance) && return BackedgePair(nothing, item), i+1 # regular dispatch
349+
isa(item, GlobalRef) && return BackedgePair(backedges[i+1], item), i+2 # (untyped) binding
349350
isa(item, MethodTable) && return BackedgePair(backedges[i+1], item), i+2 # abstract dispatch
350351
return BackedgePair(item, backedges[i+1]::MethodInstance), i+2 # `invoke` calls
351352
end

src/builtins.c

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1378,7 +1378,10 @@ JL_CALLABLE(jl_f_get_binding_type)
13781378
if (b2 != b)
13791379
return (jl_value_t*)jl_any_type;
13801380
jl_value_t *old_ty = NULL;
1381-
jl_atomic_cmpswap_relaxed(&b->ty, &old_ty, (jl_value_t*)jl_any_type);
1381+
while (!jl_atomic_cmpswap_relaxed(&b->ty, &old_ty, (jl_value_t*)jl_any_type)) {
1382+
if (old_ty && !jl_is_binding_edges(old_ty))
1383+
break;
1384+
}
13821385
return jl_atomic_load_relaxed(&b->ty);
13831386
}
13841387
return ty;
@@ -1395,8 +1398,15 @@ JL_CALLABLE(jl_f_set_binding_type)
13951398
JL_TYPECHK(set_binding_type!, type, ty);
13961399
jl_binding_t *b = jl_get_binding_wr(m, s);
13971400
jl_value_t *old_ty = NULL;
1398-
if (jl_atomic_cmpswap_relaxed(&b->ty, &old_ty, ty)) {
1401+
while (!jl_atomic_cmpswap_relaxed(&b->ty, &old_ty, ty)) {
1402+
if (old_ty && !jl_is_binding_edges(old_ty))
1403+
break;
1404+
}
1405+
if (!old_ty)
1406+
jl_gc_wb(b, ty);
1407+
else if (jl_is_binding_edges(old_ty)) {
13991408
jl_gc_wb(b, ty);
1409+
jl_binding_invalidate(ty, /* is_const */ 0, (jl_binding_edges_t *)old_ty);
14001410
}
14011411
else if (nargs != 2 && !jl_types_equal(ty, old_ty)) {
14021412
jl_errorf("cannot set type for global %s.%s. It already has a value or is already set to a different type.",
@@ -2525,6 +2535,7 @@ void jl_init_primitives(void) JL_GC_DISABLED
25252535
add_builtin("QuoteNode", (jl_value_t*)jl_quotenode_type);
25262536
add_builtin("NewvarNode", (jl_value_t*)jl_newvarnode_type);
25272537
add_builtin("Binding", (jl_value_t*)jl_binding_type);
2538+
add_builtin("BindingEdges", (jl_value_t*)jl_binding_edges_type);
25282539
add_builtin("GlobalRef", (jl_value_t*)jl_globalref_type);
25292540
add_builtin("NamedTuple", (jl_value_t*)jl_namedtuple_type);
25302541

src/codegen.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3201,7 +3201,7 @@ static jl_cgval_t emit_globalref(jl_codectx_t &ctx, jl_module_t *mod, jl_sym_t *
32013201
return mark_julia_const(ctx, v);
32023202
ty = jl_atomic_load_relaxed(&bnd->ty);
32033203
}
3204-
if (ty == nullptr)
3204+
if (ty == nullptr || jl_is_binding_edges(ty))
32053205
ty = (jl_value_t*)jl_any_type;
32063206
return update_julia_type(ctx, emit_checked_var(ctx, bp, name, (jl_value_t*)mod, false, ctx.tbaa().tbaa_binding), ty);
32073207
}
@@ -3217,7 +3217,7 @@ static jl_cgval_t emit_globalop(jl_codectx_t &ctx, jl_module_t *mod, jl_sym_t *s
32173217
return jl_cgval_t();
32183218
if (bnd && !bnd->constp) {
32193219
jl_value_t *ty = jl_atomic_load_relaxed(&bnd->ty);
3220-
if (ty != nullptr) {
3220+
if (ty != nullptr && !jl_is_binding_edges(ty)) {
32213221
const std::string fname = issetglobal ? "setglobal!" : isreplaceglobal ? "replaceglobal!" : isswapglobal ? "swapglobal!" : ismodifyglobal ? "modifyglobal!" : "setglobalonce!";
32223222
if (!ismodifyglobal) {
32233223
// TODO: use typeassert in jl_check_binding_wr too

src/gf.c

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1747,6 +1747,69 @@ static void invalidate_backedges(jl_method_instance_t *replaced_mi, size_t max_w
17471747
}
17481748
}
17491749

1750+
/**
1751+
* Invalidate the edges accumulated in `be` - this should be called when a binding has just
1752+
* acquired a type or a const value.
1753+
*
1754+
* ty is the new type of the binding (optional if const), and `is_const` is whether the new
1755+
* binding ended up being const. These will be used to filter the edge invalidations, so that
1756+
* e.g. an `isdefined` edge is not invalidated by a non-const binding
1757+
**/
1758+
JL_DLLEXPORT void jl_binding_invalidate(jl_value_t *ty, int is_const, jl_binding_edges_t *be)
1759+
{
1760+
if (!is_const && ty == (jl_value_t *)jl_any_type)
1761+
return; // no improvement to inference information
1762+
1763+
jl_array_t *edges = be->edges;
1764+
jl_method_instance_t *mi = NULL;
1765+
JL_GC_PUSH2(&edges, mi);
1766+
JL_LOCK(&world_counter_lock);
1767+
// Narrow the world age on the methods to make them uncallable
1768+
size_t world = jl_atomic_load_relaxed(&jl_world_counter);
1769+
for (int i = 0; i < jl_array_len(edges) / 2; i++) {
1770+
mi = (jl_method_instance_t *)jl_array_ptr_ref(edges, 2 * i);
1771+
jl_sym_t *kind = (jl_sym_t *)jl_array_ptr_ref(edges, 2 * i + 1);
1772+
if (!is_const && kind == jl_symbol("const"))
1773+
continue; // this is an `isdefined` edge, which has not improved
1774+
1775+
invalidate_method_instance(mi, world, /* depth */ 0);
1776+
}
1777+
jl_atomic_store_release(&jl_world_counter, world + 1);
1778+
JL_UNLOCK(&world_counter_lock);
1779+
JL_GC_POP();
1780+
}
1781+
1782+
JL_DLLEXPORT void jl_globalref_add_backedge(jl_globalref_t *callee, jl_sym_t *kind, jl_method_instance_t *caller)
1783+
{
1784+
jl_binding_t *b = jl_get_module_binding(callee->mod, callee->name, /* alloc */ 0);
1785+
assert(b != NULL);
1786+
jl_binding_edges_t *edges = (jl_binding_edges_t *)jl_atomic_load_acquire(&b->ty);
1787+
if (edges && !jl_is_binding_edges(edges))
1788+
return; // TODO: Handle case where the invalidation happens before the edge arrives
1789+
1790+
jl_array_t *array = NULL;
1791+
JL_GC_PUSH2(&array, &edges);
1792+
if (edges == NULL) {
1793+
array = jl_alloc_vec_any(0);
1794+
edges = (jl_binding_edges_t *)jl_gc_alloc(
1795+
jl_current_task->ptls, sizeof(jl_binding_edges_t),
1796+
jl_binding_edges_type
1797+
);
1798+
edges->edges = array;
1799+
jl_value_t *old_ty = NULL;
1800+
if (!jl_atomic_cmpswap_relaxed(&b->ty, &old_ty, (jl_value_t *)edges))
1801+
return; // TODO: Handle case where ty was swapped out from under us
1802+
jl_gc_wb(b, edges);
1803+
}
1804+
else {
1805+
array = edges->edges;
1806+
}
1807+
jl_array_ptr_1d_push(array, (jl_value_t *)caller);
1808+
jl_array_ptr_1d_push(array, (jl_value_t *)kind);
1809+
JL_GC_POP();
1810+
return;
1811+
}
1812+
17501813
// add a backedge from callee to caller
17511814
JL_DLLEXPORT void jl_method_instance_add_backedge(jl_method_instance_t *callee, jl_value_t *invokesig, jl_method_instance_t *caller)
17521815
{

src/jl_exported_data.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
XX(jl_floatingpoint_type) \
5252
XX(jl_function_type) \
5353
XX(jl_binding_type) \
54+
XX(jl_binding_edges_type) \
5455
XX(jl_globalref_type) \
5556
XX(jl_gotoifnot_type) \
5657
XX(jl_enternode_type) \

src/jl_exported_funcs.inc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
XX(jl_backtrace_from_here) \
4444
XX(jl_base_relative_to) \
4545
XX(jl_binding_resolved_p) \
46+
XX(jl_binding_invalidate) \
4647
XX(jl_bitcast) \
4748
XX(jl_boundp) \
4849
XX(jl_bounds_error) \
@@ -237,6 +238,7 @@
237238
XX(jl_get_world_counter) \
238239
XX(jl_get_zero_subnormals) \
239240
XX(jl_gf_invoke_lookup) \
241+
XX(jl_globalref_add_backedge) \
240242
XX(jl_method_lookup_by_tt) \
241243
XX(jl_method_lookup) \
242244
XX(jl_gf_invoke_lookup_worlds) \

src/jltypes.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3108,6 +3108,11 @@ void jl_init_types(void) JL_GC_DISABLED
31083108
const static uint32_t binding_constfields[] = { 0x0002 }; // Set fields 2 as constant
31093109
jl_binding_type->name->constfields = binding_constfields;
31103110

3111+
jl_binding_edges_type =
3112+
jl_new_datatype(jl_symbol("BindingBackedges"), core, jl_any_type, jl_emptysvec,
3113+
jl_perm_symsvec(1, "edges"), jl_svec(1, jl_any_type),
3114+
jl_emptysvec, 0, 0, 1);
3115+
31113116
jl_globalref_type =
31123117
jl_new_datatype(jl_symbol("GlobalRef"), core, jl_any_type, jl_emptysvec,
31133118
jl_perm_symsvec(3, "mod", "name", "binding"),

src/julia.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,11 @@ typedef struct _jl_binding_t {
642642
uint8_t padding:1;
643643
} jl_binding_t;
644644

645+
typedef struct _jl_binding_edges_t {
646+
JL_DATA_TYPE
647+
jl_array_t *edges;
648+
} jl_binding_edges_t;
649+
645650
typedef struct {
646651
uint64_t hi;
647652
uint64_t lo;
@@ -930,6 +935,7 @@ extern JL_DLLIMPORT jl_value_t *jl_memoryref_uint8_type JL_GLOBALLY_ROOTED;
930935
extern JL_DLLIMPORT jl_value_t *jl_memoryref_any_type JL_GLOBALLY_ROOTED;
931936
extern JL_DLLIMPORT jl_datatype_t *jl_expr_type JL_GLOBALLY_ROOTED;
932937
extern JL_DLLIMPORT jl_datatype_t *jl_binding_type JL_GLOBALLY_ROOTED;
938+
extern JL_DLLIMPORT jl_datatype_t *jl_binding_edges_type JL_GLOBALLY_ROOTED;
933939
extern JL_DLLIMPORT jl_datatype_t *jl_globalref_type JL_GLOBALLY_ROOTED;
934940
extern JL_DLLIMPORT jl_datatype_t *jl_linenumbernode_type JL_GLOBALLY_ROOTED;
935941
extern JL_DLLIMPORT jl_datatype_t *jl_gotonode_type JL_GLOBALLY_ROOTED;
@@ -1503,6 +1509,7 @@ static inline int jl_field_isconst(jl_datatype_t *st, int i) JL_NOTSAFEPOINT
15031509
#define jl_is_slotnumber(v) jl_typetagis(v,jl_slotnumber_type)
15041510
#define jl_is_expr(v) jl_typetagis(v,jl_expr_type)
15051511
#define jl_is_binding(v) jl_typetagis(v,jl_binding_type)
1512+
#define jl_is_binding_edges(v) jl_typetagis(v,jl_binding_edges_type)
15061513
#define jl_is_globalref(v) jl_typetagis(v,jl_globalref_type)
15071514
#define jl_is_gotonode(v) jl_typetagis(v,jl_gotonode_type)
15081515
#define jl_is_gotoifnot(v) jl_typetagis(v,jl_gotoifnot_type)

src/julia_internal.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -835,6 +835,7 @@ JL_DLLEXPORT jl_value_t *jl_nth_slot_type(jl_value_t *sig JL_PROPAGATES_ROOT, si
835835
void jl_compute_field_offsets(jl_datatype_t *st);
836836
void jl_module_run_initializer(jl_module_t *m);
837837
JL_DLLEXPORT jl_binding_t *jl_get_module_binding(jl_module_t *m JL_PROPAGATES_ROOT, jl_sym_t *var, int alloc);
838+
JL_DLLEXPORT void jl_binding_invalidate(jl_value_t *ty, int is_const, jl_binding_edges_t *be);
838839
JL_DLLEXPORT void jl_binding_deprecation_warning(jl_module_t *m, jl_sym_t *sym, jl_binding_t *b);
839840
extern jl_array_t *jl_module_init_order JL_GLOBALLY_ROOTED;
840841
extern htable_t jl_current_modules JL_GLOBALLY_ROOTED;
@@ -1041,6 +1042,7 @@ JL_DLLEXPORT jl_value_t *jl_methtable_lookup(jl_methtable_t *mt JL_PROPAGATES_RO
10411042
JL_DLLEXPORT jl_method_instance_t *jl_specializations_get_linfo(
10421043
jl_method_t *m JL_PROPAGATES_ROOT, jl_value_t *type, jl_svec_t *sparams);
10431044
jl_method_instance_t *jl_specializations_get_or_insert(jl_method_instance_t *mi_ins);
1045+
JL_DLLEXPORT void jl_globalref_add_backedge(jl_globalref_t *callee, jl_sym_t *kind, jl_method_instance_t *caller);
10441046
JL_DLLEXPORT void jl_method_instance_add_backedge(jl_method_instance_t *callee, jl_value_t *invokesig, jl_method_instance_t *caller);
10451047
JL_DLLEXPORT void jl_method_table_add_backedge(jl_methtable_t *mt, jl_value_t *typ, jl_value_t *caller);
10461048
JL_DLLEXPORT void jl_mi_cache_insert(jl_method_instance_t *mi JL_ROOTING_ARGUMENT,

src/method.c

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,6 +1113,13 @@ JL_DLLEXPORT jl_value_t *jl_generic_function_def(jl_sym_t *name,
11131113
if (gf != NULL) {
11141114
if (!jl_is_datatype_singleton((jl_datatype_t*)jl_typeof(gf)) && !jl_is_type(gf))
11151115
jl_errorf("cannot define function %s; it already has a value", jl_symbol_name(name));
1116+
} else if (bnd) {
1117+
jl_value_t *old_ty = NULL;
1118+
while (!jl_atomic_cmpswap_relaxed(&bnd->ty, &old_ty, (jl_value_t*)jl_any_type)) {
1119+
assert(!old_ty || jl_is_binding_edges(old_ty));
1120+
}
1121+
if (old_ty)
1122+
jl_binding_invalidate((jl_value_t *)jl_any_type, /* is_const */ 1, (jl_binding_edges_t *)old_ty);
11161123
}
11171124
if (bnd)
11181125
bnd->constp = 1; // XXX: use jl_declare_constant and jl_checked_assignment

0 commit comments

Comments
 (0)