Skip to content

Commit 10a1d6f

Browse files
Kenoaviatesk
andauthored
Teach compiler about partitioned bindings (#56299)
This commit teaches to compiler to update its world bounds whenever it looks at a binding partition, making the compiler sound in the presence of a partitioned binding. The key adjustment is that the compiler is no longer allowed to directly query the binding table without recording the world bounds, so all the various abstract evaluations that look at bindings need to be adjusted and are no longer pure tfuncs. We used to look at bindings a lot more, but thanks to earlier prep work to remove unnecessary binding-dependent code (#55288, #55289 and #55271), these changes become relatively straightforward. Note that as before, we do not create any binding partitions by default, so this commit is mostly preperatory. --------- Co-authored-by: Shuhei Kadowaki <40514306+aviatesk@users.noreply.github.com>
1 parent c3c3cd1 commit 10a1d6f

File tree

16 files changed

+404
-221
lines changed

16 files changed

+404
-221
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 329 additions & 36 deletions
Large diffs are not rendered by default.

base/compiler/cicache.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ WorldRange(r::UnitRange) = WorldRange(first(r), last(r))
3131
first(wr::WorldRange) = wr.min_world
3232
last(wr::WorldRange) = wr.max_world
3333
in(world::UInt, wr::WorldRange) = wr.min_world <= world <= wr.max_world
34+
min_world(wr::WorldRange) = first(wr)
35+
max_world(wr::WorldRange) = last(wr)
3436

3537
function intersect(a::WorldRange, b::WorldRange)
3638
ret = WorldRange(max(a.min_world, b.min_world), min(a.max_world, b.max_world))

base/compiler/optimize.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -307,8 +307,10 @@ function stmt_effect_flags(𝕃ₒ::AbstractLattice, @nospecialize(stmt), @nospe
307307
isa(stmt, GotoNode) && return (true, false, true)
308308
isa(stmt, GotoIfNot) && return (true, false, (𝕃ₒ, argextype(stmt.cond, src), Bool))
309309
if isa(stmt, GlobalRef)
310-
nothrow = consistent = isdefinedconst_globalref(stmt)
311-
return (consistent, nothrow, nothrow)
310+
# Modeled more precisely in abstract_eval_globalref. In general, if a
311+
# GlobalRef was moved to statement position, it is probably not `const`,
312+
# so we can't say much about it anyway.
313+
return (false, false, false)
312314
elseif isa(stmt, Expr)
313315
(; head, args) = stmt
314316
if head === :static_parameter
@@ -444,7 +446,7 @@ function argextype(
444446
elseif isa(x, QuoteNode)
445447
return Const(x.value)
446448
elseif isa(x, GlobalRef)
447-
return abstract_eval_globalref_type(x)
449+
return abstract_eval_globalref_type(x, src)
448450
elseif isa(x, PhiNode) || isa(x, PhiCNode) || isa(x, UpsilonNode)
449451
return Any
450452
elseif isa(x, PiNode)
@@ -1277,7 +1279,7 @@ function convert_to_ircode(ci::CodeInfo, sv::OptimizationState)
12771279
# types of call arguments only once `slot2reg` converts this `IRCode` to the SSA form
12781280
# and eliminates slots (see below)
12791281
argtypes = sv.slottypes
1280-
return IRCode(stmts, sv.cfg, di, argtypes, meta, sv.sptypes)
1282+
return IRCode(stmts, sv.cfg, di, argtypes, meta, sv.sptypes, WorldRange(ci.min_world, ci.max_world))
12811283
end
12821284

12831285
function process_meta!(meta::Vector{Expr}, @nospecialize stmt)

base/compiler/ssair/inlining.jl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1694,11 +1694,6 @@ function early_inline_special_case(ir::IRCode, stmt::Expr, flag::UInt32,
16941694
if has_flag(flag, IR_FLAG_NOTHROW)
16951695
return SomeCase(quoted(val))
16961696
end
1697-
elseif f === Core.get_binding_type
1698-
length(argtypes) == 3 || return nothing
1699-
if get_binding_type_effect_free(argtypes[2], argtypes[3])
1700-
return SomeCase(quoted(val))
1701-
end
17021697
end
17031698
end
17041699
if f === compilerbarrier

base/compiler/ssair/ir.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -430,22 +430,25 @@ struct IRCode
430430
cfg::CFG
431431
new_nodes::NewNodeStream
432432
meta::Vector{Expr}
433+
valid_worlds::WorldRange
433434

434-
function IRCode(stmts::InstructionStream, cfg::CFG, debuginfo::DebugInfoStream, argtypes::Vector{Any}, meta::Vector{Expr}, sptypes::Vector{VarState})
435+
function IRCode(stmts::InstructionStream, cfg::CFG, debuginfo::DebugInfoStream,
436+
argtypes::Vector{Any}, meta::Vector{Expr}, sptypes::Vector{VarState},
437+
valid_worlds=WorldRange(typemin(UInt), typemax(UInt)))
435438
return new(stmts, argtypes, sptypes, debuginfo, cfg, NewNodeStream(), meta)
436439
end
437440
function IRCode(ir::IRCode, stmts::InstructionStream, cfg::CFG, new_nodes::NewNodeStream)
438441
di = ir.debuginfo
439442
@assert di.codelocs === stmts.line
440-
return new(stmts, ir.argtypes, ir.sptypes, di, cfg, new_nodes, ir.meta)
443+
return new(stmts, ir.argtypes, ir.sptypes, di, cfg, new_nodes, ir.meta, ir.valid_worlds)
441444
end
442445
global function copy(ir::IRCode)
443446
di = ir.debuginfo
444447
stmts = copy(ir.stmts)
445448
di = copy(di)
446449
di.edges = copy(di.edges)
447450
di.codelocs = stmts.line
448-
return new(stmts, copy(ir.argtypes), copy(ir.sptypes), di, copy(ir.cfg), copy(ir.new_nodes), copy(ir.meta))
451+
return new(stmts, copy(ir.argtypes), copy(ir.sptypes), di, copy(ir.cfg), copy(ir.new_nodes), copy(ir.meta), ir.valid_worlds)
449452
end
450453
end
451454

base/compiler/ssair/legacy.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ function inflate_ir!(ci::CodeInfo, sptypes::Vector{VarState}, argtypes::Vector{A
4444
di = DebugInfoStream(nothing, ci.debuginfo, nstmts)
4545
stmts = InstructionStream(code, ssavaluetypes, info, di.codelocs, ci.ssaflags)
4646
meta = Expr[]
47-
return IRCode(stmts, cfg, di, argtypes, meta, sptypes)
47+
return IRCode(stmts, cfg, di, argtypes, meta, sptypes, WorldRange(ci.min_world, ci.max_world))
4848
end
4949

5050
"""

base/compiler/ssair/passes.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -474,9 +474,9 @@ function lift_leaves(compact::IncrementalCompact, field::Int,
474474
elseif isa(leaf, QuoteNode)
475475
leaf = leaf.value
476476
elseif isa(leaf, GlobalRef)
477-
mod, name = leaf.mod, leaf.name
478-
if isdefined(mod, name) && isconst(mod, name)
479-
leaf = getglobal(mod, name)
477+
typ = argextype(leaf, compact)
478+
if isa(typ, Const)
479+
leaf = typ.val
480480
else
481481
return nothing
482482
end

base/compiler/ssair/slot2ssa.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ function typ_for_val(@nospecialize(x), ci::CodeInfo, ir::IRCode, idx::Int, slott
176176
end
177177
return (ci.ssavaluetypes::Vector{Any})[idx]
178178
end
179-
isa(x, GlobalRef) && return abstract_eval_globalref_type(x)
179+
isa(x, GlobalRef) && return abstract_eval_globalref_type(x, ci)
180180
isa(x, SSAValue) && return (ci.ssavaluetypes::Vector{Any})[x.id]
181181
isa(x, Argument) && return slottypes[x.n]
182182
isa(x, NewSSAValue) && return types(ir)[new_to_regular(x, length(ir.stmts))]

base/compiler/tfuncs.jl

Lines changed: 31 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -407,10 +407,7 @@ end
407407
if isa(a1, DataType) && !isabstracttype(a1)
408408
if a1 === Module
409409
hasintersect(widenconst(sym), Symbol) || return Bottom
410-
if isa(sym, Const) && isa(sym.val, Symbol) && isa(arg1, Const) &&
411-
isdefinedconst_globalref(GlobalRef(arg1.val::Module, sym.val::Symbol))
412-
return Const(true)
413-
end
410+
# isa(sym, Const) case intercepted in abstract interpretation
414411
elseif isa(sym, Const)
415412
val = sym.val
416413
if isa(val, Symbol)
@@ -1160,7 +1157,9 @@ end
11601157
if isa(sv, Module)
11611158
setfield && return Bottom
11621159
if isa(nv, Symbol)
1163-
return abstract_eval_global(sv, nv)
1160+
# In ordinary inference, this case is intercepted early and
1161+
# re-routed to `getglobal`.
1162+
return Any
11641163
end
11651164
return Bottom
11661165
end
@@ -1407,8 +1406,9 @@ end
14071406
elseif ff === Core.modifyglobal!
14081407
o = unwrapva(argtypes[2])
14091408
f = unwrapva(argtypes[3])
1410-
RT = modifyglobal!_tfunc(𝕃ᵢ, o, f, Any, Any, Symbol)
1411-
TF = getglobal_tfunc(𝕃ᵢ, o, f, Symbol)
1409+
GT = abstract_eval_get_binding_type(interp, sv, o, f).rt
1410+
RT = isa(GT, Const) ? Pair{GT.val, GT.val} : Pair
1411+
TF = isa(GT, Const) ? GT.val : Any
14121412
elseif ff === Core.memoryrefmodify!
14131413
o = unwrapva(argtypes[2])
14141414
RT = memoryrefmodify!_tfunc(𝕃ᵢ, o, Any, Any, Symbol, Bool)
@@ -2277,20 +2277,6 @@ function _builtin_nothrow(𝕃::AbstractLattice, @nospecialize(f::Builtin), argt
22772277
elseif f === typeassert
22782278
na == 2 || return false
22792279
return typeassert_nothrow(𝕃, argtypes[1], argtypes[2])
2280-
elseif f === getglobal
2281-
if na == 2
2282-
return getglobal_nothrow(argtypes[1], argtypes[2])
2283-
elseif na == 3
2284-
return getglobal_nothrow(argtypes[1], argtypes[2], argtypes[3])
2285-
end
2286-
return false
2287-
elseif f === setglobal!
2288-
if na == 3
2289-
return setglobal!_nothrow(argtypes[1], argtypes[2], argtypes[3])
2290-
elseif na == 4
2291-
return setglobal!_nothrow(argtypes[1], argtypes[2], argtypes[3], argtypes[4])
2292-
end
2293-
return false
22942280
elseif f === Core.get_binding_type
22952281
na == 2 || return false
22962282
return get_binding_type_nothrow(𝕃, argtypes[1], argtypes[2])
@@ -2473,7 +2459,8 @@ function getfield_effects(𝕃::AbstractLattice, argtypes::Vector{Any}, @nospeci
24732459
end
24742460
end
24752461
if hasintersect(widenconst(obj), Module)
2476-
inaccessiblememonly = getglobal_effects(argtypes, rt).inaccessiblememonly
2462+
# Modeled more precisely in abstract_eval_getglobal
2463+
inaccessiblememonly = ALWAYS_FALSE
24772464
elseif is_mutation_free_argtype(obj)
24782465
inaccessiblememonly = ALWAYS_TRUE
24792466
else
@@ -2482,24 +2469,7 @@ function getfield_effects(𝕃::AbstractLattice, argtypes::Vector{Any}, @nospeci
24822469
return Effects(EFFECTS_TOTAL; consistent, nothrow, inaccessiblememonly, noub)
24832470
end
24842471

2485-
function getglobal_effects(argtypes::Vector{Any}, @nospecialize(rt))
2486-
2 length(argtypes) 3 || return EFFECTS_THROWS
2487-
consistent = inaccessiblememonly = ALWAYS_FALSE
2488-
nothrow = false
2489-
M, s = argtypes[1], argtypes[2]
2490-
if (length(argtypes) == 3 ? getglobal_nothrow(M, s, argtypes[3]) : getglobal_nothrow(M, s))
2491-
nothrow = true
2492-
# typeasserts below are already checked in `getglobal_nothrow`
2493-
Mval, sval = (M::Const).val::Module, (s::Const).val::Symbol
2494-
if isconst(Mval, sval)
2495-
consistent = ALWAYS_TRUE
2496-
if is_mutation_free_argtype(rt)
2497-
inaccessiblememonly = ALWAYS_TRUE
2498-
end
2499-
end
2500-
end
2501-
return Effects(EFFECTS_TOTAL; consistent, nothrow, inaccessiblememonly)
2502-
end
2472+
25032473

25042474
"""
25052475
builtin_effects(𝕃::AbstractLattice, f::Builtin, argtypes::Vector{Any}, rt) -> Effects
@@ -2525,11 +2495,13 @@ function builtin_effects(𝕃::AbstractLattice, @nospecialize(f::Builtin), argty
25252495
if f === isdefined
25262496
return isdefined_effects(𝕃, argtypes)
25272497
elseif f === getglobal
2528-
return getglobal_effects(argtypes, rt)
2498+
2 length(argtypes) 3 || return EFFECTS_THROWS
2499+
# Modeled more precisely in abstract_eval_getglobal
2500+
return Effects(EFFECTS_TOTAL; consistent=ALWAYS_FALSE, nothrow=false, inaccessiblememonly=ALWAYS_FALSE)
25292501
elseif f === Core.get_binding_type
25302502
length(argtypes) == 2 || return EFFECTS_THROWS
2531-
effect_free = get_binding_type_effect_free(argtypes[1], argtypes[2]) ? ALWAYS_TRUE : ALWAYS_FALSE
2532-
return Effects(EFFECTS_TOTAL; effect_free)
2503+
# Modeled more precisely in abstract_eval_get_binding_type
2504+
return Effects(EFFECTS_TOTAL; effect_free=ALWAYS_FALSE)
25332505
elseif f === compilerbarrier
25342506
length(argtypes) == 2 || return Effects(EFFECTS_THROWS; consistent=ALWAYS_FALSE)
25352507
setting = argtypes[1]
@@ -3065,118 +3037,28 @@ function typename_static(@nospecialize(t))
30653037
return isType(t) ? _typename(t.parameters[1]) : Core.TypeName
30663038
end
30673039

3068-
function global_order_nothrow(@nospecialize(o), loading::Bool, storing::Bool)
3069-
o isa Const || return false
3040+
function global_order_exct(@nospecialize(o), loading::Bool, storing::Bool)
3041+
if !(o isa Const)
3042+
if o === Symbol
3043+
return ConcurrencyViolationError
3044+
elseif !hasintersect(o, Symbol)
3045+
return TypeError
3046+
else
3047+
return Union{ConcurrencyViolationError, TypeError}
3048+
end
3049+
end
30703050
sym = o.val
30713051
if sym isa Symbol
30723052
order = get_atomic_order(sym, loading, storing)
3073-
return order !== MEMORY_ORDER_INVALID && order !== MEMORY_ORDER_NOTATOMIC
3074-
end
3075-
return false
3076-
end
3077-
@nospecs function getglobal_nothrow(M, s, o)
3078-
global_order_nothrow(o, #=loading=#true, #=storing=#false) || return false
3079-
return getglobal_nothrow(M, s)
3080-
end
3081-
@nospecs function getglobal_nothrow(M, s)
3082-
if M isa Const && s isa Const
3083-
M, s = M.val, s.val
3084-
if M isa Module && s isa Symbol
3085-
return isdefinedconst_globalref(GlobalRef(M, s))
3086-
end
3087-
end
3088-
return false
3089-
end
3090-
@nospecs function getglobal_tfunc(𝕃::AbstractLattice, M, s, order=Symbol)
3091-
if M isa Const && s isa Const
3092-
M, s = M.val, s.val
3093-
if M isa Module && s isa Symbol
3094-
return abstract_eval_global(M, s)
3095-
end
3096-
return Bottom
3097-
elseif !(hasintersect(widenconst(M), Module) && hasintersect(widenconst(s), Symbol))
3098-
return Bottom
3099-
end
3100-
T = get_binding_type_tfunc(𝕃, M, s)
3101-
T isa Const && return T.val
3102-
return Any
3103-
end
3104-
@nospecs function setglobal!_tfunc(𝕃::AbstractLattice, M, s, v, order=Symbol)
3105-
if !(hasintersect(widenconst(M), Module) && hasintersect(widenconst(s), Symbol))
3106-
return Bottom
3107-
end
3108-
return v
3109-
end
3110-
@nospecs function swapglobal!_tfunc(𝕃::AbstractLattice, M, s, v, order=Symbol)
3111-
setglobal!_tfunc(𝕃, M, s, v) === Bottom && return Bottom
3112-
return getglobal_tfunc(𝕃, M, s)
3113-
end
3114-
@nospecs function modifyglobal!_tfunc(𝕃::AbstractLattice, M, s, op, v, order=Symbol)
3115-
T = get_binding_type_tfunc(𝕃, M, s)
3116-
T === Bottom && return Bottom
3117-
T isa Const || return Pair
3118-
T = T.val
3119-
return Pair{T, T}
3120-
end
3121-
@nospecs function replaceglobal!_tfunc(𝕃::AbstractLattice, M, s, x, v, success_order=Symbol, failure_order=Symbol)
3122-
v = setglobal!_tfunc(𝕃, M, s, v)
3123-
v === Bottom && return Bottom
3124-
T = get_binding_type_tfunc(𝕃, M, s)
3125-
T === Bottom && return Bottom
3126-
T isa Const || return ccall(:jl_apply_cmpswap_type, Any, (Any,), T) where T
3127-
T = T.val
3128-
return ccall(:jl_apply_cmpswap_type, Any, (Any,), T)
3129-
end
3130-
@nospecs function setglobalonce!_tfunc(𝕃::AbstractLattice, M, s, v, success_order=Symbol, failure_order=Symbol)
3131-
setglobal!_tfunc(𝕃, M, s, v) === Bottom && return Bottom
3132-
return Bool
3133-
end
3134-
3135-
add_tfunc(Core.getglobal, 2, 3, getglobal_tfunc, 1)
3136-
add_tfunc(Core.setglobal!, 3, 4, setglobal!_tfunc, 3)
3137-
add_tfunc(Core.swapglobal!, 3, 4, swapglobal!_tfunc, 3)
3138-
add_tfunc(Core.modifyglobal!, 4, 5, modifyglobal!_tfunc, 3)
3139-
add_tfunc(Core.replaceglobal!, 4, 6, replaceglobal!_tfunc, 3)
3140-
add_tfunc(Core.setglobalonce!, 3, 5, setglobalonce!_tfunc, 3)
3141-
3142-
@nospecs function setglobal!_nothrow(M, s, newty, o)
3143-
global_order_nothrow(o, #=loading=#false, #=storing=#true) || return false
3144-
return setglobal!_nothrow(M, s, newty)
3145-
end
3146-
@nospecs function setglobal!_nothrow(M, s, newty)
3147-
if M isa Const && s isa Const
3148-
M, s = M.val, s.val
3149-
if isa(M, Module) && isa(s, Symbol)
3150-
return global_assignment_nothrow(M, s, newty)
3151-
end
3152-
end
3153-
return false
3154-
end
3155-
3156-
function global_assignment_nothrow(M::Module, s::Symbol, @nospecialize(newty))
3157-
if !isconst(M, s)
3158-
ty = ccall(:jl_get_binding_type, Any, (Any, Any), M, s)
3159-
return ty isa Type && widenconst(newty) <: ty
3160-
end
3161-
return false
3162-
end
3163-
3164-
@nospecs function get_binding_type_effect_free(M, s)
3165-
if M isa Const && s isa Const
3166-
M, s = M.val, s.val
3167-
if M isa Module && s isa Symbol
3168-
return ccall(:jl_get_binding_type, Any, (Any, Any), M, s) !== nothing
3053+
if order !== MEMORY_ORDER_INVALID && order !== MEMORY_ORDER_NOTATOMIC
3054+
return Union{}
3055+
else
3056+
return ConcurrencyViolationError
31693057
end
3058+
else
3059+
return TypeError
31703060
end
3171-
return false
3172-
end
3173-
@nospecs function get_binding_type_tfunc(𝕃::AbstractLattice, M, s)
3174-
if get_binding_type_effect_free(M, s)
3175-
return Const(Core.get_binding_type((M::Const).val::Module, (s::Const).val::Symbol))
3176-
end
3177-
return Type
31783061
end
3179-
add_tfunc(Core.get_binding_type, 2, 2, get_binding_type_tfunc, 0)
31803062

31813063
@nospecs function get_binding_type_nothrow(𝕃::AbstractLattice, M, s)
31823064
= partialorder(𝕃)

base/runtime_internals.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,13 +230,20 @@ const BINDING_KIND_DECLARED = 0x7
230230
const BINDING_KIND_GUARD = 0x8
231231

232232
is_some_const_binding(kind::UInt8) = (kind == BINDING_KIND_CONST || kind == BINDING_KIND_CONST_IMPORT)
233+
is_some_imported(kind::UInt8) = (kind == BINDING_KIND_IMPLICIT || kind == BINDING_KIND_EXPLICIT || kind == BINDING_KIND_IMPORTED)
234+
is_some_guard(kind::UInt8) = (kind == BINDING_KIND_GUARD || kind == BINDING_KIND_DECLARED || kind == BINDING_KIND_FAILED)
233235

234236
function lookup_binding_partition(world::UInt, b::Core.Binding)
235237
ccall(:jl_get_binding_partition, Ref{Core.BindingPartition}, (Any, UInt), b, world)
236238
end
237239

238240
function lookup_binding_partition(world::UInt, gr::Core.GlobalRef)
239-
ccall(:jl_get_globalref_partition, Ref{Core.BindingPartition}, (Any, UInt), gr, world)
241+
if isdefined(gr, :binding)
242+
b = gr.binding
243+
else
244+
b = ccall(:jl_get_module_binding, Ref{Core.Binding}, (Any, Any, Cint), gr.mod, gr.name, true)
245+
end
246+
return lookup_binding_partition(world, b)
240247
end
241248

242249
partition_restriction(bpart::Core.BindingPartition) = ccall(:jl_bpart_get_restriction_value, Any, (Any,), bpart)

0 commit comments

Comments
 (0)