Skip to content

Commit 5bd974a

Browse files
committed
AbstractInterpreter: remove method_table(::AbstractInterpreter, ::InferenceState) interface (#44389)
In #44240 we removed the `CachedMethodTable` support as it turned out to be ineffective under the current compiler infrastructure. Because of this, there is no strong reason to keep a method table per `InferenceState`. This commit simply removes the `method_table(::AbstractInterpreter, ::InferenceState)` interface and should make it clearer which interface should be overloaded to implement a contextual dispatch.
1 parent c4f6c12 commit 5bd974a

File tree

5 files changed

+73
-28
lines changed

5 files changed

+73
-28
lines changed

base/compiler/abstractinterpretation.jl

+7-7
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
4747
end
4848

4949
argtypes = arginfo.argtypes
50-
matches = find_matching_methods(argtypes, atype, method_table(interp, sv), InferenceParams(interp).MAX_UNION_SPLITTING, max_methods)
50+
matches = find_matching_methods(argtypes, atype, method_table(interp), InferenceParams(interp).MAX_UNION_SPLITTING, max_methods)
5151
if isa(matches, FailedMethodMatch)
5252
add_remark!(interp, sv, matches.reason)
5353
tristate_merge!(sv, Effects())
@@ -637,7 +637,7 @@ end
637637

638638
function pure_eval_eligible(interp::AbstractInterpreter,
639639
@nospecialize(f), applicable::Vector{Any}, arginfo::ArgInfo, sv::InferenceState)
640-
return !isoverlayed(method_table(interp, sv)) &&
640+
return !isoverlayed(method_table(interp)) &&
641641
f !== nothing &&
642642
length(applicable) == 1 &&
643643
is_method_pure(applicable[1]::MethodMatch) &&
@@ -674,7 +674,7 @@ end
674674

675675
function concrete_eval_eligible(interp::AbstractInterpreter,
676676
@nospecialize(f), result::MethodCallResult, arginfo::ArgInfo, sv::InferenceState)
677-
return !isoverlayed(method_table(interp, sv)) &&
677+
return !isoverlayed(method_table(interp)) &&
678678
f !== nothing &&
679679
result.edge !== nothing &&
680680
is_total_or_error(result.edge_effects) &&
@@ -2110,14 +2110,14 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
21102110
frame.dont_work_on_me = true # mark that this function is currently on the stack
21112111
W = frame.ip
21122112
states = frame.stmt_types
2113-
n = frame.nstmts
2113+
nstmts = frame.nstmts
21142114
nargs = frame.nargs
21152115
def = frame.linfo.def
21162116
isva = isa(def, Method) && def.isva
21172117
nslots = nargs - isva
21182118
slottypes = frame.slottypes
21192119
ssavaluetypes = frame.src.ssavaluetypes::Vector{Any}
2120-
while frame.pc´´ <= n
2120+
while frame.pc´´ <= nstmts
21212121
# make progress on the active ip set
21222122
local pc::Int = frame.pc´´
21232123
while true # inner loop optimizes the common case where it can run straight from pc to pc + 1
@@ -2189,7 +2189,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
21892189
end
21902190
end
21912191
elseif isa(stmt, ReturnNode)
2192-
pc´ = n + 1
2192+
pc´ = nstmts + 1
21932193
bestguess = frame.bestguess
21942194
rt = abstract_eval_value(interp, stmt.val, changes, frame)
21952195
rt = widenreturn(rt, bestguess, nslots, slottypes, changes)
@@ -2310,7 +2310,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
23102310
ssavaluetypes[pc] = Any
23112311
end
23122312

2313-
pc´ > n && break # can't proceed with the fast-path fall-through
2313+
pc´ > nstmts && break # can't proceed with the fast-path fall-through
23142314
newstate = stupdate!(states[pc´], changes)
23152315
if isa(stmt, GotoNode) && frame.pc´´ < pc´
23162316
# if we are processing a goto node anyways,

base/compiler/inferencestate.jl

+11-20
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,6 @@ mutable struct InferenceState
6262
# Inferred purity flags
6363
ipo_effects::Effects
6464

65-
# The place to look up methods while working on this function.
66-
# In particular, we cache method lookup results for the same function to
67-
# fast path repeated queries.
68-
method_table::InternalMethodTable
69-
7065
# The interpreter that created this inference state. Not looked at by
7166
# NativeInterpreter. But other interpreters may use this to detect cycles
7267
interp::AbstractInterpreter
@@ -85,9 +80,9 @@ mutable struct InferenceState
8580
src.ssavaluetypes = Any[ NOT_FOUND for i = 1:nssavalues ]
8681
stmt_info = Any[ nothing for i = 1:length(code) ]
8782

88-
n = length(code)
89-
s_types = Union{Nothing, VarTable}[ nothing for i = 1:n ]
90-
s_edges = Union{Nothing, Vector{Any}}[ nothing for i = 1:n ]
83+
nstmts = length(code)
84+
s_types = Union{Nothing, VarTable}[ nothing for i = 1:nstmts ]
85+
s_edges = Union{Nothing, Vector{Any}}[ nothing for i = 1:nstmts ]
9186

9287
# initial types
9388
nslots = length(src.slotflags)
@@ -129,19 +124,17 @@ mutable struct InferenceState
129124
@assert cache === :no || cache === :local || cache === :global
130125
frame = new(
131126
params, result, linfo,
132-
sp, slottypes, mod, 0,
133-
IdSet{InferenceState}(), IdSet{InferenceState}(),
127+
sp, slottypes, mod, #=currpc=#0,
128+
#=pclimitations=#IdSet{InferenceState}(), #=limitations=#IdSet{InferenceState}(),
134129
src, get_world_counter(interp), valid_worlds,
135130
nargs, s_types, s_edges, stmt_info,
136-
Union{}, ip, 1, n, handler_at,
137-
ssavalue_uses,
138-
Vector{Tuple{InferenceState,LineNum}}(), # cycle_backedges
139-
Vector{InferenceState}(), # callers_in_cycle
131+
#=bestguess=#Union{}, ip, #=pc´´=#1, nstmts, handler_at, ssavalue_uses,
132+
#=cycle_backedges=#Vector{Tuple{InferenceState,LineNum}}(),
133+
#=callers_in_cycle=#Vector{InferenceState}(),
140134
#=parent=#nothing,
141-
cache === :global, false, false,
142-
Effects(consistent, ALWAYS_TRUE, ALWAYS_TRUE, ALWAYS_TRUE,
143-
inbounds_taints_consistency),
144-
method_table(interp),
135+
#=cached=#cache === :global,
136+
#=inferred=#false, #=dont_work_on_me=#false,
137+
#=ipo_effects=#Effects(consistent, ALWAYS_TRUE, ALWAYS_TRUE, ALWAYS_TRUE, inbounds_taints_consistency),
145138
interp)
146139
result.result = frame
147140
cache !== :no && push!(get_inference_cache(interp), result)
@@ -267,8 +260,6 @@ function iterate(unw::InfStackUnwind, (infstate, cyclei)::Tuple{InferenceState,
267260
(infstate::InferenceState, (infstate, cyclei))
268261
end
269262

270-
method_table(interp::AbstractInterpreter, sv::InferenceState) = sv.method_table
271-
272263
function InferenceState(result::InferenceResult, cache::Symbol, interp::AbstractInterpreter)
273264
# prepare an InferenceState object for inferring lambda
274265
src = retrieve_code_info(result.linfo)

base/compiler/types.jl

+7
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,13 @@ may_compress(::AbstractInterpreter) = true
314314
may_discard_trees(::AbstractInterpreter) = true
315315
verbose_stmt_info(::AbstractInterpreter) = false
316316

317+
"""
318+
method_table(interp::AbstractInterpreter) -> MethodTableView
319+
320+
Returns a method table this `interp` uses for method lookup.
321+
External `AbstractInterpreter` can optionally return `OverlayMethodTable` here
322+
to incorporate customized dispatches for the overridden methods.
323+
"""
317324
method_table(interp::AbstractInterpreter) = InternalMethodTable(get_world_counter(interp))
318325

319326
"""

test/choosetests.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ function choosetests(choices = [])
142142
filtertests!(tests, "subarray")
143143
filtertests!(tests, "compiler", ["compiler/inference", "compiler/validation",
144144
"compiler/ssair", "compiler/irpasses", "compiler/codegen",
145-
"compiler/inline", "compiler/contextual",
145+
"compiler/inline", "compiler/contextual", "compiler/AbstractInterpreter",
146146
"compiler/EscapeAnalysis/local", "compiler/EscapeAnalysis/interprocedural"])
147147
filtertests!(tests, "compiler/EscapeAnalysis", [
148148
"compiler/EscapeAnalysis/local", "compiler/EscapeAnalysis/interprocedural"])

test/compiler/AbstractInterpreter.jl

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# This file is a part of Julia. License is MIT: https://julialang.org/license
2+
3+
const CC = Core.Compiler
4+
import Core: MethodInstance, CodeInstance
5+
import .CC: WorldRange, WorldView
6+
7+
# define new `AbstractInterpreter` that satisfies the minimum interface requirements
8+
# while managing its cache independently
9+
macro newinterp(name)
10+
cachename = gensym(string(name, "Cache"))
11+
name = esc(name)
12+
quote
13+
struct $cachename
14+
dict::IdDict{MethodInstance,CodeInstance}
15+
end
16+
struct $name <: CC.AbstractInterpreter
17+
interp::CC.NativeInterpreter
18+
cache::$cachename
19+
$name(world = Base.get_world_counter();
20+
interp = CC.NativeInterpreter(world),
21+
cache = $cachename(IdDict{MethodInstance,CodeInstance}())
22+
) = new(interp, cache)
23+
end
24+
CC.InferenceParams(interp::$name) = CC.InferenceParams(interp.interp)
25+
CC.OptimizationParams(interp::$name) = CC.OptimizationParams(interp.interp)
26+
CC.get_world_counter(interp::$name) = CC.get_world_counter(interp.interp)
27+
CC.get_inference_cache(interp::$name) = CC.get_inference_cache(interp.interp)
28+
CC.code_cache(interp::$name) = WorldView(interp.cache, WorldRange(CC.get_world_counter(interp)))
29+
CC.get(wvc::WorldView{<:$cachename}, mi::MethodInstance, default) = get(wvc.cache.dict, mi, default)
30+
CC.getindex(wvc::WorldView{<:$cachename}, mi::MethodInstance) = getindex(wvc.cache.dict, mi)
31+
CC.haskey(wvc::WorldView{<:$cachename}, mi::MethodInstance) = haskey(wvc.cache.dict, mi)
32+
CC.setindex!(wvc::WorldView{<:$cachename}, ci::CodeInstance, mi::MethodInstance) = setindex!(wvc.cache.dict, ci, mi)
33+
end
34+
end
35+
36+
# `OverlayMethodTable`
37+
# --------------------
38+
import Base.Experimental: @MethodTable, @overlay
39+
40+
@newinterp MTOverlayInterp
41+
@MethodTable(OverlayedMT)
42+
CC.method_table(interp::MTOverlayInterp) = CC.OverlayMethodTable(CC.get_world_counter(interp), OverlayedMT)
43+
44+
@overlay OverlayedMT sin(x::Float64) = 1
45+
@test Base.return_types((Int,), MTOverlayInterp()) do x
46+
sin(x)
47+
end == Any[Int]

0 commit comments

Comments
 (0)