Skip to content

Commit 171c4bf

Browse files
committed
Make EnterNode save/restore dynamic scope
As discussed in #51352, this gives `EnterNode` the ability to set (and restore on leave or catch edge) jl_current_task->scope. Manual modifications of the task field after the task has started are considered undefined behavior. In addition, we gain a new intrinsic to access current_task->scope and both inference and the optimizer will forward scopes from EnterNodes to this intrinsic (non-interprocedurally). Together with #51993 this is sufficient to fully optimize ScopedValues (non-interprocedurally at least).
1 parent c30d45d commit 171c4bf

21 files changed

+208
-61
lines changed

base/boot.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,7 @@ eval(Core, quote
460460
ReturnNode() = $(Expr(:new, :ReturnNode)) # unassigned val indicates unreachable
461461
GotoIfNot(@nospecialize(cond), dest::Int) = $(Expr(:new, :GotoIfNot, :cond, :dest))
462462
EnterNode(dest::Int) = $(Expr(:new, :EnterNode, :dest))
463+
EnterNode(dest::Int, @nospecialize(scope)) = $(Expr(:new, :EnterNode, :dest, :scope))
463464
LineNumberNode(l::Int) = $(Expr(:new, :LineNumberNode, :l, nothing))
464465
function LineNumberNode(l::Int, @nospecialize(f))
465466
isa(f, String) && (f = Symbol(f))
@@ -967,6 +968,7 @@ arraysize(a::Array, i::Int) = sle_int(i, nfields(a.size)) ? getfield(a.size, i)
967968
export arrayref, arrayset, arraysize, const_arrayref
968969

969970
# For convenience
970-
EnterNode(old::EnterNode, new_dest::Int) = EnterNode(new_dest)
971+
EnterNode(old::EnterNode, new_dest::Int) = isdefined(old, :scope) ?
972+
EnterNode(new_dest, old.scope) : EnterNode(new_dest)
971973

972974
ccall(:jl_set_istopmod, Cvoid, (Any, Bool), Core, true)

base/compiler/abstractinterpretation.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3259,6 +3259,19 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
32593259
elseif isa(stmt, EnterNode)
32603260
ssavaluetypes[currpc] = Any
32613261
add_curr_ssaflag!(frame, IR_FLAG_NOTHROW)
3262+
if isdefined(stmt, :scope)
3263+
scopet = abstract_eval_value(interp, stmt.scope, currstate, frame)
3264+
handler = frame.handlers[frame.handler_at[frame.currpc+1][1]]
3265+
@assert handler.scopet !== nothing
3266+
if !(𝕃ᵢ, scopet, handler.scopet)
3267+
handler.scopet = tmerge(𝕃ᵢ, scopet, handler.scopet)
3268+
if isdefined(handler, :scope_uses)
3269+
for bb in handler.scope_uses
3270+
push!(W, bb)
3271+
end
3272+
end
3273+
end
3274+
end
32623275
@goto fallthrough
32633276
elseif isexpr(stmt, :leave)
32643277
ssavaluetypes[currpc] = Any

base/compiler/inferencestate.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,8 +205,10 @@ const CACHE_MODE_VOLATILE = 0x01 << 2 # not cached, optimization allowed
205205

206206
mutable struct TryCatchFrame
207207
exct
208+
scopet
208209
const enter_idx::Int
209-
TryCatchFrame(@nospecialize(exct), enter_idx::Int) = new(exct, enter_idx)
210+
scope_uses::Vector{Int}
211+
TryCatchFrame(@nospecialize(exct), @nospecialize(scopet), enter_idx::Int) = new(exct, scopet, enter_idx)
210212
end
211213

212214
mutable struct InferenceState
@@ -364,7 +366,7 @@ function compute_trycatch(code::Vector{Any}, ip::BitSet)
364366
stmt = code[pc]
365367
if isa(stmt, EnterNode)
366368
l = stmt.catch_dest
367-
push!(handlers, TryCatchFrame(Bottom, pc))
369+
push!(handlers, TryCatchFrame(Bottom, isdefined(stmt, :scope) ? Bottom : nothing, pc))
368370
handler_id = length(handlers)
369371
handler_at[pc + 1] = (handler_id, 0)
370372
push!(ip, pc + 1)

base/compiler/ssair/ir.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1387,6 +1387,7 @@ function process_node!(compact::IncrementalCompact, result_idx::Int, inst::Instr
13871387
result_idx += 1
13881388
end
13891389
elseif cfg_transforms_enabled && isa(stmt, EnterNode)
1390+
stmt = renumber_ssa2!(stmt, ssa_rename, used_ssas, new_new_used_ssas, late_fixup, result_idx, do_rename_ssa, mark_refined!)::EnterNode
13901391
label = bb_rename_succ[stmt.catch_dest]
13911392
@assert label > 0
13921393
ssa_rename[idx] = SSAValue(result_idx)

base/compiler/ssair/passes.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -965,6 +965,29 @@ function fold_ifelse!(compact::IncrementalCompact, idx::Int, stmt::Expr)
965965
return false
966966
end
967967

968+
function fold_current_scope!(compact::IncrementalCompact, idx::Int, stmt::Expr, lazydomtree)
969+
domtree = get!(lazydomtree)
970+
971+
# The frontend enforces the invariant that any :enter dominates its active
972+
# region, so all we have to do here is walk the domtree to find it.
973+
dombb = block_for_inst(compact, SSAValue(idx))
974+
975+
local bbterminator
976+
while true
977+
dombb = domtree.idoms_bb[dombb]
978+
979+
# Did not find any dominating :enter - scope is inherited from the outside
980+
dombb == 0 && return nothing
981+
982+
bbterminator = compact[SSAValue(last(compact.cfg_transform.result_bbs[dombb].stmts))][:stmt]
983+
isa(bbterminator, EnterNode) || continue
984+
isdefined(bbterminator, :scope) || continue
985+
compact[idx] = bbterminator.scope
986+
return nothing
987+
end
988+
end
989+
990+
968991
# NOTE we use `IdSet{Int}` instead of `BitSet` for in these passes since they work on IR after inlining,
969992
# which can be very large sometimes, and program counters in question are often very sparse
970993
const SPCSet = IdSet{Int}
@@ -1094,6 +1117,8 @@ function sroa_pass!(ir::IRCode, inlining::Union{Nothing,InliningState}=nothing)
10941117
lift_comparison!(isa, compact, idx, stmt, 𝕃ₒ)
10951118
elseif is_known_call(stmt, Core.ifelse, compact)
10961119
fold_ifelse!(compact, idx, stmt)
1120+
elseif is_known_call(stmt, Core.current_scope, compact)
1121+
fold_current_scope!(compact, idx, stmt, lazydomtree)
10971122
elseif isexpr(stmt, :new)
10981123
refine_new_effects!(𝕃ₒ, compact, idx, stmt)
10991124
end

base/compiler/ssair/show.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ function print_stmt(io::IO, idx::Int, @nospecialize(stmt), used::BitSet, maxleng
6969
# given control flow information, we prefer to print these with the basic block #, instead of the ssa %
7070
elseif isa(stmt, EnterNode)
7171
print(io, "enter #", stmt.catch_dest, "")
72+
if isdefined(stmt, :scope)
73+
print(io, " with scope ")
74+
show_unquoted(io, stmt.scope, indent)
75+
end
7276
elseif stmt isa GotoNode
7377
print(io, "goto #", stmt.label)
7478
elseif stmt isa PhiNode

base/compiler/ssair/verify.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
function maybe_show_ir(ir::IRCode)
44
if isdefined(Core, :Main)
5-
Core.Main.Base.display(ir)
5+
invokelatest(Core.Main.Base.display, ir)
66
end
77
end
88

base/compiler/tfuncs.jl

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2488,6 +2488,12 @@ function builtin_effects(𝕃::AbstractLattice, @nospecialize(f::Builtin), argty
24882488
return Effects(EFFECTS_TOTAL;
24892489
consistent = (isa(setting, Const) && setting.val === :conditional) ? ALWAYS_TRUE : ALWAYS_FALSE,
24902490
nothrow = compilerbarrier_nothrow(setting, nothing))
2491+
elseif f === Core.current_scope
2492+
length(argtypes) == 0 || return Effects(EFFECTS_THROWS; consistent=ALWAYS_FALSE)
2493+
return Effects(EFFECTS_TOTAL;
2494+
consistent = ALWAYS_FALSE,
2495+
notaskstate = false,
2496+
)
24912497
else
24922498
if contains_is(_CONSISTENT_BUILTINS, f)
24932499
consistent = ALWAYS_TRUE
@@ -2554,6 +2560,32 @@ function memoryop_noub(@nospecialize(f), argtypes::Vector{Any})
25542560
return false
25552561
end
25562562

2563+
function current_scope_tfunc(interp::AbstractInterpreter, sv::InferenceState)
2564+
pc = sv.currpc
2565+
while true
2566+
handleridx = sv.handler_at[pc][2]
2567+
if handleridx == 0
2568+
# No local scope available - inherited from the outside
2569+
return Any
2570+
end
2571+
pchandler = sv.handlers[handleridx]
2572+
# Remember that we looked at this handler, so we get re-scheduled
2573+
# if the scope information changes
2574+
isdefined(pchandler, :scope_uses) || (pchandler.scope_uses = Int[])
2575+
pcbb = block_for_inst(sv.cfg, pc)
2576+
if findfirst(pchandler.scope_uses, pcbb) === nothing
2577+
push!(pchandler.scope_uses, pcbb)
2578+
end
2579+
scope = pchandler.scopet
2580+
if scope !== nothing
2581+
# Found the scope - forward it
2582+
return scope
2583+
end
2584+
pc = pchandler.enter_idx
2585+
end
2586+
end
2587+
current_scope_tfunc(interp::AbstractInterpreter, sv) = Any
2588+
25572589
"""
25582590
builtin_nothrow(𝕃::AbstractLattice, f::Builtin, argtypes::Vector{Any}, rt) -> Bool
25592591
@@ -2568,9 +2600,6 @@ end
25682600
function builtin_tfunction(interp::AbstractInterpreter, @nospecialize(f), argtypes::Vector{Any},
25692601
sv::Union{AbsIntState, Nothing})
25702602
𝕃ᵢ = typeinf_lattice(interp)
2571-
if f === tuple
2572-
return tuple_tfunc(𝕃ᵢ, argtypes)
2573-
end
25742603
if isa(f, IntrinsicFunction)
25752604
if is_pure_intrinsic_infer(f) && all(@nospecialize(a) -> isa(a, Const), argtypes)
25762605
argvals = anymap(@nospecialize(a) -> (a::Const).val, argtypes)
@@ -2596,6 +2625,12 @@ function builtin_tfunction(interp::AbstractInterpreter, @nospecialize(f), argtyp
25962625
end
25972626
tf = T_IFUNC[iidx]
25982627
else
2628+
if f === tuple
2629+
return tuple_tfunc(𝕃ᵢ, argtypes)
2630+
elseif f === Core.current_scope
2631+
length(argtypes) == 0 || return Bottom
2632+
return current_scope_tfunc(interp, sv)
2633+
end
25992634
fidx = find_tfunc(f)
26002635
if fidx === nothing
26012636
# unknown/unhandled builtin function

base/compiler/validation.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ const VALID_EXPR_HEADS = IdDict{Symbol,UnitRange{Int}}(
1313
:new => 1:typemax(Int),
1414
:splatnew => 2:2,
1515
:the_exception => 0:0,
16-
:enter => 1:1,
16+
:enter => 1:2,
1717
:leave => 1:typemax(Int),
1818
:pop_exception => 1:1,
1919
:inbounds => 1:1,
@@ -160,6 +160,13 @@ function validate_code!(errors::Vector{InvalidCodeError}, c::CodeInfo, is_top_le
160160
push!(errors, InvalidCodeError(INVALID_CALL_ARG, x.cond))
161161
end
162162
validate_val!(x.cond)
163+
elseif isa(x, EnterNode)
164+
if isdefined(x, :scope)
165+
if !is_valid_argument(x.scope)
166+
push!(errors, InvalidCodeError(INVALID_CALL_ARG, x.scope))
167+
end
168+
validate_val!(x.scope)
169+
end
163170
elseif isa(x, ReturnNode)
164171
if isdefined(x, :val)
165172
if !is_valid_return(x.val)

base/scopedvalues.jl

Lines changed: 12 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,6 @@ function Scope(scope, pairs::Pair{<:ScopedValue}...)
7676
end
7777
Scope(::Nothing) = nothing
7878

79-
"""
80-
current_scope()::Union{Nothing, Scope}
81-
82-
Return the current dynamic scope.
83-
"""
84-
current_scope() = current_task().scope::Union{Nothing, Scope}
85-
8679
function Base.show(io::IO, scope::Scope)
8780
print(io, Scope, "(")
8881
first = true
@@ -111,8 +104,7 @@ return `nothing`. Otherwise returns `Some{T}` with the current
111104
value.
112105
"""
113106
function get(val::ScopedValue{T}) where {T}
114-
# Inline current_scope to avoid doing the type assertion twice.
115-
scope = current_task().scope
107+
scope = Core.current_scope()::Union{Scope, Nothing}
116108
if scope === nothing
117109
isassigned(val) && return Some{T}(val.default)
118110
return nothing
@@ -146,25 +138,6 @@ function Base.show(io::IO, val::ScopedValue)
146138
print(io, ')')
147139
end
148140

149-
"""
150-
with(f, (var::ScopedValue{T} => val::T)...)
151-
152-
Execute `f` in a new scope with `var` set to `val`.
153-
"""
154-
function with(f, pair::Pair{<:ScopedValue}, rest::Pair{<:ScopedValue}...)
155-
@nospecialize
156-
ct = Base.current_task()
157-
current_scope = ct.scope::Union{Nothing, Scope}
158-
ct.scope = Scope(current_scope, pair, rest...)
159-
try
160-
return f()
161-
finally
162-
ct.scope = current_scope
163-
end
164-
end
165-
166-
with(@nospecialize(f)) = f()
167-
168141
"""
169142
@with vars... expr
170143
@@ -182,18 +155,18 @@ macro with(exprs...)
182155
else
183156
error("@with expects at least one argument")
184157
end
185-
for expr in exprs
186-
if expr.head !== :call || first(expr.args) !== :(=>)
187-
error("@with expects arguments of the form `A => 2` got $expr")
188-
end
189-
end
190158
exprs = map(esc, exprs)
191-
quote
192-
ct = $(Base.current_task)()
193-
current_scope = ct.scope::$(Union{Nothing, Scope})
194-
ct.scope = $(Scope)(current_scope, $(exprs...))
195-
$(Expr(:tryfinally, esc(ex), :(ct.scope = current_scope)))
196-
end
159+
Expr(:tryfinally, esc(ex), :(), :($(Scope)($(Core.current_scope)()::Union{Nothing, Scope}, $(exprs...))))
197160
end
198161

162+
"""
163+
with(f, (var::ScopedValue{T} => val::T)...)
164+
165+
Execute `f` in a new scope with `var` set to `val`.
166+
"""
167+
function with(f, pair::Pair{<:ScopedValue}, rest::Pair{<:ScopedValue}...)
168+
@with(pair, rest..., f())
169+
end
170+
with(@nospecialize(f)) = f()
171+
199172
end # module ScopedValues

0 commit comments

Comments
 (0)