Skip to content

Commit 19e66b3

Browse files
aviateskvtjnash
authored andcommitted
inference: propagate variable changes to all exception frames #42081 (#42110)
cherry-picked from #42081 Co-Authored-By: Jameson Nash <vtjnash+github@gmail.com>
1 parent 9d13e16 commit 19e66b3

File tree

3 files changed

+161
-43
lines changed

3 files changed

+161
-43
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1346,19 +1346,18 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
13461346
n = frame.nstmts
13471347
while frame.pc´´ <= n
13481348
# make progress on the active ip set
1349-
local pc::Int = frame.pc´´ # current program-counter
1349+
local pc::Int = frame.pc´´
13501350
while true # inner loop optimizes the common case where it can run straight from pc to pc + 1
13511351
#print(pc,": ",s[pc],"\n")
13521352
local pc´::Int = pc + 1 # next program-counter (after executing instruction)
13531353
if pc == frame.pc´´
1354-
# need to update pc´´ to point at the new lowest instruction in W
1355-
min_pc = _bits_findnext(W.bits, pc + 1)
1356-
frame.pc´´ = min_pc == -1 ? n + 1 : min_pc
1354+
# want to update pc´´ to point at the new lowest instruction in W
1355+
frame.pc´´ = pc´
13571356
end
13581357
delete!(W, pc)
13591358
frame.currpc = pc
1360-
frame.cur_hand = frame.handler_at[pc]
1361-
frame.stmt_edges[pc] === nothing || empty!(frame.stmt_edges[pc])
1359+
edges = frame.stmt_edges[pc]
1360+
edges === nothing || empty!(edges)
13621361
frame.stmt_info[pc] = nothing
13631362
stmt = frame.src.code[pc]
13641363
changes = s[pc]::VarTable
@@ -1392,7 +1391,6 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
13921391
pc´ = l
13931392
else
13941393
# general case
1395-
frame.handler_at[l] = frame.cur_hand
13961394
changes_else = changes
13971395
if isa(condt, Conditional)
13981396
if condt.elsetype !== Any && condt.elsetype !== changes[slot_id(condt.var)]
@@ -1440,7 +1438,6 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
14401438
end
14411439
elseif hd === :enter
14421440
l = stmt.args[1]::Int
1443-
frame.cur_hand = Pair{Any,Any}(l, frame.cur_hand)
14441441
# propagate type info to exception handler
14451442
old = s[l]
14461443
newstate_catch = stupdate!(old, changes)
@@ -1452,11 +1449,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
14521449
s[l] = newstate_catch
14531450
end
14541451
typeassert(s[l], VarTable)
1455-
frame.handler_at[l] = frame.cur_hand
14561452
elseif hd === :leave
1457-
for i = 1:((stmt.args[1])::Int)
1458-
frame.cur_hand = (frame.cur_hand::Pair{Any,Any}).second
1459-
end
14601453
else
14611454
if hd === :(=)
14621455
t = abstract_eval_statement(interp, stmt.args[2], changes, frame)
@@ -1482,16 +1475,22 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
14821475
frame.src.ssavaluetypes[pc] = t
14831476
end
14841477
end
1485-
if frame.cur_hand !== nothing && isa(changes, StateUpdate)
1486-
# propagate new type info to exception handler
1487-
# the handling for Expr(:enter) propagates all changes from before the try/catch
1488-
# so this only needs to propagate any changes
1489-
l = frame.cur_hand.first::Int
1490-
if stupdate1!(s[l]::VarTable, changes::StateUpdate) !== false
1491-
if l < frame.pc´´
1492-
frame.pc´´ = l
1478+
if isa(changes, StateUpdate)
1479+
let cur_hand = frame.handler_at[pc], l, enter
1480+
while cur_hand != 0
1481+
enter = frame.src.code[cur_hand]
1482+
l = (enter::Expr).args[1]::Int
1483+
# propagate new type info to exception handler
1484+
# the handling for Expr(:enter) propagates all changes from before the try/catch
1485+
# so this only needs to propagate any changes
1486+
if stupdate1!(s[l]::VarTable, changes::StateUpdate) !== false
1487+
if l < frame.pc´´
1488+
frame.pc´´ = l
1489+
end
1490+
push!(W, l)
1491+
end
1492+
cur_hand = frame.handler_at[cur_hand]
14931493
end
1494-
push!(W, l)
14951494
end
14961495
end
14971496
end
@@ -1504,7 +1503,6 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
15041503
end
15051504

15061505
pc´ > n && break # can't proceed with the fast-path fall-through
1507-
frame.handler_at[pc´] = frame.cur_hand
15081506
newstate = stupdate!(s[pc´], changes)
15091507
if isa(stmt, GotoNode) && frame.pc´´ < pc´
15101508
# if we are processing a goto node anyways,
@@ -1515,7 +1513,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
15151513
s[pc´] = newstate
15161514
end
15171515
push!(W, pc´)
1518-
pc = frame.pc´´
1516+
break
15191517
elseif newstate !== nothing
15201518
s[pc´] = newstate
15211519
pc = pc´
@@ -1525,6 +1523,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
15251523
break
15261524
end
15271525
end
1526+
frame.pc´´ = _bits_findnext(W.bits, frame.pc´´)::Int # next program-counter
15281527
end
15291528
frame.dont_work_on_me = false
15301529
nothing

base/compiler/inferencestate.jl

Lines changed: 94 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,7 @@ mutable struct InferenceState
2828
pc´´::LineNum
2929
nstmts::Int
3030
# current exception handler info
31-
cur_hand #::Union{Nothing, Pair{LineNum, prev_handler}}
32-
handler_at::Vector{Any}
33-
n_handlers::Int
31+
handler_at::Vector{LineNum}
3432
# ssavalue sparsity and restart info
3533
ssavalue_uses::Vector{BitSet}
3634
throw_blocks::BitSet
@@ -57,8 +55,9 @@ mutable struct InferenceState
5755
function InferenceState(result::InferenceResult, src::CodeInfo,
5856
cached::Bool, interp::AbstractInterpreter)
5957
linfo = result.linfo
58+
def = linfo.def
6059
code = src.code::Array{Any,1}
61-
toplevel = !isa(linfo.def, Method)
60+
toplevel = !isa(def, Method)
6261

6362
sp = sptypes_from_meth_instance(linfo::MethodInstance)
6463

@@ -87,30 +86,21 @@ mutable struct InferenceState
8786
throw_blocks = find_throw_blocks(code)
8887

8988
# exception handlers
90-
cur_hand = nothing
91-
handler_at = Any[ nothing for i=1:n ]
92-
n_handlers = 0
93-
94-
W = BitSet()
95-
push!(W, 1) #initial pc to visit
96-
97-
if !toplevel
98-
meth = linfo.def
99-
inmodule = meth.module
100-
else
101-
inmodule = linfo.def::Module
102-
end
89+
ip = BitSet()
90+
handler_at = compute_trycatch(src.code, ip)
91+
push!(ip, 1)
10392

93+
mod = isa(def, Method) ? def.module : def
10494
valid_worlds = WorldRange(src.min_world,
10595
src.max_world == typemax(UInt) ? get_world_counter() : src.max_world)
96+
10697
frame = new(
10798
InferenceParams(interp), result, linfo,
108-
sp, slottypes, inmodule, 0,
99+
sp, slottypes, mod, 0,
109100
IdSet{InferenceState}(), IdSet{InferenceState}(),
110101
src, get_world_counter(interp), valid_worlds,
111102
nargs, s_types, s_edges, stmt_info,
112-
Union{}, W, 1, n,
113-
cur_hand, handler_at, n_handlers,
103+
Union{}, ip, 1, n, handler_at,
114104
ssavalue_uses, throw_blocks,
115105
Vector{Tuple{InferenceState,LineNum}}(), # cycle_backedges
116106
Vector{InferenceState}(), # callers_in_cycle
@@ -124,6 +114,90 @@ mutable struct InferenceState
124114
end
125115
end
126116

117+
function compute_trycatch(code::Vector{Any}, ip::BitSet)
118+
# The goal initially is to record the frame like this for the state at exit:
119+
# 1: (enter 3) # == 0
120+
# 3: (expr) # == 1
121+
# 3: (leave 1) # == 1
122+
# 4: (expr) # == 0
123+
# then we can find all trys by walking backwards from :enter statements,
124+
# and all catches by looking at the statement after the :enter
125+
n = length(code)
126+
empty!(ip)
127+
ip.offset = 0 # for _bits_findnext
128+
push!(ip, n + 1)
129+
handler_at = fill(0, n)
130+
131+
# start from all :enter statements and record the location of the try
132+
for pc = 1:n
133+
stmt = code[pc]
134+
if isexpr(stmt, :enter)
135+
l = stmt.args[1]::Int
136+
handler_at[pc + 1] = pc
137+
push!(ip, pc + 1)
138+
handler_at[l] = pc
139+
push!(ip, l)
140+
end
141+
end
142+
143+
# now forward those marks to all :leave statements
144+
pc´´ = 0
145+
while true
146+
# make progress on the active ip set
147+
pc = _bits_findnext(ip.bits, pc´´)::Int
148+
pc > n && break
149+
while true # inner loop optimizes the common case where it can run straight from pc to pc + 1
150+
pc´ = pc + 1 # next program-counter (after executing instruction)
151+
if pc == pc´´
152+
pc´´ = pc´
153+
end
154+
delete!(ip, pc)
155+
cur_hand = handler_at[pc]
156+
@assert cur_hand != 0 "unbalanced try/catch"
157+
stmt = code[pc]
158+
if isa(stmt, GotoNode)
159+
pc´ = stmt.label
160+
elseif isa(stmt, GotoIfNot)
161+
l = stmt.dest::Int
162+
if handler_at[l] != cur_hand
163+
@assert handler_at[l] == 0 "unbalanced try/catch"
164+
handler_at[l] = cur_hand
165+
if l < pc´´
166+
pc´´ = l
167+
end
168+
push!(ip, l)
169+
end
170+
elseif isa(stmt, ReturnNode)
171+
@assert !isdefined(stmt, :val) "unbalanced try/catch"
172+
break
173+
elseif isa(stmt, Expr)
174+
head = stmt.head
175+
if head === :enter
176+
cur_hand = pc
177+
elseif head === :leave
178+
l = stmt.args[1]::Int
179+
for i = 1:l
180+
cur_hand = handler_at[cur_hand]
181+
end
182+
cur_hand == 0 && break
183+
end
184+
end
185+
186+
pc´ > n && break # can't proceed with the fast-path fall-through
187+
if handler_at[pc´] != cur_hand
188+
@assert handler_at[pc´] == 0 "unbalanced try/catch"
189+
handler_at[pc´] = cur_hand
190+
elseif !in(pc´, ip)
191+
break # already visited
192+
end
193+
pc = pc´
194+
end
195+
end
196+
197+
@assert first(ip) == n + 1
198+
return handler_at
199+
end
200+
127201
method_table(interp::AbstractInterpreter, sv::InferenceState) = sv.method_table
128202

129203
function InferenceState(result::InferenceResult, cached::Bool, interp::AbstractInterpreter)

test/compiler/inference.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3040,3 +3040,48 @@ Base.return_types((Union{Int,Nothing},)) do x
30403040
end
30413041
x
30423042
end == [Int]
3043+
3044+
# issue #42022
3045+
let x = Tuple{Int,Any}[
3046+
#= 1=# (0, Expr(:(=), Core.SlotNumber(3), 1))
3047+
#= 2=# (0, Expr(:enter, 18))
3048+
#= 3=# (2, Expr(:(=), Core.SlotNumber(3), 2.0))
3049+
#= 4=# (2, Expr(:enter, 12))
3050+
#= 5=# (4, Expr(:(=), Core.SlotNumber(3), '3'))
3051+
#= 6=# (4, Core.GotoIfNot(Core.SlotNumber(2), 9))
3052+
#= 7=# (4, Expr(:leave, 2))
3053+
#= 8=# (0, Core.ReturnNode(1))
3054+
#= 9=# (4, Expr(:call, GlobalRef(Main, :throw)))
3055+
#=10=# (4, Expr(:leave, 1))
3056+
#=11=# (2, Core.GotoNode(16))
3057+
#=12=# (4, Expr(:leave, 1))
3058+
#=13=# (2, Expr(:(=), Core.SlotNumber(4), Expr(:the_exception)))
3059+
#=14=# (2, Expr(:call, GlobalRef(Main, :rethrow)))
3060+
#=15=# (2, Expr(:pop_exception, Core.SSAValue(4)))
3061+
#=16=# (2, Expr(:leave, 1))
3062+
#=17=# (0, Core.GotoNode(22))
3063+
#=18=# (2, Expr(:leave, 1))
3064+
#=19=# (0, Expr(:(=), Core.SlotNumber(5), Expr(:the_exception)))
3065+
#=20=# (0, nothing)
3066+
#=21=# (0, Expr(:pop_exception, Core.SSAValue(2)))
3067+
#=22=# (0, Core.ReturnNode(Core.SlotNumber(3)))
3068+
]
3069+
handler_at = Core.Compiler.compute_trycatch(last.(x), Core.Compiler.BitSet())
3070+
@test handler_at == first.(x)
3071+
end
3072+
3073+
@test only(Base.return_types((Bool,)) do y
3074+
x = 1
3075+
try
3076+
x = 2.0
3077+
try
3078+
x = '3'
3079+
y ? (return 1) : throw()
3080+
catch ex1
3081+
rethrow()
3082+
end
3083+
catch ex2
3084+
nothing
3085+
end
3086+
return x
3087+
end) === Union{Int, Float64, Char}

0 commit comments

Comments
 (0)