Skip to content

Commit ae19646

Browse files
committed
Add basic code for binding partition revalidation
This adds the binding partition revalidation code from #54654. This is the last piece of that PR that hasn't been merged yet - however the TODO in that PR still stands for future work. This PR itself adds a callback that gets triggered by deleting a binding. It will then walk all code in the system and invalidate code instances of Methods whose lowered source referenced the given global. This walk is quite slow. Future work will add backedges and optimizations to make this faster, but the basic functionality should be in place with this PR.
1 parent 582585b commit ae19646

File tree

5 files changed

+154
-2
lines changed

5 files changed

+154
-2
lines changed

base/Base_compiler.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ include("ordering.jl")
253253
using .Order
254254

255255
include("coreir.jl")
256+
include("invalidation.jl")
256257

257258
# For OS specific stuff
258259
# We need to strcat things here, before strings are really defined

base/invalidation.jl

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# This file is a part of Julia. License is MIT: https://julialang.org/license
2+
3+
struct GlobalRefIterator
4+
mod::Module
5+
end
6+
IteratorSize(::Type{GlobalRefIterator}) = SizeUnknown()
7+
globalrefs(mod::Module) = GlobalRefIterator(mod)
8+
9+
function iterate(gri::GlobalRefIterator, i = 1)
10+
m = gri.mod
11+
table = ccall(:jl_module_get_bindings, Ref{SimpleVector}, (Any,), m)
12+
i == length(table) && return nothing
13+
b = table[i]
14+
b === nothing && return iterate(gri, i+1)
15+
return ((b::Core.Binding).globalref, i+1)
16+
end
17+
18+
const TYPE_TYPE_MT = Type.body.name.mt
19+
const NONFUNCTION_MT = Core.MethodTable.name.mt
20+
function foreach_module_mtable(visit, m::Module, world::UInt)
21+
for gb in globalrefs(m)
22+
binding = gb.binding
23+
bpart = lookup_binding_partition(world, binding)
24+
if is_some_const_binding(binding_kind(bpart))
25+
isdefined(bpart, :restriction) || continue
26+
v = partition_restriction(bpart)
27+
uw = unwrap_unionall(v)
28+
name = gb.name
29+
if isa(uw, DataType)
30+
tn = uw.name
31+
if tn.module === m && tn.name === name && tn.wrapper === v && isdefined(tn, :mt)
32+
# this is the original/primary binding for the type (name/wrapper)
33+
mt = tn.mt
34+
if mt !== nothing && mt !== TYPE_TYPE_MT && mt !== NONFUNCTION_MT
35+
@assert mt.module === m
36+
visit(mt) || return false
37+
end
38+
end
39+
elseif isa(v, Module) && v !== m && parentmodule(v) === m && _nameof(v) === name
40+
# this is the original/primary binding for the submodule
41+
foreach_module_mtable(visit, v, world) || return false
42+
elseif isa(v, Core.MethodTable) && v.module === m && v.name === name
43+
# this is probably an external method table here, so let's
44+
# assume so as there is no way to precisely distinguish them
45+
visit(v) || return false
46+
end
47+
end
48+
end
49+
return true
50+
end
51+
52+
function foreach_reachable_mtable(visit, world::UInt)
53+
visit(TYPE_TYPE_MT) || return
54+
visit(NONFUNCTION_MT) || return
55+
for mod in loaded_modules_array()
56+
foreach_module_mtable(visit, mod, world)
57+
end
58+
end
59+
60+
function should_invalidate_code_for_globalref(gr::GlobalRef, src::CodeInfo)
61+
found_any = false
62+
labelchangemap = nothing
63+
stmts = src.code
64+
isgr(g::GlobalRef) = gr.mod == g.mod && gr.name === g.name
65+
isgr(g) = false
66+
for i = 1:length(stmts)
67+
stmt = stmts[i]
68+
if isgr(stmt)
69+
found_any = true
70+
continue
71+
end
72+
for ur in Compiler.userefs(stmt)
73+
arg = ur[]
74+
# If any of the GlobalRefs in this stmt match the one that
75+
# we are about, we need to move out all GlobalRefs to preserve
76+
# effect order, in case we later invalidate a different GR
77+
if isa(arg, GlobalRef)
78+
if isgr(arg)
79+
@assert !isa(stmt, PhiNode)
80+
found_any = true
81+
break
82+
end
83+
end
84+
end
85+
end
86+
return found_any
87+
end
88+
89+
function invalidate_code_for_globalref!(gr::GlobalRef, new_max_world::UInt)
90+
valid_in_valuepos = false
91+
foreach_reachable_mtable(new_max_world) do mt::Core.MethodTable
92+
for method in MethodList(mt)
93+
if isdefined(method, :source)
94+
src = _uncompressed_ir(method)
95+
old_stmts = src.code
96+
if should_invalidate_code_for_globalref(gr, src)
97+
for mi in specializations(method)
98+
ci = mi.cache
99+
while true
100+
if ci.max_world > new_max_world
101+
ccall(:jl_invalidate_code_instance, Cvoid, (Any, UInt), ci, new_max_world)
102+
end
103+
isdefined(ci, :next) || break
104+
ci = ci.next
105+
end
106+
end
107+
end
108+
end
109+
end
110+
return true
111+
end
112+
end

src/gf.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1785,6 +1785,11 @@ static void invalidate_code_instance(jl_code_instance_t *replaced, size_t max_wo
17851785
JL_UNLOCK(&replaced->def->def.method->writelock);
17861786
}
17871787

1788+
JL_DLLEXPORT void jl_invalidate_code_instance(jl_code_instance_t *replaced, size_t max_world)
1789+
{
1790+
invalidate_code_instance(replaced, max_world, 1);
1791+
}
1792+
17881793
static void _invalidate_backedges(jl_method_instance_t *replaced_mi, size_t max_world, int depth) {
17891794
jl_array_t *backedges = replaced_mi->backedges;
17901795
if (backedges) {

src/module.c

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,6 +1025,21 @@ JL_DLLEXPORT void jl_set_const(jl_module_t *m JL_ROOTING_ARGUMENT, jl_sym_t *var
10251025
jl_gc_wb(bpart, val);
10261026
}
10271027

1028+
void jl_invalidate_binding_refs(jl_globalref_t *ref, size_t new_world)
1029+
{
1030+
static jl_value_t *invalidate_code_for_globalref = NULL;
1031+
if (invalidate_code_for_globalref == NULL && jl_base_module != NULL)
1032+
invalidate_code_for_globalref = jl_get_global(jl_base_module, jl_symbol("invalidate_code_for_globalref!"));
1033+
if (!invalidate_code_for_globalref)
1034+
jl_error("Binding invalidation is not permitted during bootstrap.");
1035+
if (jl_generating_output())
1036+
jl_error("Binding invalidation is not permitted during image generation.");
1037+
jl_value_t *boxed_world = jl_box_ulong(new_world);
1038+
JL_GC_PUSH1(&boxed_world);
1039+
jl_call2((jl_function_t*)invalidate_code_for_globalref, (jl_value_t*)ref, boxed_world);
1040+
JL_GC_POP();
1041+
}
1042+
10281043
extern jl_mutex_t world_counter_lock;
10291044
JL_DLLEXPORT void jl_disable_binding(jl_globalref_t *gr)
10301045
{
@@ -1039,9 +1054,16 @@ JL_DLLEXPORT void jl_disable_binding(jl_globalref_t *gr)
10391054

10401055
JL_LOCK(&world_counter_lock);
10411056
jl_task_t *ct = jl_current_task;
1057+
size_t last_world = ct->world_age;
10421058
size_t new_max_world = jl_atomic_load_acquire(&jl_world_counter);
1043-
// TODO: Trigger invalidation here
1044-
(void)ct;
1059+
JL_TRY {
1060+
ct->world_age = jl_typeinf_world;
1061+
jl_invalidate_binding_refs(gr, new_max_world);
1062+
} JL_CATCH {
1063+
JL_UNLOCK(&world_counter_lock);
1064+
jl_rethrow();
1065+
}
1066+
ct->world_age = last_world;
10451067
jl_atomic_store_release(&bpart->max_world, new_max_world);
10461068
jl_atomic_store_release(&jl_world_counter, new_max_world + 1);
10471069
JL_UNLOCK(&world_counter_lock);
@@ -1327,6 +1349,11 @@ JL_DLLEXPORT void jl_add_to_module_init_list(jl_value_t *mod)
13271349
jl_array_ptr_1d_push(jl_module_init_order, mod);
13281350
}
13291351

1352+
JL_DLLEXPORT jl_svec_t *jl_module_get_bindings(jl_module_t *m)
1353+
{
1354+
return jl_atomic_load_relaxed(&m->bindings);
1355+
}
1356+
13301357
JL_DLLEXPORT void jl_init_restored_module(jl_value_t *mod)
13311358
{
13321359
if (!jl_generating_output() || jl_options.incremental) {

test/rebinding.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,11 @@ module Rebinding
3333
@test Base.@world(Foo, defined_world_age) == typeof(x)
3434
@test Base.@world(Rebinding.Foo, defined_world_age) == typeof(x)
3535
@test Base.@world((@__MODULE__).Foo, defined_world_age) == typeof(x)
36+
37+
# Test invalidation (const -> undefined)
38+
const delete_me = 1
39+
f_return_delete_me() = delete_me
40+
@test f_return_delete_me() == 1
41+
Base.delete_binding(@__MODULE__, :delete_me)
42+
@test_throws UndefVarError f_return_delete_me()
3643
end

0 commit comments

Comments
 (0)