Skip to content

bpart: Track whether any binding replacement has happened in image modules #57433

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions base/client.jl
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ function exec_options(opts)
distributed_mode = (opts.worker == 1) || (opts.nprocs > 0) || (opts.machine_file != C_NULL)
if distributed_mode
let Distributed = require(PkgId(UUID((0x8ba89e20_285c_5b6f, 0x9357_94700520ee1b)), "Distributed"))
Core.eval(MainInclude, :(const Distributed = $Distributed))
MainInclude.Distributed = Distributed
Core.eval(Main, :(using Base.MainInclude.Distributed))
invokelatest(Distributed.process_opts, opts)
end
Expand Down Expand Up @@ -400,7 +400,7 @@ function load_InteractiveUtils(mod::Module=Main)
try
# TODO: we have to use require_stdlib here because it is a dependency of REPL, but we would sort of prefer not to
let InteractiveUtils = require_stdlib(PkgId(UUID(0xb77e0a4c_d291_57a0_90e8_8db25a27a240), "InteractiveUtils"))
Core.eval(MainInclude, :(const InteractiveUtils = $InteractiveUtils))
MainInclude.InteractiveUtils = InteractiveUtils
end
catch ex
@warn "Failed to import InteractiveUtils into module $mod" exception=(ex, catch_backtrace())
Expand Down Expand Up @@ -535,6 +535,10 @@ The thrown errors are collected in a stack of exceptions.
"""
global err = nothing

# Used for memoizing require_stdlib of these modules
global InteractiveUtils::Module
global Distributed::Module

# weakly exposes ans and err variables to Main
export ans, err
end
Expand Down
18 changes: 13 additions & 5 deletions base/invalidation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,26 +180,34 @@
b.partitions.min_world > unsafe_load(cglobal(:jl_require_world, UInt))
end

function scan_new_method!(methods_with_invalidated_source::IdSet{Method}, method::Method)
function scan_new_method!(methods_with_invalidated_source::IdSet{Method}, method::Method, image_backedges_only::Bool)
isdefined(method, :source) || return
if image_backedges_only && !has_image_globalref(method)
return
end
src = _uncompressed_ir(method)
mod = method.module
foreachgr(src) do gr::GlobalRef
b = convert(Core.Binding, gr)
binding_was_invalidated(b) && push!(methods_with_invalidated_source, method)
if binding_was_invalidated(b)
# TODO: We could turn this into an addition if condition. For now, use it as a reasonably cheap
# additional consistency chekc

Check warning on line 194 in base/invalidation.jl

View workflow job for this annotation

GitHub Actions / Check for new typos

perhaps "chekc" should be "check".
@assert !image_backedges_only
push!(methods_with_invalidated_source, method)
end
maybe_add_binding_backedge!(b, method)
end
end

function scan_new_methods(extext_methods::Vector{Any}, internal_methods::Vector{Any})
function scan_new_methods(extext_methods::Vector{Any}, internal_methods::Vector{Any}, image_backedges_only::Bool)
methods_with_invalidated_source = IdSet{Method}()
for method in internal_methods
if isa(method, Method)
scan_new_method!(methods_with_invalidated_source, method)
scan_new_method!(methods_with_invalidated_source, method, image_backedges_only)
end
end
for tme::Core.TypeMapEntry in extext_methods
scan_new_method!(methods_with_invalidated_source, tme.func::Method)
scan_new_method!(methods_with_invalidated_source, tme.func::Method, image_backedges_only)
end
return methods_with_invalidated_source
end
2 changes: 2 additions & 0 deletions base/runtime_internals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1656,3 +1656,5 @@ isempty(mt::Core.MethodTable) = (mt.defs === nothing)
uncompressed_ir(m::Method) = isdefined(m, :source) ? _uncompressed_ir(m) :
isdefined(m, :generator) ? error("Method is @generated; try `code_lowered` instead.") :
error("Code for this Method is not available.")

has_image_globalref(m::Method) = ccall(:jl_ir_flag_has_image_globalref, Bool, (Any,), m.source)
3 changes: 2 additions & 1 deletion base/staticdata.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ end
function insert_backedges(edges::Vector{Any}, ext_ci_list::Union{Nothing,Vector{Any}}, extext_methods::Vector{Any}, internal_methods::Vector{Any})
# determine which CodeInstance objects are still valid in our image
# to enable any applicable new codes
methods_with_invalidated_source = Base.scan_new_methods(extext_methods, internal_methods)
backedges_only = unsafe_load(cglobal(:jl_first_image_replacement_world, UInt)) == typemax(UInt)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unsafe_load requires an atomic ordering specifier to access mutable data (this can be strong UB here without it, as we define unsafe_load to be a memcpy and not a unordered load)

methods_with_invalidated_source = Base.scan_new_methods(extext_methods, internal_methods, backedges_only)
stack = CodeInstance[]
visiting = IdDict{CodeInstance,Int}()
_insert_backedges(edges, stack, visiting, methods_with_invalidated_source)
Expand Down
16 changes: 14 additions & 2 deletions src/ircode.c
Original file line number Diff line number Diff line change
Expand Up @@ -547,14 +547,15 @@ static void jl_encode_value_(jl_ircode_state *s, jl_value_t *v, int as_literal)
}
}

static jl_code_info_flags_t code_info_flags(uint8_t propagate_inbounds, uint8_t has_fcall,
static jl_code_info_flags_t code_info_flags(uint8_t propagate_inbounds, uint8_t has_fcall, uint8_t has_image_globalref,
uint8_t nospecializeinfer, uint8_t isva,
uint8_t inlining, uint8_t constprop, uint8_t nargsmatchesmethod,
jl_array_t *ssaflags)
{
jl_code_info_flags_t flags;
flags.bits.propagate_inbounds = propagate_inbounds;
flags.bits.has_fcall = has_fcall;
flags.bits.has_image_globalref = has_image_globalref;
flags.bits.nospecializeinfer = nospecializeinfer;
flags.bits.isva = isva;
flags.bits.inlining = inlining;
Expand Down Expand Up @@ -1036,7 +1037,7 @@ JL_DLLEXPORT jl_string_t *jl_compress_ir(jl_method_t *m, jl_code_info_t *code)
};

uint8_t nargsmatchesmethod = code->nargs == m->nargs;
jl_code_info_flags_t flags = code_info_flags(code->propagate_inbounds, code->has_fcall,
jl_code_info_flags_t flags = code_info_flags(code->propagate_inbounds, code->has_fcall, code->has_image_globalref,
code->nospecializeinfer, code->isva,
code->inlining, code->constprop,
nargsmatchesmethod,
Expand Down Expand Up @@ -1134,6 +1135,7 @@ JL_DLLEXPORT jl_code_info_t *jl_uncompress_ir(jl_method_t *m, jl_code_instance_t
code->constprop = flags.bits.constprop;
code->propagate_inbounds = flags.bits.propagate_inbounds;
code->has_fcall = flags.bits.has_fcall;
code->has_image_globalref = flags.bits.has_image_globalref;
code->nospecializeinfer = flags.bits.nospecializeinfer;
code->isva = flags.bits.isva;
code->purity.bits = read_uint16(s.s);
Expand Down Expand Up @@ -1228,6 +1230,16 @@ JL_DLLEXPORT uint8_t jl_ir_flag_has_fcall(jl_string_t *data)
return flags.bits.has_fcall;
}

JL_DLLEXPORT uint8_t jl_ir_flag_has_image_globalref(jl_string_t *data)
{
if (jl_is_code_info(data))
return ((jl_code_info_t*)data)->has_image_globalref;
assert(jl_is_string(data));
jl_code_info_flags_t flags;
flags.packed = jl_string_data(data)[ir_offset_flags];
return flags.bits.has_image_globalref;
}

JL_DLLEXPORT uint16_t jl_ir_inlining_cost(jl_string_t *data)
{
if (jl_is_code_info(data))
Expand Down
6 changes: 4 additions & 2 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -3485,7 +3485,7 @@ void jl_init_types(void) JL_GC_DISABLED
jl_code_info_type =
jl_new_datatype(jl_symbol("CodeInfo"), core,
jl_any_type, jl_emptysvec,
jl_perm_symsvec(22,
jl_perm_symsvec(23,
"code",
"debuginfo",
"ssavaluetypes",
Expand All @@ -3502,13 +3502,14 @@ void jl_init_types(void) JL_GC_DISABLED
"nargs",
"propagate_inbounds",
"has_fcall",
"has_image_globalref",
"nospecializeinfer",
"isva",
"inlining",
"constprop",
"purity",
"inlining_cost"),
jl_svec(22,
jl_svec(23,
jl_array_any_type,
jl_debuginfo_type,
jl_any_type,
Expand All @@ -3527,6 +3528,7 @@ void jl_init_types(void) JL_GC_DISABLED
jl_bool_type,
jl_bool_type,
jl_bool_type,
jl_bool_type,
jl_uint8_type,
jl_uint8_type,
jl_uint16_type,
Expand Down
2 changes: 2 additions & 0 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ typedef struct _jl_code_info_t {
// various boolean properties:
uint8_t propagate_inbounds;
uint8_t has_fcall;
uint8_t has_image_globalref;
uint8_t nospecializeinfer;
uint8_t isva;
// uint8 settings
Expand Down Expand Up @@ -2263,6 +2264,7 @@ JL_DLLEXPORT jl_value_t *jl_compress_ir(jl_method_t *m, jl_code_info_t *code);
JL_DLLEXPORT jl_code_info_t *jl_uncompress_ir(jl_method_t *m, jl_code_instance_t *metadata, jl_value_t *data);
JL_DLLEXPORT uint8_t jl_ir_flag_inlining(jl_value_t *data) JL_NOTSAFEPOINT;
JL_DLLEXPORT uint8_t jl_ir_flag_has_fcall(jl_value_t *data) JL_NOTSAFEPOINT;
JL_DLLEXPORT uint8_t jl_ir_flag_has_image_globalref(jl_value_t *data) JL_NOTSAFEPOINT;
JL_DLLEXPORT uint16_t jl_ir_inlining_cost(jl_value_t *data) JL_NOTSAFEPOINT;
JL_DLLEXPORT ssize_t jl_ir_nslots(jl_value_t *data) JL_NOTSAFEPOINT;
JL_DLLEXPORT uint8_t jl_ir_slotflag(jl_value_t *data, size_t i) JL_NOTSAFEPOINT;
Expand Down
1 change: 1 addition & 0 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,7 @@ STATIC_INLINE jl_value_t *undefref_check(jl_datatype_t *dt, jl_value_t *v) JL_NO
typedef struct {
uint16_t propagate_inbounds:1;
uint16_t has_fcall:1;
uint16_t has_image_globalref:1;
uint16_t nospecializeinfer:1;
uint16_t isva:1;
uint16_t nargsmatchesmethod:1;
Expand Down
6 changes: 6 additions & 0 deletions src/method.c
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,11 @@ jl_code_info_t *jl_new_code_info_from_ir(jl_expr_t *ir)
is_flag_stmt = 1;
else if (jl_is_expr(st) && ((jl_expr_t*)st)->head == jl_return_sym)
jl_array_ptr_set(body, j, jl_new_struct(jl_returnnode_type, jl_exprarg(st, 0)));
else if (jl_is_globalref(st)) {
jl_globalref_t *gr = (jl_globalref_t*)st;
if (jl_object_in_image((jl_value_t*)gr->mod))
li->has_image_globalref = 1;
}
else {
if (jl_is_expr(st) && ((jl_expr_t*)st)->head == jl_assign_sym)
st = jl_exprarg(st, 1);
Expand Down Expand Up @@ -593,6 +598,7 @@ JL_DLLEXPORT jl_code_info_t *jl_new_code_info_uninit(void)
src->max_world = ~(size_t)0;
src->propagate_inbounds = 0;
src->has_fcall = 0;
src->has_image_globalref = 0;
src->nospecializeinfer = 0;
src->constprop = 0;
src->inlining = 0;
Expand Down
9 changes: 9 additions & 0 deletions src/module.c
Original file line number Diff line number Diff line change
Expand Up @@ -1294,10 +1294,19 @@ JL_DLLEXPORT jl_binding_partition_t *jl_replace_binding_locked(jl_binding_t *b,
new_world);
}

extern JL_DLLEXPORT _Atomic(size_t) jl_first_image_replacement_world;
JL_DLLEXPORT jl_binding_partition_t *jl_replace_binding_locked2(jl_binding_t *b,
jl_binding_partition_t *old_bpart, jl_value_t *restriction_val, size_t kind, size_t new_world)
{
check_safe_newbinding(b->globalref->mod, b->globalref->name);

// Check if this is a replacing a binding in the system or a package image.
// Until the first such replacement, we can fast-path validation.
// For these purposes, we consider the `Main` module to be a non-sysimg module.
// This is legal, because we special case the `Main` in check_safe_import_from.
if (jl_object_in_image((jl_value_t*)b) && b->globalref->mod != jl_main_module && jl_atomic_load_relaxed(&jl_first_image_replacement_world) == ~(size_t)0)
jl_atomic_store_relaxed(&jl_first_image_replacement_world, new_world);

assert(jl_atomic_load_relaxed(&b->partitions) == old_bpart);
jl_atomic_store_release(&old_bpart->max_world, new_world-1);
jl_binding_partition_t *new_bpart = new_binding_partition();
Expand Down
17 changes: 10 additions & 7 deletions src/staticdata.c
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ External links:

static const size_t WORLD_AGE_REVALIDATION_SENTINEL = 0x1;
JL_DLLEXPORT size_t jl_require_world = ~(size_t)0;
JL_DLLEXPORT _Atomic(size_t) jl_first_image_replacement_world = ~(size_t)0;

#include "staticdata_utils.c"
#include "precompile_utils.c"
Expand Down Expand Up @@ -3541,7 +3542,7 @@ extern void export_jl_small_typeof(void);
int IMAGE_NATIVE_CODE_TAINTED = 0;

// TODO: This should possibly be in Julia
static int jl_validate_binding_partition(jl_binding_t *b, jl_binding_partition_t *bpart, size_t mod_idx, int unchanged_implicit)
static int jl_validate_binding_partition(jl_binding_t *b, jl_binding_partition_t *bpart, size_t mod_idx, int unchanged_implicit, int no_replacement)
{
if (jl_atomic_load_relaxed(&bpart->max_world) != ~(size_t)0)
return 1;
Expand All @@ -3556,10 +3557,13 @@ static int jl_validate_binding_partition(jl_binding_t *b, jl_binding_partition_t
if (!jl_bkind_is_some_import(kind))
return 1;
jl_binding_t *imported_binding = (jl_binding_t*)bpart->restriction;
if (no_replacement)
goto add_backedge;
jl_binding_partition_t *latest_imported_bpart = jl_atomic_load_relaxed(&imported_binding->partitions);
if (!latest_imported_bpart)
return 1;
if (latest_imported_bpart->min_world <= bpart->min_world) {
add_backedge:
// Imported binding is still valid
if ((kind == BINDING_KIND_EXPLICIT || kind == BINDING_KIND_IMPORTED) &&
external_blob_index((jl_value_t*)imported_binding) != mod_idx) {
Expand All @@ -3583,7 +3587,7 @@ static int jl_validate_binding_partition(jl_binding_t *b, jl_binding_partition_t
jl_binding_t *bedge = (jl_binding_t*)edge;
if (!jl_atomic_load_relaxed(&bedge->partitions))
continue;
jl_validate_binding_partition(bedge, jl_atomic_load_relaxed(&bedge->partitions), mod_idx, 0);
jl_validate_binding_partition(bedge, jl_atomic_load_relaxed(&bedge->partitions), mod_idx, 0, 0);
}
}
if (bpart->kind & BINDING_FLAG_EXPORTED) {
Expand All @@ -3600,7 +3604,7 @@ static int jl_validate_binding_partition(jl_binding_t *b, jl_binding_partition_t
if (!jl_atomic_load_relaxed(&importee->partitions))
continue;
JL_UNLOCK(&mod->lock);
jl_validate_binding_partition(importee, jl_atomic_load_relaxed(&importee->partitions), mod_idx, 0);
jl_validate_binding_partition(importee, jl_atomic_load_relaxed(&importee->partitions), mod_idx, 0, 0);
JL_LOCK(&mod->lock);
}
}
Expand Down Expand Up @@ -4070,22 +4074,21 @@ static void jl_restore_system_image_from_stream_(ios_t *f, jl_image_t *image, jl
}
}
if (s.incremental) {
// This needs to be done in a second pass after the binding partitions
// have the proper ABI again.
int no_replacement = jl_atomic_load_relaxed(&jl_first_image_replacement_world) == ~(size_t)0;
for (size_t i = 0; i < s.fixup_objs.len; i++) {
uintptr_t item = (uintptr_t)s.fixup_objs.items[i];
jl_value_t *obj = (jl_value_t*)(image_base + item);
if (jl_is_module(obj)) {
jl_module_t *mod = (jl_module_t*)obj;
size_t mod_idx = external_blob_index((jl_value_t*)mod);
jl_svec_t *table = jl_atomic_load_relaxed(&mod->bindings);
int unchanged_implicit = all_usings_unchanged_implicit(mod);
int unchanged_implicit = no_replacement || all_usings_unchanged_implicit(mod);
for (size_t i = 0; i < jl_svec_len(table); i++) {
jl_binding_t *b = (jl_binding_t*)jl_svecref(table, i);
if ((jl_value_t*)b == jl_nothing)
continue;
jl_binding_partition_t *bpart = jl_atomic_load_relaxed(&b->partitions);
if (!jl_validate_binding_partition(b, bpart, mod_idx, unchanged_implicit)) {
if (!jl_validate_binding_partition(b, bpart, mod_idx, unchanged_implicit, no_replacement)) {
unchanged_implicit = all_usings_unchanged_implicit(mod);
}
}
Expand Down
5 changes: 4 additions & 1 deletion stdlib/Serialization/src/Serialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ const TAGS = Any[
const NTAGS = length(TAGS)
@assert NTAGS == 255

const ser_version = 29 # do not make changes without bumping the version #!
const ser_version = 30 # do not make changes without bumping the version #!

format_version(::AbstractSerializer) = ser_version
format_version(s::Serializer) = s.version
Expand Down Expand Up @@ -1268,6 +1268,9 @@ function deserialize(s::AbstractSerializer, ::Type{CodeInfo})
if format_version(s) >= 20
ci.has_fcall = deserialize(s)
end
if format_version(s) >= 30
ci.has_image_globalref = deserialize(s)::Bool
end
if format_version(s) >= 24
ci.nospecializeinfer = deserialize(s)::Bool
end
Expand Down
49 changes: 49 additions & 0 deletions test/rebinding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -221,3 +221,52 @@ module Regression
end
@test GeoParams57377.B.C.h() == GeoParams57377.B.C.S()
end

# Test that the validation bypass fast path is not defeated by loading InteractiveUtils
@test parse(UInt, readchomp(`$(Base.julia_cmd()) -e 'using InteractiveUtils; show(unsafe_load(cglobal(:jl_first_image_replacement_world, UInt)))'`)) == typemax(UInt)

# Test that imported module binding backedges are still added in a new module that has the fast path active
let test_code =
"""
using Test
@assert unsafe_load(cglobal(:jl_first_image_replacement_world, UInt)) == typemax(UInt)
include("precompile_utils.jl")

precompile_test_harness("rebinding precompile") do load_path
write(joinpath(load_path, "LotsOfBindingsToDelete2.jl"),
"module LotsOfBindingsToDelete2
const delete_me_6 = 6
end")
Base.compilecache(Base.PkgId("LotsOfBindingsToDelete2"))
write(joinpath(load_path, "UseTheBindings2.jl"),
"module UseTheBindings2
import LotsOfBindingsToDelete2: delete_me_6
f_use_bindings6() = delete_me_6
# Code Instances for each of these
@assert (f_use_bindings6(),) == (6,)
end")
Base.compilecache(Base.PkgId("UseTheBindings2"))
@eval using LotsOfBindingsToDelete2
@eval using UseTheBindings2
invokelatest() do
@test UseTheBindings2.f_use_bindings6() == 6
Base.delete_binding(LotsOfBindingsToDelete2, :delete_me_6)
invokelatest() do
@test_throws UndefVarError UseTheBindings2.f_use_bindings6()
end
end
end

finish_precompile_test!()
"""
@test success(pipeline(`$(Base.julia_cmd()) -e $test_code`; stderr))
end

# Image Globalref smoke test
module ImageGlobalRefFlag
using Test
@eval fimage() = $(GlobalRef(Base, :sin))
fnoimage() = x
@test Base.has_image_globalref(first(methods(fimage)))
@test !Base.has_image_globalref(first(methods(fnoimage)))
end