Skip to content

Commit aa20b32

Browse files
Ian AtolKeno
andauthored
Semi-concrete IR interpreter (#44803)
Co-authored-by: Keno Fischer <keno@juliacomputing.com>
1 parent db2c174 commit aa20b32

18 files changed

+709
-143
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 191 additions & 98 deletions
Large diffs are not rendered by default.

base/compiler/compiler.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ include("compiler/stmtinfo.jl")
164164

165165
include("compiler/abstractinterpretation.jl")
166166
include("compiler/typeinfer.jl")
167-
include("compiler/optimize.jl") # TODO: break this up further + extract utilities
167+
include("compiler/optimize.jl")
168168

169169
# required for bootstrap because sort.jl uses extrema
170170
# to decide whether to dispatch to counting sort.

base/compiler/inferenceresult.jl

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,36 @@ function is_forwardable_argtype(@nospecialize x)
1616
isa(x, PartialOpaque)
1717
end
1818

19+
function va_process_argtypes(given_argtypes::Vector{Any}, mi::MethodInstance,
20+
condargs::Union{Vector{Tuple{Int,Int}}, Nothing}=nothing)
21+
isva = mi.def.isva
22+
nargs = Int(mi.def.nargs)
23+
if isva || isvarargtype(given_argtypes[end])
24+
isva_given_argtypes = Vector{Any}(undef, nargs)
25+
for i = 1:(nargs - isva)
26+
isva_given_argtypes[i] = argtype_by_index(given_argtypes, i)
27+
end
28+
if isva
29+
if length(given_argtypes) < nargs && isvarargtype(given_argtypes[end])
30+
last = length(given_argtypes)
31+
else
32+
last = nargs
33+
end
34+
isva_given_argtypes[nargs] = tuple_tfunc(given_argtypes[last:end])
35+
# invalidate `Conditional` imposed on varargs
36+
if condargs !== nothing
37+
for (slotid, i) in condargs
38+
if slotid last
39+
isva_given_argtypes[i] = widenconditional(isva_given_argtypes[i])
40+
end
41+
end
42+
end
43+
end
44+
return isva_given_argtypes
45+
end
46+
return given_argtypes
47+
end
48+
1949
# In theory, there could be a `cache` containing a matching `InferenceResult`
2050
# for the provided `linfo` and `given_argtypes`. The purpose of this function is
2151
# to return a valid value for `cache_lookup(linfo, argtypes, cache).argtypes`,
@@ -56,30 +86,7 @@ function matching_cache_argtypes(
5686
end
5787
given_argtypes[i] = widenconditional(argtype)
5888
end
59-
isva = def.isva
60-
if isva || isvarargtype(given_argtypes[end])
61-
isva_given_argtypes = Vector{Any}(undef, nargs)
62-
for i = 1:(nargs - isva)
63-
isva_given_argtypes[i] = argtype_by_index(given_argtypes, i)
64-
end
65-
if isva
66-
if length(given_argtypes) < nargs && isvarargtype(given_argtypes[end])
67-
last = length(given_argtypes)
68-
else
69-
last = nargs
70-
end
71-
isva_given_argtypes[nargs] = tuple_tfunc(given_argtypes[last:end])
72-
# invalidate `Conditional` imposed on varargs
73-
if condargs !== nothing
74-
for (slotid, i) in condargs
75-
if slotid last
76-
isva_given_argtypes[i] = widenconditional(isva_given_argtypes[i])
77-
end
78-
end
79-
end
80-
end
81-
given_argtypes = isva_given_argtypes
82-
end
89+
given_argtypes = va_process_argtypes(given_argtypes, linfo, condargs)
8390
@assert length(given_argtypes) == nargs
8491
for i in 1:nargs
8592
given_argtype = given_argtypes[i]

base/compiler/inferencestate.jl

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,12 @@ function in(idx::Int, bsbmp::BitSetBoundedMinPrioritySet)
8080
return idx in bsbmp.elems
8181
end
8282

83+
function append!(bsbmp::BitSetBoundedMinPrioritySet, itr)
84+
for val in itr
85+
push!(bsbmp, val)
86+
end
87+
end
88+
8389
mutable struct InferenceState
8490
#= information about this method instance =#
8591
linfo::MethodInstance
@@ -209,8 +215,10 @@ Effects(state::InferenceState) = state.ipo_effects
209215
function merge_effects!(::AbstractInterpreter, caller::InferenceState, effects::Effects)
210216
caller.ipo_effects = merge_effects(caller.ipo_effects, effects)
211217
end
218+
212219
merge_effects!(interp::AbstractInterpreter, caller::InferenceState, callee::InferenceState) =
213220
merge_effects!(interp, caller, Effects(callee))
221+
merge_effects!(interp::AbstractInterpreter, caller::IRCode, effects::Effects) = nothing
214222

215223
is_effect_overridden(sv::InferenceState, effect::Symbol) = is_effect_overridden(sv.linfo, effect)
216224
function is_effect_overridden(linfo::MethodInstance, effect::Symbol)
@@ -226,15 +234,15 @@ function InferenceResult(
226234
return _InferenceResult(linfo, arginfo)
227235
end
228236

229-
add_remark!(::AbstractInterpreter, sv::InferenceState, remark) = return
237+
add_remark!(::AbstractInterpreter, sv::Union{InferenceState, IRCode}, remark) = return
230238

231-
function bail_out_toplevel_call(::AbstractInterpreter, @nospecialize(callsig), sv::InferenceState)
232-
return sv.restrict_abstract_call_sites && !isdispatchtuple(callsig)
239+
function bail_out_toplevel_call(::AbstractInterpreter, @nospecialize(callsig), sv::Union{InferenceState, IRCode})
240+
return isa(sv, InferenceState) && sv.restrict_abstract_call_sites && !isdispatchtuple(callsig)
233241
end
234-
function bail_out_call(::AbstractInterpreter, @nospecialize(rt), sv::InferenceState)
242+
function bail_out_call(::AbstractInterpreter, @nospecialize(rt), sv::Union{InferenceState, IRCode})
235243
return rt === Any
236244
end
237-
function bail_out_apply(::AbstractInterpreter, @nospecialize(rt), sv::InferenceState)
245+
function bail_out_apply(::AbstractInterpreter, @nospecialize(rt), sv::Union{InferenceState, IRCode})
238246
return rt === Any
239247
end
240248

base/compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ import Core.Compiler: # Core.Compiler specific definitions
3131
isbitstype, isexpr, is_meta_expr_head, println, widenconst, argextype, singleton_type,
3232
fieldcount_noerror, try_compute_field, try_compute_fieldidx, hasintersect, ,
3333
intrinsic_nothrow, array_builtin_common_typecheck, arrayset_typecheck,
34-
setfield!_nothrow, alloc_array_ndims, check_effect_free!
34+
setfield!_nothrow, alloc_array_ndims, stmt_effect_free, check_effect_free!,
35+
SemiConcreteResult
3536

3637
include(x) = _TOP_MOD.include(@__MODULE__, x)
3738
if _TOP_MOD === Core.Compiler

base/compiler/ssair/EscapeAnalysis/interprocedural.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import Core.Compiler:
66
call_sig, argtypes_to_type, is_builtin, is_return_type, istopfunction, validate_sparams,
77
specialize_method, invoke_rewrite
88

9-
const Linfo = Union{MethodInstance,InferenceResult}
9+
const Linfo = Union{MethodInstance,InferenceResult,SemiConcreteResult}
1010
struct CallInfo
1111
linfos::Vector{Linfo}
1212
nothrow::Bool

base/compiler/ssair/driver.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@ include("compiler/ssair/verify.jl")
2020
include("compiler/ssair/legacy.jl")
2121
include("compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl")
2222
include("compiler/ssair/passes.jl")
23+
include("compiler/ssair/irinterp.jl")

base/compiler/ssair/inlining.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1371,6 +1371,8 @@ function compute_inlining_cases(info::ConstCallInfo,
13711371
push!(cases, InliningCase(result.mi.specTypes, case))
13721372
elseif isa(result, ConstPropResult)
13731373
handled_all_cases &= handle_const_prop_result!(result, argtypes, flag, state, cases, #=allow_abstract=#true)
1374+
elseif isa(result, SemiConcreteResult)
1375+
handled_all_cases &= handle_semi_concrete_result!(result, cases, #=allow_abstract=#true)
13741376
else
13751377
@assert result === nothing
13761378
handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, #=allow_abstract=#true, #=allow_typevars=#false)
@@ -1434,6 +1436,15 @@ function handle_const_prop_result!(
14341436
return true
14351437
end
14361438

1439+
function handle_semi_concrete_result!(result::SemiConcreteResult, cases::Vector{InliningCase}, allow_abstract::Bool = false)
1440+
mi = result.mi
1441+
spec_types = mi.specTypes
1442+
allow_abstract || isdispatchtuple(spec_types) || return false
1443+
validate_sparams(mi.sparam_vals) || return false
1444+
push!(cases, InliningCase(spec_types, InliningTodo(mi, result.ir, result.effects)))
1445+
return true
1446+
end
1447+
14371448
function concrete_result_item(result::ConcreteResult, state::InliningState)
14381449
if !isdefined(result, :result) || !is_inlineable_constant(result.result)
14391450
case = compileable_specialization(state.et, result.mi, result.effects)

base/compiler/ssair/ir.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,15 +1050,22 @@ function renumber_ssa2!(@nospecialize(stmt), ssanums::Vector{Any}, used_ssas::Ve
10501050
end
10511051

10521052
# Used in inlining before we start compacting - Only works at the CFG level
1053-
function kill_edge!(bbs::Vector{BasicBlock}, from::Int, to::Int)
1053+
function kill_edge!(bbs::Vector{BasicBlock}, from::Int, to::Int, callback=nothing)
10541054
preds, succs = bbs[to].preds, bbs[from].succs
10551055
deleteat!(preds, findfirst(x->x === from, preds)::Int)
10561056
deleteat!(succs, findfirst(x->x === to, succs)::Int)
10571057
if length(preds) == 0
10581058
for succ in copy(bbs[to].succs)
1059-
kill_edge!(bbs, to, succ)
1059+
kill_edge!(bbs, to, succ, callback)
10601060
end
10611061
end
1062+
if callback !== nothing
1063+
callback(from, to)
1064+
end
1065+
end
1066+
1067+
function kill_edge!(ir::IRCode, from::Int, to::Int, callback=nothing)
1068+
kill_edge!(ir.cfg.blocks, from, to, callback)
10621069
end
10631070

10641071
# N.B.: from and to are non-renamed indices

0 commit comments

Comments
 (0)