Skip to content

Commit 0afd997

Browse files
KenoElOceanografo
authored andcommitted
Give const prop'ed calls their own statement info (JuliaLang#39754)
My primary motivation here is to let Cthulhu mark cases where constant propagation improved the result, but this also lets us avoid the second (linear) lookup in the inference cache, which causes a marginal, but measurable (a few percent) improvement in sysimage build time.
1 parent 638c49e commit 0afd997

File tree

4 files changed

+105
-52
lines changed

4 files changed

+105
-52
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -166,11 +166,14 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
166166
# if there's a possibility we could constant-propagate a better result
167167
# (hopefully without doing too much work), try to do that now
168168
# TODO: it feels like this could be better integrated into abstract_call_method / typeinf_edge
169-
const_rettype = abstract_call_method_with_const_args(interp, rettype, f, argtypes, applicable[nonbot]::MethodMatch, sv, edgecycle)
169+
const_rettype, result = abstract_call_method_with_const_args(interp, rettype, f, argtypes, applicable[nonbot]::MethodMatch, sv, edgecycle)
170170
if const_rettype rettype
171171
# use the better result, if it's a refinement of rettype
172172
rettype = const_rettype
173173
end
174+
if result !== nothing
175+
info = ConstCallInfo(info, result)
176+
end
174177
end
175178
if is_unused && !(rettype === Bottom)
176179
add_remark!(interp, sv, "Call result type was widened because the return value is unused")
@@ -263,7 +266,7 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, @nosp
263266
method = match.method
264267
nargs::Int = method.nargs
265268
method.isva && (nargs -= 1)
266-
length(argtypes) >= nargs || return Any
269+
length(argtypes) >= nargs || return Any, nothing
267270
haveconst = false
268271
allconst = true
269272
# see if any or all of the arguments are constant and propagating constants may be worthwhile
@@ -279,21 +282,21 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, @nosp
279282
break
280283
end
281284
end
282-
haveconst || improvable_via_constant_propagation(rettype) || return Any
285+
haveconst || improvable_via_constant_propagation(rettype) || return Any, nothing
283286
force_inference = method.aggressive_constprop || InferenceParams(interp).aggressive_constant_propagation
284287
if !force_inference && nargs > 1
285288
if istopfunction(f, :getindex) || istopfunction(f, :setindex!)
286289
arrty = argtypes[2]
287290
# don't propagate constant index into indexing of non-constant array
288291
if arrty isa Type && arrty <: AbstractArray && !issingletontype(arrty)
289-
return Any
292+
return Any, nothing
290293
elseif arrty Array
291-
return Any
294+
return Any, nothing
292295
end
293296
elseif istopfunction(f, :iterate)
294297
itrty = argtypes[2]
295298
if itrty Array
296-
return Any
299+
return Any, nothing
297300
end
298301
end
299302
end
@@ -304,7 +307,7 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, @nosp
304307
istopfunction(f, :<<) || istopfunction(f, :>>))
305308
# it is almost useless to inline the op of when all the same type,
306309
# but highly worthwhile to inline promote of a constant
307-
length(argtypes) > 2 || return Any
310+
length(argtypes) > 2 || return Any, nothing
308311
t1 = widenconst(argtypes[2])
309312
all_same = true
310313
for i in 3:length(argtypes)
@@ -313,18 +316,18 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, @nosp
313316
break
314317
end
315318
end
316-
all_same && return Any
319+
all_same && return Any, nothing
317320
end
318321
if istopfunction(f, :getproperty) || istopfunction(f, :setproperty!)
319322
force_inference = true
320323
end
321324
force_inference |= allconst
322325
mi = specialize_method(match, !force_inference)
323-
mi === nothing && return Any
326+
mi === nothing && return Any, nothing
324327
mi = mi::MethodInstance
325328
# decide if it's likely to be worthwhile
326329
if !force_inference && !const_prop_heuristic(interp, method, mi)
327-
return Any
330+
return Any, nothing
328331
end
329332
inf_cache = get_inference_cache(interp)
330333
inf_result = cache_lookup(mi, argtypes, inf_cache)
@@ -336,7 +339,7 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, @nosp
336339
cyclei = 0
337340
while !(infstate === nothing)
338341
if method === infstate.linfo.def && any(infstate.result.overridden_by_const)
339-
return Any
342+
return Any, nothing
340343
end
341344
if cyclei < length(infstate.callers_in_cycle)
342345
cyclei += 1
@@ -349,16 +352,16 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, @nosp
349352
end
350353
inf_result = InferenceResult(mi, argtypes)
351354
frame = InferenceState(inf_result, #=cache=#false, interp)
352-
frame === nothing && return Any # this is probably a bad generated function (unsound), but just ignore it
355+
frame === nothing && return Any, nothing # this is probably a bad generated function (unsound), but just ignore it
353356
frame.parent = sv
354357
push!(inf_cache, inf_result)
355-
typeinf(interp, frame) || return Any
358+
typeinf(interp, frame) || return Any, nothing
356359
end
357360
result = inf_result.result
358361
# if constant inference hits a cycle, just bail out
359-
isa(result, InferenceState) && return Any
362+
isa(result, InferenceState) && return Any, nothing
360363
add_backedge!(inf_result.linfo, sv)
361-
return result
364+
return result, inf_result
362365
end
363366

364367
const RECURSION_UNUSED_MSG = "Bounded recursion detected with unused result. Annotated return type may be wider than true result."

base/compiler/ssair/inlining.jl

Lines changed: 74 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ end
3333
pass to apply its own inlining policy decisions.
3434
"""
3535
struct DelayedInliningSpec
36-
match::MethodMatch
36+
match::Union{MethodMatch, InferenceResult}
3737
atypes::Vector{Any}
3838
stmttype::Any
3939
end
@@ -44,7 +44,11 @@ struct InliningTodo
4444
spec::Union{ResolvedInliningSpec, DelayedInliningSpec}
4545
end
4646

47-
InliningTodo(mi::MethodInstance, match::MethodMatch, atypes::Vector{Any}, @nospecialize(stmttype)) = InliningTodo(mi, DelayedInliningSpec(match, atypes, stmttype))
47+
InliningTodo(mi::MethodInstance, match::MethodMatch,
48+
atypes::Vector{Any}, @nospecialize(stmttype)) = InliningTodo(mi, DelayedInliningSpec(match, atypes, stmttype))
49+
50+
InliningTodo(result::InferenceResult, atypes::Vector{Any}, @nospecialize(stmttype)) =
51+
InliningTodo(result.linfo, DelayedInliningSpec(result, atypes, stmttype))
4852

4953
struct ConstantCase
5054
val::Any
@@ -631,7 +635,10 @@ function rewrite_apply_exprargs!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::
631635
new_stmt = Expr(:call, argexprs[2], def, state...)
632636
state1 = insert_node!(ir, idx, call.rt, new_stmt)
633637
new_sig = with_atype(call_sig(ir, new_stmt)::Signature)
634-
if isa(call.info, MethodMatchInfo) || isa(call.info, UnionSplitInfo)
638+
if isa(call.info, ConstCallInfo)
639+
handle_const_call!(ir, state1.id, new_stmt, call.info, new_sig,
640+
call.rt, et, caches, false, todo)
641+
elseif isa(call.info, MethodMatchInfo) || isa(call.info, UnionSplitInfo)
635642
info = isa(call.info, MethodMatchInfo) ?
636643
MethodMatchInfo[call.info] : call.info.matches
637644
# See if we can inline this call to `iterate`
@@ -676,9 +683,32 @@ function compileable_specialization(et::Union{EdgeTracker, Nothing}, match::Meth
676683
return mi
677684
end
678685

686+
function compileable_specialization(et::Union{EdgeTracker, Nothing}, result::InferenceResult)
687+
mi = specialize_method(result.linfo.def, result.linfo.specTypes,
688+
result.linfo.sparam_vals, false, true)
689+
mi !== nothing && et !== nothing && push!(et, mi::MethodInstance)
690+
return mi
691+
end
692+
679693
function resolve_todo(todo::InliningTodo, et::Union{EdgeTracker, Nothing}, caches::InferenceCaches)
680694
spec = todo.spec::DelayedInliningSpec
681-
isconst, src = find_inferred(todo.mi, spec.atypes, caches, spec.stmttype)
695+
696+
#XXX: update_valid_age!(min_valid[1], max_valid[1], sv)
697+
isconst, src = false, nothing
698+
if isa(spec.match, InferenceResult)
699+
let inferred_src = spec.match.src
700+
if isa(inferred_src, CodeInfo)
701+
isconst, src = false, inferred_src
702+
elseif isa(inferred_src, Const)
703+
if !is_inlineable_constant(inferred_src.val)
704+
return compileable_specialization(et, spec.match)
705+
end
706+
isconst, src = true, quoted(inferred_src.val)
707+
end
708+
end
709+
else
710+
isconst, src = find_inferred(todo.mi, spec.atypes, caches, spec.stmttype)
711+
end
682712

683713
if isconst && et !== nothing
684714
push!(et, todo.mi)
@@ -717,6 +747,13 @@ function resolve_todo!(todo::Vector{Pair{Int, Any}}, et::Union{EdgeTracker, Noth
717747
todo
718748
end
719749

750+
function validate_sparams(sparams::SimpleVector)
751+
for i = 1:length(sparams)
752+
(isa(sparams[i], TypeVar) || isa(sparams[i], Core.TypeofVararg)) && return false
753+
end
754+
return true
755+
end
756+
720757
function analyze_method!(match::MethodMatch, atypes::Vector{Any},
721758
et::Union{EdgeTracker, Nothing},
722759
caches::Union{InferenceCaches, Nothing},
@@ -737,9 +774,8 @@ function analyze_method!(match::MethodMatch, atypes::Vector{Any},
737774

738775
# Bail out if any static parameters are left as TypeVar
739776
ok = true
740-
for i = 1:length(match.sparams)
741-
(isa(match.sparams[i], TypeVar) || isa(match.sparams[i], Core.TypeofVararg)) && return nothing
742-
end
777+
validate_sparams(match.sparams) || return nothing
778+
743779

744780
if !params.inlining
745781
return compileable_specialization(et, match)
@@ -1146,6 +1182,28 @@ function analyze_single_call!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int
11461182
return nothing
11471183
end
11481184

1185+
function handle_const_call!(ir::IRCode, idx::Int, stmt::Expr,
1186+
info::ConstCallInfo, sig::Signature, @nospecialize(calltype),
1187+
et::Union{EdgeTracker, Nothing}, caches::Union{InferenceCaches, Nothing},
1188+
isinvoke::Bool, todo::Vector{Pair{Int, Any}})
1189+
item = InliningTodo(info.result, sig.atypes, calltype)
1190+
validate_sparams(item.mi.sparam_vals) || return
1191+
mthd_sig = item.mi.def.sig
1192+
mistypes = item.mi.specTypes
1193+
caches !== nothing && (item = resolve_todo(item, et, caches))
1194+
if sig.atype <: mthd_sig
1195+
return handle_single_case!(ir, stmt, idx, item, isinvoke, todo)
1196+
else
1197+
item === nothing && return
1198+
# Union split out the error case
1199+
item = UnionSplit(false, sig.atype, Pair{Any, Any}[mistypes => item])
1200+
if isinvoke
1201+
stmt.args = rewrite_invoke_exprargs!(stmt.args)
1202+
end
1203+
push!(todo, idx=>item)
1204+
end
1205+
end
1206+
11491207
function assemble_inline_todo!(ir::IRCode, state::InliningState)
11501208
# todo = (inline_idx, (isva, isinvoke, na), method, spvals, inline_linetable, inline_ir, lie)
11511209
todo = Pair{Int, Any}[]
@@ -1173,6 +1231,15 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState)
11731231
end
11741232
end
11751233

1234+
# If inference arrived at this result by using constant propagation,
1235+
# it'll performed a specialized analysis for just this case. Use its
1236+
# result.
1237+
if isa(info, ConstCallInfo)
1238+
handle_const_call!(ir, idx, stmt, info, sig, calltype, state.et,
1239+
state.caches, invoke_data !== nothing, todo)
1240+
continue
1241+
end
1242+
11761243
# Ok, now figure out what method to call
11771244
if invoke_data !== nothing
11781245
inline_invoke!(ir, idx, sig, invoke_data, state, todo)
@@ -1387,35 +1454,6 @@ function ssa_substitute_op!(@nospecialize(val), arg_replacements::Vector{Any},
13871454
end
13881455

13891456
function find_inferred(mi::MethodInstance, atypes::Vector{Any}, caches::InferenceCaches, @nospecialize(rettype))
1390-
if caches.inf_cache !== nothing
1391-
# see if the method has a InferenceResult in the current cache
1392-
# or an existing inferred code info store in `.inferred`
1393-
haveconst = false
1394-
for i in 1:length(atypes)
1395-
if has_nontrivial_const_info(atypes[i])
1396-
# have new information from argtypes that wasn't available from the signature
1397-
haveconst = true
1398-
break
1399-
end
1400-
end
1401-
if haveconst || improvable_via_constant_propagation(rettype)
1402-
inf_result = cache_lookup(mi, atypes, caches.inf_cache) # Union{Nothing, InferenceResult}
1403-
else
1404-
inf_result = nothing
1405-
end
1406-
#XXX: update_valid_age!(min_valid[1], max_valid[1], sv)
1407-
if isa(inf_result, InferenceResult)
1408-
let inferred_src = inf_result.src
1409-
if isa(inferred_src, CodeInfo)
1410-
return svec(false, inferred_src)
1411-
end
1412-
if isa(inferred_src, Const) && is_inlineable_constant(inferred_src.val)
1413-
return svec(true, quoted(inferred_src.val),)
1414-
end
1415-
end
1416-
end
1417-
end
1418-
14191457
linfo = get(caches.mi_cache, mi, nothing)
14201458
if linfo isa CodeInstance
14211459
if invoke_api(linfo) == 2

base/compiler/stmtinfo.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,18 @@ struct UnionSplitApplyCallInfo
8282
infos::Vector{ApplyCallInfo}
8383
end
8484

85+
"""
86+
struct ConstCallInfo
87+
88+
Precision for this call was improved using constant information. This info
89+
keeps a reference to the result that was used (or created for these)
90+
constant information.
91+
"""
92+
struct ConstCallInfo
93+
call::Any
94+
result::InferenceResult
95+
end
96+
8597
# Stmt infos that are used by external consumers, but not by optimization.
8698
# These are not produced by default and must be explicitly opted into by
8799
# the AbstractInterpreter.

base/compiler/typeutils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ function unswitchtupleunion(u::Union)
234234
ts = uniontypes(u)
235235
n = -1
236236
for t in ts
237-
if t isa DataType && t.name === Tuple.name && !isvarargtype(t.parameters[end])
237+
if t isa DataType && t.name === Tuple.name && length(t.parameters) != 0 && !isvarargtype(t.parameters[end])
238238
if n == -1
239239
n = length(t.parameters)
240240
elseif n != length(t.parameters)

0 commit comments

Comments
 (0)