Skip to content

Commit 8f76c69

Browse files
authored
minor refactoring on find_method_matches (JuliaLang#53741)
So that it can be tested in isolation easier.
1 parent 8e67f99 commit 8f76c69

File tree

2 files changed

+71
-67
lines changed

2 files changed

+71
-67
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 70 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
2222
end
2323

2424
argtypes = arginfo.argtypes
25-
matches = find_matching_methods(𝕃ᵢ, argtypes, atype, method_table(interp),
26-
InferenceParams(interp).max_union_splitting, max_methods)
25+
matches = find_method_matches(interp, argtypes, atype; max_methods)
2726
if isa(matches, FailedMethodMatch)
2827
add_remark!(interp, sv, matches.reason)
2928
return CallMeta(Any, Any, Effects(), NoCallInfo())
@@ -255,73 +254,79 @@ struct UnionSplitMethodMatches
255254
end
256255
any_ambig(m::UnionSplitMethodMatches) = any(any_ambig, m.info.matches)
257256

258-
function find_matching_methods(𝕃::AbstractLattice,
259-
argtypes::Vector{Any}, @nospecialize(atype), method_table::MethodTableView,
260-
max_union_splitting::Int, max_methods::Int)
261-
# NOTE this is valid as far as any "constant" lattice element doesn't represent `Union` type
262-
if 1 < unionsplitcost(𝕃, argtypes) <= max_union_splitting
263-
split_argtypes = switchtupleunion(𝕃, argtypes)
264-
infos = MethodMatchInfo[]
265-
applicable = Any[]
266-
applicable_argtypes = Vector{Any}[] # arrays like `argtypes`, including constants, for each match
267-
valid_worlds = WorldRange()
268-
mts = MethodTable[]
269-
fullmatches = Bool[]
270-
for i in 1:length(split_argtypes)
271-
arg_n = split_argtypes[i]::Vector{Any}
272-
sig_n = argtypes_to_type(arg_n)
273-
mt = ccall(:jl_method_table_for, Any, (Any,), sig_n)
274-
mt === nothing && return FailedMethodMatch("Could not identify method table for call")
275-
mt = mt::MethodTable
276-
matches = findall(sig_n, method_table; limit = max_methods)
277-
if matches === nothing
278-
return FailedMethodMatch("For one of the union split cases, too many methods matched")
279-
end
280-
push!(infos, MethodMatchInfo(matches))
281-
for m in matches
282-
push!(applicable, m)
283-
push!(applicable_argtypes, arg_n)
284-
end
285-
valid_worlds = intersect(valid_worlds, matches.valid_worlds)
286-
thisfullmatch = any(match::MethodMatch->match.fully_covers, matches)
287-
found = false
288-
for (i, mt′) in enumerate(mts)
289-
if mt′ === mt
290-
fullmatches[i] &= thisfullmatch
291-
found = true
292-
break
293-
end
294-
end
295-
if !found
296-
push!(mts, mt)
297-
push!(fullmatches, thisfullmatch)
298-
end
299-
end
300-
return UnionSplitMethodMatches(applicable,
301-
applicable_argtypes,
302-
UnionSplitInfo(infos),
303-
valid_worlds,
304-
mts,
305-
fullmatches)
306-
else
307-
mt = ccall(:jl_method_table_for, Any, (Any,), atype)
308-
if mt === nothing
309-
return FailedMethodMatch("Could not identify method table for call")
310-
end
257+
function find_method_matches(interp::AbstractInterpreter, argtypes::Vector{Any}, @nospecialize(atype);
258+
max_union_splitting::Int = InferenceParams(interp).max_union_splitting,
259+
max_methods::Int = InferenceParams(interp).max_methods)
260+
if is_union_split_eligible(typeinf_lattice(interp), argtypes, max_union_splitting)
261+
return find_union_split_method_matches(interp, argtypes, atype, max_methods)
262+
end
263+
return find_simple_method_matches(interp, atype, max_methods)
264+
end
265+
266+
# NOTE this is valid as far as any "constant" lattice element doesn't represent `Union` type
267+
is_union_split_eligible(𝕃::AbstractLattice, argtypes::Vector{Any}, max_union_splitting::Int) =
268+
1 < unionsplitcost(𝕃, argtypes) <= max_union_splitting
269+
270+
function find_union_split_method_matches(interp::AbstractInterpreter, argtypes::Vector{Any},
271+
@nospecialize(atype), max_methods::Int)
272+
split_argtypes = switchtupleunion(typeinf_lattice(interp), argtypes)
273+
infos = MethodMatchInfo[]
274+
applicable = Any[]
275+
applicable_argtypes = Vector{Any}[] # arrays like `argtypes`, including constants, for each match
276+
valid_worlds = WorldRange()
277+
mts = MethodTable[]
278+
fullmatches = Bool[]
279+
for i in 1:length(split_argtypes)
280+
arg_n = split_argtypes[i]::Vector{Any}
281+
sig_n = argtypes_to_type(arg_n)
282+
mt = ccall(:jl_method_table_for, Any, (Any,), sig_n)
283+
mt === nothing && return FailedMethodMatch("Could not identify method table for call")
311284
mt = mt::MethodTable
312-
matches = findall(atype, method_table; limit = max_methods)
285+
matches = findall(sig_n, method_table(interp); limit = max_methods)
313286
if matches === nothing
314-
# this means too many methods matched
315-
# (assume this will always be true, so we don't compute / update valid age in this case)
316-
return FailedMethodMatch("Too many methods matched")
287+
return FailedMethodMatch("For one of the union split cases, too many methods matched")
288+
end
289+
push!(infos, MethodMatchInfo(matches))
290+
for m in matches
291+
push!(applicable, m)
292+
push!(applicable_argtypes, arg_n)
293+
end
294+
valid_worlds = intersect(valid_worlds, matches.valid_worlds)
295+
thisfullmatch = any(match::MethodMatch->match.fully_covers, matches)
296+
found = false
297+
for (i, mt′) in enumerate(mts)
298+
if mt′ === mt
299+
fullmatches[i] &= thisfullmatch
300+
found = true
301+
break
302+
end
317303
end
318-
fullmatch = any(match::MethodMatch->match.fully_covers, matches)
319-
return MethodMatches(matches.matches,
320-
MethodMatchInfo(matches),
321-
matches.valid_worlds,
322-
mt,
323-
fullmatch)
304+
if !found
305+
push!(mts, mt)
306+
push!(fullmatches, thisfullmatch)
307+
end
308+
end
309+
info = UnionSplitInfo(infos)
310+
return UnionSplitMethodMatches(
311+
applicable, applicable_argtypes, info, valid_worlds, mts, fullmatches)
312+
end
313+
314+
function find_simple_method_matches(interp::AbstractInterpreter, @nospecialize(atype), max_methods::Int)
315+
mt = ccall(:jl_method_table_for, Any, (Any,), atype)
316+
if mt === nothing
317+
return FailedMethodMatch("Could not identify method table for call")
318+
end
319+
mt = mt::MethodTable
320+
matches = findall(atype, method_table(interp); limit = max_methods)
321+
if matches === nothing
322+
# this means too many methods matched
323+
# (assume this will always be true, so we don't compute / update valid age in this case)
324+
return FailedMethodMatch("Too many methods matched")
324325
end
326+
info = MethodMatchInfo(matches)
327+
fullmatch = any(match::MethodMatch->match.fully_covers, matches)
328+
return MethodMatches(
329+
matches.matches, info, matches.valid_worlds, mt, fullmatch)
325330
end
326331

327332
"""

base/compiler/tfuncs.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3019,8 +3019,7 @@ function abstract_applicable(interp::AbstractInterpreter, argtypes::Vector{Any},
30193019
isvarargtype(argtypes[2]) && return CallMeta(Bool, Any, EFFECTS_UNKNOWN, NoCallInfo())
30203020
argtypes = argtypes[2:end]
30213021
atype = argtypes_to_type(argtypes)
3022-
matches = find_matching_methods(typeinf_lattice(interp), argtypes, atype, method_table(interp),
3023-
InferenceParams(interp).max_union_splitting, max_methods)
3022+
matches = find_method_matches(interp, argtypes, atype; max_methods)
30243023
if isa(matches, FailedMethodMatch)
30253024
rt = Bool # too many matches to analyze
30263025
else

0 commit comments

Comments
 (0)