Skip to content

inference: propagate variable changes to all exception frames #42081 #42110

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 22 additions & 23 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1346,19 +1346,18 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
n = frame.nstmts
while frame.pc´´ <= n
# make progress on the active ip set
local pc::Int = frame.pc´´ # current program-counter
local pc::Int = frame.pc´´
while true # inner loop optimizes the common case where it can run straight from pc to pc + 1
#print(pc,": ",s[pc],"\n")
local pc´::Int = pc + 1 # next program-counter (after executing instruction)
if pc == frame.pc´´
# need to update pc´´ to point at the new lowest instruction in W
min_pc = _bits_findnext(W.bits, pc + 1)
frame.pc´´ = min_pc == -1 ? n + 1 : min_pc
# want to update pc´´ to point at the new lowest instruction in W
frame.pc´´ = pc´
end
delete!(W, pc)
frame.currpc = pc
frame.cur_hand = frame.handler_at[pc]
frame.stmt_edges[pc] === nothing || empty!(frame.stmt_edges[pc])
edges = frame.stmt_edges[pc]
edges === nothing || empty!(edges)
frame.stmt_info[pc] = nothing
stmt = frame.src.code[pc]
changes = s[pc]::VarTable
Expand Down Expand Up @@ -1392,7 +1391,6 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
pc´ = l
else
# general case
frame.handler_at[l] = frame.cur_hand
changes_else = changes
if isa(condt, Conditional)
if condt.elsetype !== Any && condt.elsetype !== changes[slot_id(condt.var)]
Expand Down Expand Up @@ -1440,7 +1438,6 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
end
elseif hd === :enter
l = stmt.args[1]::Int
frame.cur_hand = Pair{Any,Any}(l, frame.cur_hand)
# propagate type info to exception handler
old = s[l]
newstate_catch = stupdate!(old, changes)
Expand All @@ -1452,11 +1449,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
s[l] = newstate_catch
end
typeassert(s[l], VarTable)
frame.handler_at[l] = frame.cur_hand
elseif hd === :leave
for i = 1:((stmt.args[1])::Int)
frame.cur_hand = (frame.cur_hand::Pair{Any,Any}).second
end
else
if hd === :(=)
t = abstract_eval_statement(interp, stmt.args[2], changes, frame)
Expand All @@ -1482,16 +1475,22 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
frame.src.ssavaluetypes[pc] = t
end
end
if frame.cur_hand !== nothing && isa(changes, StateUpdate)
# propagate new type info to exception handler
# the handling for Expr(:enter) propagates all changes from before the try/catch
# so this only needs to propagate any changes
l = frame.cur_hand.first::Int
if stupdate1!(s[l]::VarTable, changes::StateUpdate) !== false
if l < frame.pc´´
frame.pc´´ = l
if isa(changes, StateUpdate)
let cur_hand = frame.handler_at[pc], l, enter
while cur_hand != 0
enter = frame.src.code[cur_hand]
l = (enter::Expr).args[1]::Int
# propagate new type info to exception handler
# the handling for Expr(:enter) propagates all changes from before the try/catch
# so this only needs to propagate any changes
if stupdate1!(s[l]::VarTable, changes::StateUpdate) !== false
if l < frame.pc´´
frame.pc´´ = l
end
push!(W, l)
end
cur_hand = frame.handler_at[cur_hand]
end
push!(W, l)
end
end
end
Expand All @@ -1504,7 +1503,6 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
end

pc´ > n && break # can't proceed with the fast-path fall-through
frame.handler_at[pc´] = frame.cur_hand
newstate = stupdate!(s[pc´], changes)
if isa(stmt, GotoNode) && frame.pc´´ < pc´
# if we are processing a goto node anyways,
Expand All @@ -1515,7 +1513,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
s[pc´] = newstate
end
push!(W, pc´)
pc = frame.pc´´
break
elseif newstate !== nothing
s[pc´] = newstate
pc = pc´
Expand All @@ -1525,6 +1523,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
break
end
end
frame.pc´´ = _bits_findnext(W.bits, frame.pc´´)::Int # next program-counter
end
frame.dont_work_on_me = false
nothing
Expand Down
114 changes: 94 additions & 20 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ mutable struct InferenceState
pc´´::LineNum
nstmts::Int
# current exception handler info
cur_hand #::Union{Nothing, Pair{LineNum, prev_handler}}
handler_at::Vector{Any}
n_handlers::Int
handler_at::Vector{LineNum}
# ssavalue sparsity and restart info
ssavalue_uses::Vector{BitSet}
throw_blocks::BitSet
Expand All @@ -57,8 +55,9 @@ mutable struct InferenceState
function InferenceState(result::InferenceResult, src::CodeInfo,
cached::Bool, interp::AbstractInterpreter)
linfo = result.linfo
def = linfo.def
code = src.code::Array{Any,1}
toplevel = !isa(linfo.def, Method)
toplevel = !isa(def, Method)

sp = sptypes_from_meth_instance(linfo::MethodInstance)

Expand Down Expand Up @@ -87,30 +86,21 @@ mutable struct InferenceState
throw_blocks = find_throw_blocks(code)

# exception handlers
cur_hand = nothing
handler_at = Any[ nothing for i=1:n ]
n_handlers = 0

W = BitSet()
push!(W, 1) #initial pc to visit

if !toplevel
meth = linfo.def
inmodule = meth.module
else
inmodule = linfo.def::Module
end
ip = BitSet()
handler_at = compute_trycatch(src.code, ip)
push!(ip, 1)

mod = isa(def, Method) ? def.module : def
valid_worlds = WorldRange(src.min_world,
src.max_world == typemax(UInt) ? get_world_counter() : src.max_world)

frame = new(
InferenceParams(interp), result, linfo,
sp, slottypes, inmodule, 0,
sp, slottypes, mod, 0,
IdSet{InferenceState}(), IdSet{InferenceState}(),
src, get_world_counter(interp), valid_worlds,
nargs, s_types, s_edges, stmt_info,
Union{}, W, 1, n,
cur_hand, handler_at, n_handlers,
Union{}, ip, 1, n, handler_at,
ssavalue_uses, throw_blocks,
Vector{Tuple{InferenceState,LineNum}}(), # cycle_backedges
Vector{InferenceState}(), # callers_in_cycle
Expand All @@ -124,6 +114,90 @@ mutable struct InferenceState
end
end

function compute_trycatch(code::Vector{Any}, ip::BitSet)
# The goal initially is to record the frame like this for the state at exit:
# 1: (enter 3) # == 0
# 3: (expr) # == 1
# 3: (leave 1) # == 1
# 4: (expr) # == 0
# then we can find all trys by walking backwards from :enter statements,
# and all catches by looking at the statement after the :enter
n = length(code)
empty!(ip)
ip.offset = 0 # for _bits_findnext
push!(ip, n + 1)
handler_at = fill(0, n)

# start from all :enter statements and record the location of the try
for pc = 1:n
stmt = code[pc]
if isexpr(stmt, :enter)
l = stmt.args[1]::Int
handler_at[pc + 1] = pc
push!(ip, pc + 1)
handler_at[l] = pc
push!(ip, l)
end
end

# now forward those marks to all :leave statements
pc´´ = 0
while true
# make progress on the active ip set
pc = _bits_findnext(ip.bits, pc´´)::Int
pc > n && break
while true # inner loop optimizes the common case where it can run straight from pc to pc + 1
pc´ = pc + 1 # next program-counter (after executing instruction)
if pc == pc´´
pc´´ = pc´
end
delete!(ip, pc)
cur_hand = handler_at[pc]
@assert cur_hand != 0 "unbalanced try/catch"
stmt = code[pc]
if isa(stmt, GotoNode)
pc´ = stmt.label
elseif isa(stmt, GotoIfNot)
l = stmt.dest::Int
if handler_at[l] != cur_hand
@assert handler_at[l] == 0 "unbalanced try/catch"
handler_at[l] = cur_hand
if l < pc´´
pc´´ = l
end
push!(ip, l)
end
elseif isa(stmt, ReturnNode)
@assert !isdefined(stmt, :val) "unbalanced try/catch"
break
elseif isa(stmt, Expr)
head = stmt.head
if head === :enter
cur_hand = pc
elseif head === :leave
l = stmt.args[1]::Int
for i = 1:l
cur_hand = handler_at[cur_hand]
end
cur_hand == 0 && break
end
end

pc´ > n && break # can't proceed with the fast-path fall-through
if handler_at[pc´] != cur_hand
@assert handler_at[pc´] == 0 "unbalanced try/catch"
handler_at[pc´] = cur_hand
elseif !in(pc´, ip)
break # already visited
end
pc = pc´
end
end

@assert first(ip) == n + 1
return handler_at
end

method_table(interp::AbstractInterpreter, sv::InferenceState) = sv.method_table

function InferenceState(result::InferenceResult, cached::Bool, interp::AbstractInterpreter)
Expand Down
45 changes: 45 additions & 0 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3040,3 +3040,48 @@ Base.return_types((Union{Int,Nothing},)) do x
end
x
end == [Int]

# issue #42022
let x = Tuple{Int,Any}[
#= 1=# (0, Expr(:(=), Core.SlotNumber(3), 1))
#= 2=# (0, Expr(:enter, 18))
#= 3=# (2, Expr(:(=), Core.SlotNumber(3), 2.0))
#= 4=# (2, Expr(:enter, 12))
#= 5=# (4, Expr(:(=), Core.SlotNumber(3), '3'))
#= 6=# (4, Core.GotoIfNot(Core.SlotNumber(2), 9))
#= 7=# (4, Expr(:leave, 2))
#= 8=# (0, Core.ReturnNode(1))
#= 9=# (4, Expr(:call, GlobalRef(Main, :throw)))
#=10=# (4, Expr(:leave, 1))
#=11=# (2, Core.GotoNode(16))
#=12=# (4, Expr(:leave, 1))
#=13=# (2, Expr(:(=), Core.SlotNumber(4), Expr(:the_exception)))
#=14=# (2, Expr(:call, GlobalRef(Main, :rethrow)))
#=15=# (2, Expr(:pop_exception, Core.SSAValue(4)))
#=16=# (2, Expr(:leave, 1))
#=17=# (0, Core.GotoNode(22))
#=18=# (2, Expr(:leave, 1))
#=19=# (0, Expr(:(=), Core.SlotNumber(5), Expr(:the_exception)))
#=20=# (0, nothing)
#=21=# (0, Expr(:pop_exception, Core.SSAValue(2)))
#=22=# (0, Core.ReturnNode(Core.SlotNumber(3)))
]
handler_at = Core.Compiler.compute_trycatch(last.(x), Core.Compiler.BitSet())
@test handler_at == first.(x)
end

@test only(Base.return_types((Bool,)) do y
x = 1
try
x = 2.0
try
x = '3'
y ? (return 1) : throw()
catch ex1
rethrow()
end
catch ex2
nothing
end
return x
end) === Union{Int, Float64, Char}