Skip to content

Be more careful about iterator invalidation during recursive invalida… #57934

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Compiler/test/invalidation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -283,3 +283,8 @@ begin take!(GLOBAL_BUFFER)
@test isnothing(pr48932_caller_inlined(42))
@test "42" == String(take!(GLOBAL_BUFFER))
end

# Issue #57696
# This test checks for invalidation of recursive backedges. However, unfortunately, the original failure
# manifestation was an unreliable segfault or an assertion failure, so we don't have a more compact test.
@test success(`$(Base.julia_cmd()) -e 'Base.typejoin(x, ::Type) = 0; exit()'`)
237 changes: 134 additions & 103 deletions src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -1831,43 +1831,133 @@ JL_DLLEXPORT void jl_invalidate_code_instance(jl_code_instance_t *replaced, size
}

static void _invalidate_backedges(jl_method_instance_t *replaced_mi, jl_code_instance_t *replaced_ci, size_t max_world, int depth) {
jl_array_t *backedges = replaced_mi->backedges;
if (backedges) {
// invalidate callers (if any)
uint8_t recursion_flags = 0;
jl_array_t *backedges = jl_mi_get_backedges_mutate(replaced_mi, &recursion_flags);
if (!backedges)
return;
// invalidate callers (if any)
if (!replaced_ci) {
// We know all backedges are deleted - clear them eagerly
// Clears both array and flags
replaced_mi->backedges = NULL;
JL_GC_PUSH1(&backedges);
size_t i = 0, l = jl_array_nrows(backedges);
size_t ins = 0;
jl_code_instance_t *replaced;
while (i < l) {
jl_value_t *invokesig = NULL;
i = get_next_edge(backedges, i, &invokesig, &replaced);
JL_GC_PROMISE_ROOTED(replaced); // propagated by get_next_edge from backedges
if (replaced_ci) {
// If we're invalidating a particular codeinstance, only invalidate
// this backedge it actually has an edge for our codeinstance.
jl_svec_t *edges = jl_atomic_load_relaxed(&replaced->edges);
for (size_t j = 0; j < jl_svec_len(edges); ++j) {
jl_value_t *edge = jl_svecref(edges, j);
if (edge == (jl_value_t*)replaced_mi || edge == (jl_value_t*)replaced_ci)
goto found;
}
// Keep this entry in the backedge list, but compact it
ins = set_next_edge(backedges, ins, invokesig, replaced);
continue;
found:;
jl_atomic_fetch_and_relaxed(&replaced_mi->flags, ~MI_FLAG_BACKEDGES_ALL);
}
JL_GC_PUSH1(&backedges);
size_t i = 0, l = jl_array_nrows(backedges);
size_t ins = 0;
jl_code_instance_t *replaced;
while (i < l) {
jl_value_t *invokesig = NULL;
i = get_next_edge(backedges, i, &invokesig, &replaced);
if (!replaced) {
ins = i;
continue;
}
JL_GC_PROMISE_ROOTED(replaced); // propagated by get_next_edge from backedges
if (replaced_ci) {
// If we're invalidating a particular codeinstance, only invalidate
// this backedge it actually has an edge for our codeinstance.
jl_svec_t *edges = jl_atomic_load_relaxed(&replaced->edges);
for (size_t j = 0; j < jl_svec_len(edges); ++j) {
jl_value_t *edge = jl_svecref(edges, j);
if (edge == (jl_value_t*)replaced_mi || edge == (jl_value_t*)replaced_ci)
goto found;
}
invalidate_code_instance(replaced, max_world, depth);
ins = set_next_edge(backedges, ins, invokesig, replaced);
continue;
found:;
ins = clear_next_edge(backedges, ins, invokesig, replaced);
jl_atomic_fetch_or(&replaced_mi->flags, MI_FLAG_BACKEDGES_DIRTY);
/* fallthrough */
}
invalidate_code_instance(replaced, max_world, depth);
if (replaced_ci && !replaced_mi->backedges) {
// Fast-path early out. If `invalidate_code_instance` invalidated
// the entire mi via a recursive edge, there's no point to keep
// iterating - they'll already have been invalidated.
break;
}
if (replaced_ci && ins != 0) {
jl_array_del_end(backedges, l - ins);
// If we're only invalidating one ci, we don't know which ci any particular
// backedge was for, so we can't delete them. Put them back.
replaced_mi->backedges = backedges;
jl_gc_wb(replaced_mi, backedges);
}
if (replaced_ci)
jl_mi_done_backedges(replaced_mi, recursion_flags);
JL_GC_POP();
}

enum morespec_options {
morespec_unknown,
morespec_isnot,
morespec_is
};

// check if `type` is replacing `m` with an ambiguity here, given other methods in `d` that already match it
static int is_replacing(char ambig, jl_value_t *type, jl_method_t *m, jl_method_t *const *d, size_t n, jl_value_t *isect, jl_value_t *isect2, char *morespec)
{
size_t k;
for (k = 0; k < n; k++) {
jl_method_t *m2 = d[k];
// see if m2 also fully covered this intersection
if (m == m2 || !(jl_subtype(isect, m2->sig) || (isect2 && jl_subtype(isect2, m2->sig))))
continue;
if (morespec[k] == (char)morespec_unknown)
morespec[k] = (char)(jl_type_morespecific(m2->sig, type) ? morespec_is : morespec_isnot);
if (morespec[k] == (char)morespec_is)
// not actually shadowing this--m2 will still be better
return 0;
// if type is not more specific than m (thus now dominating it)
// then there is a new ambiguity here,
// since m2 was also a previous match over isect,
// see if m was previously dominant over all m2
// or if this was already ambiguous before
if (ambig != morespec_is && !jl_type_morespecific(m->sig, m2->sig)) {
// m and m2 were previously ambiguous over the full intersection of mi with type, and will still be ambiguous with addition of type
return 0;
}
JL_GC_POP();
}
return 1;
}

static int _invalidate_dispatch_backedges(jl_method_instance_t *mi, jl_value_t *type, jl_method_t *m,
jl_method_t *const *d, size_t n, int replaced_dispatch, int ambig,
size_t max_world, char *morespec)
{
uint8_t backedge_recursion_flags = 0;
jl_array_t *backedges = jl_mi_get_backedges_mutate(mi, &backedge_recursion_flags);
if (!backedges)
return 0;
size_t ib = 0, insb = 0, nb = jl_array_nrows(backedges);
jl_value_t *invokeTypes;
jl_code_instance_t *caller;
int invalidated_any = 0;
while (mi->backedges && ib < nb) {
ib = get_next_edge(backedges, ib, &invokeTypes, &caller);
if (!caller) {
insb = ib;
continue;
}
JL_GC_PROMISE_ROOTED(caller); // propagated by get_next_edge from backedges
int replaced_edge;
if (invokeTypes) {
// n.b. normally we must have mi.specTypes <: invokeTypes <: m.sig (though it might not strictly hold), so we only need to check the other subtypes
if (jl_egal(invokeTypes, jl_get_ci_mi(caller)->def.method->sig))
replaced_edge = 0; // if invokeTypes == m.sig, then the only way to change this invoke is to replace the method itself
else
replaced_edge = jl_subtype(invokeTypes, type) && is_replacing(ambig, type, m, d, n, invokeTypes, NULL, morespec);
}
else {
replaced_edge = replaced_dispatch;
}
if (replaced_edge) {
invalidate_code_instance(caller, max_world, 1);
insb = clear_next_edge(backedges, insb, invokeTypes, caller);
jl_atomic_fetch_or(&mi->flags, MI_FLAG_BACKEDGES_DIRTY);
invalidated_any = 1;
}
else {
insb = set_next_edge(backedges, insb, invokeTypes, caller);
}
}
jl_mi_done_backedges(mi, backedge_recursion_flags);
return invalidated_any;
}

// invalidate cached methods that overlap this definition
Expand Down Expand Up @@ -1898,20 +1988,22 @@ JL_DLLEXPORT void jl_method_instance_add_backedge(jl_method_instance_t *callee,
JL_LOCK(&callee->def.method->writelock);
if (jl_atomic_load_relaxed(&allow_new_worlds)) {
int found = 0;
jl_array_t *backedges = jl_mi_get_backedges(callee);
// TODO: use jl_cache_type_(invokesig) like cache_method does to save memory
if (!callee->backedges) {
if (!backedges) {
// lazy-init the backedges array
callee->backedges = jl_alloc_vec_any(0);
jl_gc_wb(callee, callee->backedges);
backedges = jl_alloc_vec_any(0);
callee->backedges = backedges;
jl_gc_wb(callee, backedges);
}
else {
size_t i = 0, l = jl_array_nrows(callee->backedges);
size_t i = 0, l = jl_array_nrows(backedges);
for (i = 0; i < l; i++) {
// optimized version of while (i < l) i = get_next_edge(callee->backedges, i, &invokeTypes, &mi);
jl_value_t *ciedge = jl_array_ptr_ref(callee->backedges, i);
jl_value_t *ciedge = jl_array_ptr_ref(backedges, i);
if (ciedge != (jl_value_t*)caller)
continue;
jl_value_t *invokeTypes = i > 0 ? jl_array_ptr_ref(callee->backedges, i - 1) : NULL;
jl_value_t *invokeTypes = i > 0 ? jl_array_ptr_ref(backedges, i - 1) : NULL;
if (invokeTypes && jl_is_method_instance(invokeTypes))
invokeTypes = NULL;
if ((invokesig == NULL && invokeTypes == NULL) ||
Expand All @@ -1922,7 +2014,7 @@ JL_DLLEXPORT void jl_method_instance_add_backedge(jl_method_instance_t *callee,
}
}
if (!found)
push_edge(callee->backedges, invokesig, caller);
push_edge(backedges, invokesig, caller);
}
JL_UNLOCK(&callee->def.method->writelock);
}
Expand Down Expand Up @@ -2111,13 +2203,13 @@ static int erase_method_backedges(jl_typemap_entry_t *def, void *closure)
for (i = 0; i < l; i++) {
jl_method_instance_t *mi = (jl_method_instance_t*)jl_svecref(specializations, i);
if ((jl_value_t*)mi != jl_nothing) {
mi->backedges = NULL;
mi->backedges = 0;
}
}
}
else {
jl_method_instance_t *mi = (jl_method_instance_t*)specializations;
mi->backedges = NULL;
mi->backedges = 0;
}
JL_UNLOCK(&method->writelock);
return 1;
Expand Down Expand Up @@ -2189,39 +2281,6 @@ static int jl_type_intersection2(jl_value_t *t1, jl_value_t *t2, jl_value_t **is
return 1;
}

enum morespec_options {
morespec_unknown,
morespec_isnot,
morespec_is
};

// check if `type` is replacing `m` with an ambiguity here, given other methods in `d` that already match it
static int is_replacing(char ambig, jl_value_t *type, jl_method_t *m, jl_method_t *const *d, size_t n, jl_value_t *isect, jl_value_t *isect2, char *morespec)
{
size_t k;
for (k = 0; k < n; k++) {
jl_method_t *m2 = d[k];
// see if m2 also fully covered this intersection
if (m == m2 || !(jl_subtype(isect, m2->sig) || (isect2 && jl_subtype(isect2, m2->sig))))
continue;
if (morespec[k] == (char)morespec_unknown)
morespec[k] = (char)(jl_type_morespecific(m2->sig, type) ? morespec_is : morespec_isnot);
if (morespec[k] == (char)morespec_is)
// not actually shadowing this--m2 will still be better
return 0;
// if type is not more specific than m (thus now dominating it)
// then there is a new ambiguity here,
// since m2 was also a previous match over isect,
// see if m was previously dominant over all m2
// or if this was already ambiguous before
if (ambig != morespec_is && !jl_type_morespecific(m->sig, m2->sig)) {
// m and m2 were previously ambiguous over the full intersection of mi with type, and will still be ambiguous with addition of type
return 0;
}
}
return 1;
}

jl_typemap_entry_t *jl_method_table_add(jl_methtable_t *mt, jl_method_t *method, jl_tupletype_t *simpletype)
{
JL_TIMING(ADD_METHOD, ADD_METHOD);
Expand Down Expand Up @@ -2386,35 +2445,7 @@ void jl_method_table_activate(jl_methtable_t *mt, jl_typemap_entry_t *newentry)
// found that this specialization dispatch got replaced by m
// call invalidate_backedges(mi, max_world, "jl_method_table_insert");
// but ignore invoke-type edges
jl_array_t *backedges = mi->backedges;
if (backedges) {
size_t ib = 0, insb = 0, nb = jl_array_nrows(backedges);
jl_value_t *invokeTypes;
jl_code_instance_t *caller;
while (ib < nb) {
ib = get_next_edge(backedges, ib, &invokeTypes, &caller);
JL_GC_PROMISE_ROOTED(caller); // propagated by get_next_edge from backedges
int replaced_edge;
if (invokeTypes) {
// n.b. normally we must have mi.specTypes <: invokeTypes <: m.sig (though it might not strictly hold), so we only need to check the other subtypes
if (jl_egal(invokeTypes, jl_get_ci_mi(caller)->def.method->sig))
replaced_edge = 0; // if invokeTypes == m.sig, then the only way to change this invoke is to replace the method itself
else
replaced_edge = jl_subtype(invokeTypes, type) && is_replacing(ambig, type, m, d, n, invokeTypes, NULL, morespec);
}
else {
replaced_edge = replaced_dispatch;
}
if (replaced_edge) {
invalidate_code_instance(caller, max_world, 1);
invalidated = 1;
}
else {
insb = set_next_edge(backedges, insb, invokeTypes, caller);
}
}
jl_array_del_end(backedges, nb - insb);
}
invalidated = _invalidate_dispatch_backedges(mi, type, m, d, n, replaced_dispatch, ambig, max_world, morespec);
jl_array_ptr_1d_push(oldmi, (jl_value_t*)mi);
if (_jl_debug_method_invalidation && invalidated) {
jl_array_ptr_1d_push(_jl_debug_method_invalidation, (jl_value_t*)mi);
Expand Down
5 changes: 4 additions & 1 deletion src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -404,13 +404,16 @@ struct _jl_method_instance_t {
} def; // pointer back to the context for this code
jl_value_t *specTypes; // argument types this was specialized for
jl_svec_t *sparam_vals; // static parameter values, indexed by def.method->sig
jl_array_t *backedges; // list of code-instances which call this method-instance; `invoke` records (invokesig, caller) pairs
// list of code-instances which call this method-instance; `invoke` records (invokesig, caller) pairs
jl_array_t *backedges;
_Atomic(struct _jl_code_instance_t*) cache;
uint8_t cache_with_orig; // !cache_with_specTypes

// flags for this method instance
// bit 0: generated by an explicit `precompile(...)`
// bit 1: dispatched
// bit 2: The ->backedges field is currently being walked higher up the stack - entries may be deleted, but not moved
// bit 3: The ->backedges field was modified and should be compacted when clearing bit 2
_Atomic(uint8_t) flags;
};
#define JL_MI_FLAGS_MASK_PRECOMPILED 0x01
Expand Down
20 changes: 20 additions & 0 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -726,9 +726,29 @@ JL_DLLEXPORT void jl_resolve_definition_effects_in_ir(jl_array_t *stmts, jl_modu
JL_DLLEXPORT int jl_maybe_add_binding_backedge(jl_binding_t *b, jl_value_t *edge, jl_method_t *in_method);
JL_DLLEXPORT void jl_add_binding_backedge(jl_binding_t *b, jl_value_t *edge);

static const uint8_t MI_FLAG_BACKEDGES_INUSE = 0b0100;
static const uint8_t MI_FLAG_BACKEDGES_DIRTY = 0b1000;
static const uint8_t MI_FLAG_BACKEDGES_ALL = 0b1100;

STATIC_INLINE jl_array_t *jl_mi_get_backedges_mutate(jl_method_instance_t *mi JL_PROPAGATES_ROOT, uint8_t *flags) {
*flags = jl_atomic_load_relaxed(&mi->flags) & (MI_FLAG_BACKEDGES_ALL);
jl_array_t *ret = mi->backedges;
if (ret)
jl_atomic_fetch_or_relaxed(&mi->flags, MI_FLAG_BACKEDGES_INUSE);
return ret;
}

STATIC_INLINE jl_array_t *jl_mi_get_backedges(jl_method_instance_t *mi JL_PROPAGATES_ROOT) {
assert(!(jl_atomic_load_relaxed(&mi->flags) & MI_FLAG_BACKEDGES_ALL));
jl_array_t *ret = mi->backedges;
return ret;
}

int get_next_edge(jl_array_t *list, int i, jl_value_t** invokesig, jl_code_instance_t **caller) JL_NOTSAFEPOINT;
int set_next_edge(jl_array_t *list, int i, jl_value_t *invokesig, jl_code_instance_t *caller);
int clear_next_edge(jl_array_t *list, int i, jl_value_t *invokesig, jl_code_instance_t *caller);
void push_edge(jl_array_t *list, jl_value_t *invokesig, jl_code_instance_t *caller);
void jl_mi_done_backedges(jl_method_instance_t *mi JL_PROPAGATES_ROOT, uint8_t old_flags);

JL_DLLEXPORT void jl_add_method_root(jl_method_t *m, jl_module_t *mod, jl_value_t* root);
void jl_append_method_roots(jl_method_t *m, uint64_t modid, jl_array_t* roots);
Expand Down
Loading