Skip to content

Commit 61c044c

Browse files
gbaralditopolarity
andauthored
Inline statically known method errors. (#54972)
This replaces the `Expr(:call, ...)` with a call of a new builtin `Core.throw_methoderror` This is useful because it makes very clear if something is a static method error or a plain dynamic dispatch that always errors. Tools such as AllocCheck or juliac can notice that this is not a genuine dynamic dispatch, and prevent it from becoming a false positive compile-time error. Dependent on #55705 --------- Co-authored-by: Cody Tapscott <topolarity@tapscott.me>
1 parent f808606 commit 61c044c

File tree

9 files changed

+111
-65
lines changed

9 files changed

+111
-65
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
209209
rettype = exctype = Any
210210
all_effects = Effects()
211211
else
212-
if (matches isa MethodMatches ? (!matches.fullmatch || any_ambig(matches)) :
213-
(!all(matches.fullmatches) || any_ambig(matches)))
212+
if !fully_covering(matches) || any_ambig(matches)
214213
# Account for the fact that we may encounter a MethodError with a non-covered or ambiguous signature.
215214
all_effects = Effects(all_effects; nothrow=false)
216215
exctype = exctype ₚ MethodError
@@ -275,21 +274,23 @@ struct MethodMatches
275274
applicable::Vector{Any}
276275
info::MethodMatchInfo
277276
valid_worlds::WorldRange
278-
mt::MethodTable
279-
fullmatch::Bool
280277
end
281-
any_ambig(info::MethodMatchInfo) = info.results.ambig
278+
any_ambig(result::MethodLookupResult) = result.ambig
279+
any_ambig(info::MethodMatchInfo) = any_ambig(info.results)
282280
any_ambig(m::MethodMatches) = any_ambig(m.info)
281+
fully_covering(info::MethodMatchInfo) = info.fullmatch
282+
fully_covering(m::MethodMatches) = fully_covering(m.info)
283283

284284
struct UnionSplitMethodMatches
285285
applicable::Vector{Any}
286286
applicable_argtypes::Vector{Vector{Any}}
287287
info::UnionSplitInfo
288288
valid_worlds::WorldRange
289-
mts::Vector{MethodTable}
290-
fullmatches::Vector{Bool}
291289
end
292-
any_ambig(m::UnionSplitMethodMatches) = any(any_ambig, m.info.matches)
290+
any_ambig(info::UnionSplitInfo) = any(any_ambig, info.matches)
291+
any_ambig(m::UnionSplitMethodMatches) = any_ambig(m.info)
292+
fully_covering(info::UnionSplitInfo) = all(info.fullmatches)
293+
fully_covering(m::UnionSplitMethodMatches) = fully_covering(m.info)
293294

294295
function find_method_matches(interp::AbstractInterpreter, argtypes::Vector{Any}, @nospecialize(atype);
295296
max_union_splitting::Int = InferenceParams(interp).max_union_splitting,
@@ -307,7 +308,7 @@ is_union_split_eligible(𝕃::AbstractLattice, argtypes::Vector{Any}, max_union_
307308
function find_union_split_method_matches(interp::AbstractInterpreter, argtypes::Vector{Any},
308309
@nospecialize(atype), max_methods::Int)
309310
split_argtypes = switchtupleunion(typeinf_lattice(interp), argtypes)
310-
infos = MethodMatchInfo[]
311+
infos = MethodLookupResult[]
311312
applicable = Any[]
312313
applicable_argtypes = Vector{Any}[] # arrays like `argtypes`, including constants, for each match
313314
valid_worlds = WorldRange()
@@ -323,29 +324,29 @@ function find_union_split_method_matches(interp::AbstractInterpreter, argtypes::
323324
if matches === nothing
324325
return FailedMethodMatch("For one of the union split cases, too many methods matched")
325326
end
326-
push!(infos, MethodMatchInfo(matches))
327+
push!(infos, matches)
327328
for m in matches
328329
push!(applicable, m)
329330
push!(applicable_argtypes, arg_n)
330331
end
331332
valid_worlds = intersect(valid_worlds, matches.valid_worlds)
332333
thisfullmatch = any(match::MethodMatch->match.fully_covers, matches)
333-
found = false
334+
mt_found = false
334335
for (i, mt′) in enumerate(mts)
335336
if mt′ === mt
336337
fullmatches[i] &= thisfullmatch
337-
found = true
338+
mt_found = true
338339
break
339340
end
340341
end
341-
if !found
342+
if !mt_found
342343
push!(mts, mt)
343344
push!(fullmatches, thisfullmatch)
344345
end
345346
end
346-
info = UnionSplitInfo(infos)
347+
info = UnionSplitInfo(infos, mts, fullmatches)
347348
return UnionSplitMethodMatches(
348-
applicable, applicable_argtypes, info, valid_worlds, mts, fullmatches)
349+
applicable, applicable_argtypes, info, valid_worlds)
349350
end
350351

351352
function find_simple_method_matches(interp::AbstractInterpreter, @nospecialize(atype), max_methods::Int)
@@ -360,10 +361,9 @@ function find_simple_method_matches(interp::AbstractInterpreter, @nospecialize(a
360361
# (assume this will always be true, so we don't compute / update valid age in this case)
361362
return FailedMethodMatch("Too many methods matched")
362363
end
363-
info = MethodMatchInfo(matches)
364364
fullmatch = any(match::MethodMatch->match.fully_covers, matches)
365-
return MethodMatches(
366-
matches.matches, info, matches.valid_worlds, mt, fullmatch)
365+
info = MethodMatchInfo(matches, mt, fullmatch)
366+
return MethodMatches(matches.matches, info, matches.valid_worlds)
367367
end
368368

369369
"""
@@ -584,9 +584,10 @@ function add_call_backedges!(interp::AbstractInterpreter, @nospecialize(rettype)
584584
# also need an edge to the method table in case something gets
585585
# added that did not intersect with any existing method
586586
if isa(matches, MethodMatches)
587-
matches.fullmatch || add_mt_backedge!(sv, matches.mt, atype)
587+
fully_covering(matches) || add_mt_backedge!(sv, matches.info.mt, atype)
588588
else
589-
for (thisfullmatch, mt) in zip(matches.fullmatches, matches.mts)
589+
matches::UnionSplitMethodMatches
590+
for (thisfullmatch, mt) in zip(matches.info.fullmatches, matches.info.mts)
590591
thisfullmatch || add_mt_backedge!(sv, mt, atype)
591592
end
592593
end

base/compiler/ssair/inlining.jl

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,13 @@ struct InliningCase
5050
end
5151

5252
struct UnionSplit
53-
fully_covered::Bool
53+
handled_all_cases::Bool # All possible dispatches are included in the cases
54+
fully_covered::Bool # All handled cases are fully covering
5455
atype::DataType
5556
cases::Vector{InliningCase}
5657
bbs::Vector{Int}
57-
UnionSplit(fully_covered::Bool, atype::DataType, cases::Vector{InliningCase}) =
58-
new(fully_covered, atype, cases, Int[])
58+
UnionSplit(handled_all_cases::Bool, fully_covered::Bool, atype::DataType, cases::Vector{InliningCase}) =
59+
new(handled_all_cases, fully_covered, atype, cases, Int[])
5960
end
6061

6162
struct InliningEdgeTracker
@@ -215,7 +216,7 @@ end
215216

216217
function cfg_inline_unionsplit!(ir::IRCode, idx::Int, union_split::UnionSplit,
217218
state::CFGInliningState, params::OptimizationParams)
218-
(; fully_covered, #=atype,=# cases, bbs) = union_split
219+
(; handled_all_cases, fully_covered, #=atype,=# cases, bbs) = union_split
219220
inline_into_block!(state, block_for_inst(ir, idx))
220221
from_bbs = Int[]
221222
delete!(state.split_targets, length(state.new_cfg_blocks))
@@ -235,7 +236,7 @@ function cfg_inline_unionsplit!(ir::IRCode, idx::Int, union_split::UnionSplit,
235236
end
236237
end
237238
push!(from_bbs, length(state.new_cfg_blocks))
238-
if !(i == length(cases) && fully_covered)
239+
if !(i == length(cases) && (handled_all_cases && fully_covered))
239240
# This block will have the next condition or the final else case
240241
push!(state.new_cfg_blocks, BasicBlock(StmtRange(idx, idx)))
241242
push!(state.new_cfg_blocks[cond_bb].succs, length(state.new_cfg_blocks))
@@ -244,7 +245,10 @@ function cfg_inline_unionsplit!(ir::IRCode, idx::Int, union_split::UnionSplit,
244245
end
245246
end
246247
# The edge from the fallback block.
247-
fully_covered || push!(from_bbs, length(state.new_cfg_blocks))
248+
# NOTE This edge is only required for `!handled_all_cases` and not `!fully_covered`,
249+
# since in the latter case we inline `Core.throw_methoderror` into the fallback
250+
# block, which is must-throw, making the subsequent code path unreachable.
251+
!handled_all_cases && push!(from_bbs, length(state.new_cfg_blocks))
248252
# This block will be the block everyone returns to
249253
push!(state.new_cfg_blocks, BasicBlock(StmtRange(idx, idx), from_bbs, orig_succs))
250254
join_bb = length(state.new_cfg_blocks)
@@ -523,7 +527,7 @@ assuming their order stays the same post-discovery in `ml_matches`.
523527
function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int, argexprs::Vector{Any},
524528
union_split::UnionSplit, boundscheck::Symbol,
525529
todo_bbs::Vector{Tuple{Int,Int}}, interp::AbstractInterpreter)
526-
(; fully_covered, atype, cases, bbs) = union_split
530+
(; handled_all_cases, fully_covered, atype, cases, bbs) = union_split
527531
stmt, typ, line = compact.result[idx][:stmt], compact.result[idx][:type], compact.result[idx][:line]
528532
join_bb = bbs[end]
529533
pn = PhiNode()
@@ -538,7 +542,7 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int, argexprs::
538542
cond = true
539543
nparams = fieldcount(atype)
540544
@assert nparams == fieldcount(mtype)
541-
if !(i == ncases && fully_covered)
545+
if !(i == ncases && fully_covered && handled_all_cases)
542546
for i = 1:nparams
543547
aft, mft = fieldtype(atype, i), fieldtype(mtype, i)
544548
# If this is always true, we don't need to check for it
@@ -597,14 +601,18 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int, argexprs::
597601
end
598602
bb += 1
599603
# We're now in the fall through block, decide what to do
600-
if !fully_covered
604+
if !handled_all_cases
601605
ssa = insert_node_here!(compact, NewInstruction(stmt, typ, line))
602606
push!(pn.edges, bb)
603607
push!(pn.values, ssa)
604608
insert_node_here!(compact, NewInstruction(GotoNode(join_bb), Any, line))
605609
finish_current_bb!(compact, 0)
610+
elseif !fully_covered
611+
insert_node_here!(compact, NewInstruction(Expr(:call, GlobalRef(Core, :throw_methoderror), argexprs...), Union{}, line))
612+
insert_node_here!(compact, NewInstruction(ReturnNode(), Union{}, line))
613+
finish_current_bb!(compact, 0)
614+
ncases == 0 && return insert_node_here!(compact, NewInstruction(nothing, Any, line))
606615
end
607-
608616
# We're now in the join block.
609617
return insert_node_here!(compact, NewInstruction(pn, typ, line))
610618
end
@@ -1348,10 +1356,6 @@ function compute_inlining_cases(@nospecialize(info::CallInfo), flag::UInt32, sig
13481356
# Too many applicable methods
13491357
# Or there is a (partial?) ambiguity
13501358
return nothing
1351-
elseif length(meth) == 0
1352-
# No applicable methods; try next union split
1353-
handled_all_cases = false
1354-
continue
13551359
end
13561360
local split_fully_covered = false
13571361
for (j, match) in enumerate(meth)
@@ -1392,22 +1396,26 @@ function compute_inlining_cases(@nospecialize(info::CallInfo), flag::UInt32, sig
13921396
handled_all_cases &= handle_any_const_result!(cases,
13931397
result, match, argtypes, info, flag, state; allow_typevars=true)
13941398
end
1399+
if !fully_covered
1400+
atype = argtypes_to_type(sig.argtypes)
1401+
# We will emit an inline MethodError so we need a backedge to the MethodTable
1402+
add_uncovered_edges!(state.edges, info, atype)
1403+
end
13951404
elseif !isempty(cases)
13961405
# if we've not seen all candidates, union split is valid only for dispatch tuples
13971406
filter!(case::InliningCase->isdispatchtuple(case.sig), cases)
13981407
end
1399-
1400-
return cases, (handled_all_cases & fully_covered), joint_effects
1408+
return cases, handled_all_cases, fully_covered, joint_effects
14011409
end
14021410

14031411
function handle_call!(todo::Vector{Pair{Int,Any}},
14041412
ir::IRCode, idx::Int, stmt::Expr, @nospecialize(info::CallInfo), flag::UInt32, sig::Signature,
14051413
state::InliningState)
14061414
cases = compute_inlining_cases(info, flag, sig, state)
14071415
cases === nothing && return nothing
1408-
cases, all_covered, joint_effects = cases
1416+
cases, handled_all_cases, fully_covered, joint_effects = cases
14091417
atype = argtypes_to_type(sig.argtypes)
1410-
handle_cases!(todo, ir, idx, stmt, atype, cases, all_covered, joint_effects)
1418+
handle_cases!(todo, ir, idx, stmt, atype, cases, handled_all_cases, fully_covered, joint_effects)
14111419
end
14121420

14131421
function handle_match!(cases::Vector{InliningCase},
@@ -1496,19 +1504,19 @@ function concrete_result_item(result::ConcreteResult, @nospecialize(info::CallIn
14961504
end
14971505

14981506
function handle_cases!(todo::Vector{Pair{Int,Any}}, ir::IRCode, idx::Int, stmt::Expr,
1499-
@nospecialize(atype), cases::Vector{InliningCase}, all_covered::Bool,
1507+
@nospecialize(atype), cases::Vector{InliningCase}, handled_all_cases::Bool, fully_covered::Bool,
15001508
joint_effects::Effects)
15011509
# If we only have one case and that case is fully covered, we may either
15021510
# be able to do the inlining now (for constant cases), or push it directly
15031511
# onto the todo list
1504-
if all_covered && length(cases) == 1
1512+
if fully_covered && handled_all_cases && length(cases) == 1
15051513
handle_single_case!(todo, ir, idx, stmt, cases[1].item)
1506-
elseif length(cases) > 0
1514+
elseif length(cases) > 0 || handled_all_cases
15071515
isa(atype, DataType) || return nothing
15081516
for case in cases
15091517
isa(case.sig, DataType) || return nothing
15101518
end
1511-
push!(todo, idx=>UnionSplit(all_covered, atype, cases))
1519+
push!(todo, idx=>UnionSplit(handled_all_cases, fully_covered, atype, cases))
15121520
else
15131521
add_flag!(ir[SSAValue(idx)], flags_for_effects(joint_effects))
15141522
end

base/compiler/stmtinfo.jl

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,13 @@ not a call to a generic function.
3333
"""
3434
struct MethodMatchInfo <: CallInfo
3535
results::MethodLookupResult
36+
mt::MethodTable
37+
fullmatch::Bool
3638
end
3739
nsplit_impl(info::MethodMatchInfo) = 1
3840
getsplit_impl(info::MethodMatchInfo, idx::Int) = (@assert idx == 1; info.results)
3941
getresult_impl(::MethodMatchInfo, ::Int) = nothing
42+
add_uncovered_edges_impl(edges::Vector{Any}, info::MethodMatchInfo, @nospecialize(atype)) = (!info.fullmatch && push!(edges, info.mt, atype); )
4043

4144
"""
4245
info::UnionSplitInfo <: CallInfo
@@ -48,20 +51,27 @@ each partition (`info.matches::Vector{MethodMatchInfo}`).
4851
This info is illegal on any statement that is not a call to a generic function.
4952
"""
5053
struct UnionSplitInfo <: CallInfo
51-
matches::Vector{MethodMatchInfo}
54+
matches::Vector{MethodLookupResult}
55+
mts::Vector{MethodTable}
56+
fullmatches::Vector{Bool}
5257
end
5358

5459
nmatches(info::MethodMatchInfo) = length(info.results)
5560
function nmatches(info::UnionSplitInfo)
5661
n = 0
5762
for mminfo in info.matches
58-
n += nmatches(mminfo)
63+
n += length(mminfo)
5964
end
6065
return n
6166
end
6267
nsplit_impl(info::UnionSplitInfo) = length(info.matches)
63-
getsplit_impl(info::UnionSplitInfo, idx::Int) = getsplit_impl(info.matches[idx], 1)
68+
getsplit_impl(info::UnionSplitInfo, idx::Int) = info.matches[idx]
6469
getresult_impl(::UnionSplitInfo, ::Int) = nothing
70+
function add_uncovered_edges_impl(edges::Vector{Any}, info::UnionSplitInfo, @nospecialize(atype))
71+
for (mt, fullmatch) in zip(info.mts, info.fullmatches)
72+
!fullmatch && push!(edges, mt, atype)
73+
end
74+
end
6575

6676
abstract type ConstResult end
6777

@@ -105,6 +115,7 @@ end
105115
nsplit_impl(info::ConstCallInfo) = nsplit(info.call)
106116
getsplit_impl(info::ConstCallInfo, idx::Int) = getsplit(info.call, idx)
107117
getresult_impl(info::ConstCallInfo, idx::Int) = info.results[idx]
118+
add_uncovered_edges_impl(edges::Vector{Any}, info::ConstCallInfo, @nospecialize(atype)) = add_uncovered_edges!(edges, info.call, atype)
108119

109120
"""
110121
info::MethodResultPure <: CallInfo

base/compiler/tfuncs.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2983,9 +2983,9 @@ function abstract_applicable(interp::AbstractInterpreter, argtypes::Vector{Any},
29832983
# also need an edge to the method table in case something gets
29842984
# added that did not intersect with any existing method
29852985
if isa(matches, MethodMatches)
2986-
matches.fullmatch || add_mt_backedge!(sv, matches.mt, atype)
2986+
fully_covering(matches) || add_mt_backedge!(sv, matches.info.mt, atype)
29872987
else
2988-
for (thisfullmatch, mt) in zip(matches.fullmatches, matches.mts)
2988+
for (thisfullmatch, mt) in zip(matches.info.fullmatches, matches.info.mts)
29892989
thisfullmatch || add_mt_backedge!(sv, mt, atype)
29902990
end
29912991
end
@@ -3001,8 +3001,7 @@ function abstract_applicable(interp::AbstractInterpreter, argtypes::Vector{Any},
30013001
add_backedge!(sv, edge)
30023002
end
30033003

3004-
if isa(matches, MethodMatches) ? (!matches.fullmatch || any_ambig(matches)) :
3005-
(!all(matches.fullmatches) || any_ambig(matches))
3004+
if !fully_covering(matches) || any_ambig(matches)
30063005
# Account for the fact that we may encounter a MethodError with a non-covered or ambiguous signature.
30073006
rt = Bool
30083007
end

base/compiler/types.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,10 +450,16 @@ abstract type CallInfo end
450450

451451
nsplit(info::CallInfo) = nsplit_impl(info)::Union{Nothing,Int}
452452
getsplit(info::CallInfo, idx::Int) = getsplit_impl(info, idx)::MethodLookupResult
453+
add_uncovered_edges!(edges::Vector{Any}, info::CallInfo, @nospecialize(atype)) = add_uncovered_edges_impl(edges, info, atype)
454+
453455
getresult(info::CallInfo, idx::Int) = getresult_impl(info, idx)
454456

457+
# must implement `nsplit`, `getsplit`, and `add_uncovered_edges!` to opt in to inlining
455458
nsplit_impl(::CallInfo) = nothing
456459
getsplit_impl(::CallInfo, ::Int) = error("unexpected call into `getsplit`")
460+
add_uncovered_edges_impl(edges::Vector{Any}, info::CallInfo, @nospecialize(atype)) = error("unexpected call into `add_uncovered_edges!`")
461+
462+
# must implement `getresult` to opt in to extended lattice return information
457463
getresult_impl(::CallInfo, ::Int) = nothing
458464

459465
@specialize

test/compiler/AbstractInterpreter.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,7 @@ end
409409
CC.nsplit_impl(info::NoinlineCallInfo) = CC.nsplit(info.info)
410410
CC.getsplit_impl(info::NoinlineCallInfo, idx::Int) = CC.getsplit(info.info, idx)
411411
CC.getresult_impl(info::NoinlineCallInfo, idx::Int) = CC.getresult(info.info, idx)
412+
CC.add_uncovered_edges_impl(edges::Vector{Any}, info::NoinlineCallInfo, @nospecialize(atype)) = CC.add_uncovered_edges!(edges, info.info, atype)
412413

413414
function CC.abstract_call(interp::NoinlineInterpreter,
414415
arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.InferenceState, max_methods::Int)
@@ -431,6 +432,8 @@ end
431432
@inline function inlined_usually(x, y, z)
432433
return x * y + z
433434
end
435+
foo_split(x::Float64) = 1
436+
foo_split(x::Int) = 2
434437

435438
# check if the inlining algorithm works as expected
436439
let src = code_typed1((Float64,Float64,Float64)) do x, y, z
@@ -444,6 +447,7 @@ let NoinlineModule = Module()
444447
main_func(x, y, z) = inlined_usually(x, y, z)
445448
@eval NoinlineModule noinline_func(x, y, z) = $inlined_usually(x, y, z)
446449
@eval OtherModule other_func(x, y, z) = $inlined_usually(x, y, z)
450+
@eval NoinlineModule bar_split_error() = $foo_split(Core.compilerbarrier(:type, nothing))
447451

448452
interp = NoinlineInterpreter(Set((NoinlineModule,)))
449453

@@ -473,6 +477,11 @@ let NoinlineModule = Module()
473477
@test count(isinvoke(:inlined_usually), src.code) == 0
474478
@test count(iscall((src, inlined_usually)), src.code) == 0
475479
end
480+
481+
let src = code_typed1(NoinlineModule.bar_split_error)
482+
@test count(iscall((src, foo_split)), src.code) == 0
483+
@test count(iscall((src, Core.throw_methoderror)), src.code) > 0
484+
end
476485
end
477486

478487
# Make sure that Core.Compiler has enough NamedTuple infrastructure

0 commit comments

Comments
 (0)