Skip to content

Commit bdb1515

Browse files
committed
lattice overhaul step 1: separate contexts where native Julia types are expected from those where extended lattice wrappers are
1 parent 7499513 commit bdb1515

28 files changed

+591
-450
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 199 additions & 189 deletions
Large diffs are not rendered by default.

base/compiler/compiler.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,36 @@ something(x::Any, y...) = x
112112
# compiler #
113113
############
114114

115+
# TODO remove me in the future, this is just to check the coverage of the overhaul
116+
import Core: Const, PartialStruct, InterConditional, PartialOpaque, TypeofVararg
117+
abstract type _AbstractLattice end
118+
const AbstractLattice = Union{
119+
Const, PartialStruct, InterConditional, PartialOpaque, TypeofVararg,
120+
_AbstractLattice}
121+
const Argtypes = Vector{AbstractLattice}
122+
123+
macro latticeop(mode, def)
124+
@assert is_function_def(def)
125+
sig, body = def.args
126+
if mode === :args || mode === :op
127+
nospecs = Symbol[]
128+
for arg in sig.args
129+
if isexpr(arg, :macrocall) && arg.args[1] === Symbol("@nospecialize")
130+
push!(nospecs, arg.args[3])
131+
end
132+
end
133+
idx = findfirst(x->!isa(x, LineNumberNode), body.args)
134+
for var in nospecs
135+
insert!(body.args, idx, Expr(:(=), var, Expr(:(::), var, :AbstractLattice)))
136+
end
137+
end
138+
if mode === :ret || mode === :op
139+
sig = Expr(:(::), sig, :AbstractLattice)
140+
end
141+
return esc(Expr(def.head, sig, body))
142+
end
143+
anymap(f::Function, a::Vector{AbstractLattice}) = Any[ f(a[i]) for i in 1:length(a) ]
144+
115145
include("compiler/cicache.jl")
116146
include("compiler/types.jl")
117147
include("compiler/utilities.jl")

base/compiler/inferenceresult.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# This file is a part of Julia. License is MIT: https://julialang.org/license
22

3-
function is_argtype_match(@nospecialize(given_argtype),
3+
@latticeop args function is_argtype_match(@nospecialize(given_argtype),
44
@nospecialize(cache_argtype),
55
overridden_by_const::Bool)
66
if is_forwardable_argtype(given_argtype)
@@ -26,7 +26,7 @@ function matching_cache_argtypes(
2626
@assert isa(linfo.def, Method) # ensure the next line works
2727
nargs::Int = linfo.def.nargs
2828
cache_argtypes, overridden_by_const = matching_cache_argtypes(linfo, nothing, va_override)
29-
given_argtypes = Vector{Any}(undef, length(argtypes))
29+
given_argtypes = Vector{AbstractLattice}(undef, length(argtypes))
3030
local condargs = nothing
3131
for i in 1:length(argtypes)
3232
argtype = argtypes[i]
@@ -42,7 +42,7 @@ function matching_cache_argtypes(
4242
if vtype === Bottom && elsetype === Bottom
4343
# we accidentally proved this method match is impossible
4444
# TODO bail out here immediately rather than just propagating Bottom ?
45-
given_argtypes[i] = Bottom
45+
given_argtypes[i] =
4646
else
4747
if condargs === nothing
4848
condargs = Tuple{Int,Int}[]
@@ -56,18 +56,18 @@ function matching_cache_argtypes(
5656
given_argtypes[i] = widenconditional(argtype)
5757
end
5858
isva = va_override || linfo.def.isva
59-
if isva || isvarargtype(given_argtypes[end])
59+
if isva || isvarargtype(unwraptype(given_argtypes[end]))
6060
isva_given_argtypes = Vector{Any}(undef, nargs)
6161
for i = 1:(nargs - isva)
6262
isva_given_argtypes[i] = argtype_by_index(given_argtypes, i)
6363
end
6464
if isva
65-
if length(given_argtypes) < nargs && isvarargtype(given_argtypes[end])
65+
if length(given_argtypes) < nargs && isvarargtype(unwraptype(given_argtypes[end]))
6666
last = length(given_argtypes)
6767
else
6868
last = nargs
6969
end
70-
isva_given_argtypes[nargs] = tuple_tfunc(given_argtypes[last:end])
70+
isva_given_argtypes[nargs] = TypeLattice(tuple_tfunc(anymap(unwraptype, given_argtypes[last:end])))
7171
# invalidate `Conditional` imposed on varargs
7272
if condargs !== nothing
7373
for (slotid, i) in condargs
@@ -101,7 +101,7 @@ function most_general_argtypes(method::Union{Method, Nothing}, @nospecialize(spe
101101
# For opaque closure, the closure environment is processed elsewhere
102102
nargs -= 1
103103
end
104-
cache_argtypes = Vector{Any}(undef, nargs)
104+
cache_argtypes = Vector{AbstractLattice}(undef, nargs)
105105
# First, if we're dealing with a varargs method, then we set the last element of `args`
106106
# to the appropriate `Tuple` type or `PartialStruct` instance.
107107
if !toplevel && isva
@@ -140,7 +140,7 @@ function most_general_argtypes(method::Union{Method, Nothing}, @nospecialize(spe
140140
vargtype = tuple_tfunc(vargtype_elements)
141141
end
142142
end
143-
cache_argtypes[nargs] = vargtype
143+
cache_argtypes[nargs] = TypeLattice(vargtype)
144144
nargs -= 1
145145
end
146146
# Now, we propagate type info from `linfo_argtypes` into `cache_argtypes`, improving some
@@ -168,10 +168,10 @@ function most_general_argtypes(method::Union{Method, Nothing}, @nospecialize(spe
168168
atyp = elim_free_typevars(rewrap_unionall(atyp, specTypes))
169169
end
170170
i == n && (lastatype = atyp)
171-
cache_argtypes[i] = atyp
171+
cache_argtypes[i] = TypeLattice(atyp)
172172
end
173173
for i = (tail_index + 1):nargs
174-
cache_argtypes[i] = lastatype
174+
cache_argtypes[i] = TypeLattice(lastatype)
175175
end
176176
else
177177
@assert nargs == 0 "invalid specialization of method" # wrong number of arguments
@@ -199,7 +199,7 @@ function matching_cache_argtypes(linfo::MethodInstance, ::Nothing, va_override::
199199
return cache_argtypes, falses(length(cache_argtypes))
200200
end
201201

202-
function cache_lookup(linfo::MethodInstance, given_argtypes::Vector{Any}, cache::Vector{InferenceResult})
202+
function cache_lookup(linfo::MethodInstance, given_argtypes::Vector{AbstractLattice}, cache::Vector{InferenceResult})
203203
method = linfo.def::Method
204204
nargs::Int = method.nargs
205205
method.isva && (nargs -= 1)
@@ -218,7 +218,7 @@ function cache_lookup(linfo::MethodInstance, given_argtypes::Vector{Any}, cache:
218218
end
219219
end
220220
if method.isva && cache_match
221-
cache_match = is_argtype_match(tuple_tfunc(given_argtypes[(nargs + 1):end]),
221+
cache_match = is_argtype_match(TypeLattice(tuple_tfunc(anymap(unwraptype, given_argtypes[(nargs + 1):end]))),
222222
cache_argtypes[end],
223223
cache_overridden_by_const[end])
224224
end

base/compiler/inferencestate.jl

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ const LineNum = Int
55
# The type of a variable load is either a value or an UndefVarError
66
# (only used in abstractinterpret, doesn't appear in optimize)
77
struct VarState
8-
typ
8+
typ::AbstractLattice
99
undef::Bool
10-
VarState(@nospecialize(typ), undef::Bool) = new(typ, undef)
10+
@latticeop args VarState(@nospecialize(typ), undef::Bool) = new(typ, undef)
1111
end
1212

1313
"""
@@ -24,8 +24,8 @@ mutable struct InferenceState
2424
params::InferenceParams
2525
result::InferenceResult # remember where to put the result
2626
linfo::MethodInstance
27-
sptypes::Vector{Any} # types of static parameter
28-
slottypes::Vector{Any}
27+
sptypes::Vector{AbstractLattice} # types of static parameter
28+
slottypes::Vector{AbstractLattice}
2929
mod::Module
3030
currpc::LineNum
3131
pclimitations::IdSet{InferenceState} # causes of precision restrictions (LimitedAccuracy) on currpc ssavalue
@@ -40,7 +40,7 @@ mutable struct InferenceState
4040
stmt_edges::Vector{Union{Nothing, Vector{Any}}}
4141
stmt_info::Vector{Any}
4242
# return type
43-
bestguess #::Type
43+
bestguess::AbstractLattice
4444
# current active instruction pointers
4545
ip::BitSet
4646
pc´´::LineNum
@@ -79,7 +79,7 @@ mutable struct InferenceState
7979
sp = sptypes_from_meth_instance(linfo::MethodInstance)
8080

8181
nssavalues = src.ssavaluetypes::Int
82-
src.ssavaluetypes = Any[ NOT_FOUND for i = 1:nssavalues ]
82+
src.ssavaluetypes = AbstractLattice[ NOT_FOUND for i = 1:nssavalues ]
8383
stmt_info = Any[ nothing for i = 1:length(code) ]
8484

8585
n = length(code)
@@ -91,9 +91,9 @@ mutable struct InferenceState
9191
argtypes = result.argtypes
9292
nargs = length(argtypes)
9393
s_argtypes = VarTable(undef, nslots)
94-
slottypes = Vector{Any}(undef, nslots)
94+
slottypes = Vector{AbstractLattice}(undef, nslots)
9595
for i in 1:nslots
96-
at = (i > nargs) ? Bottom : argtypes[i]
96+
at = (i > nargs) ? : TypeLattice(argtypes[i])
9797
s_argtypes[i] = VarState(at, i > nargs)
9898
slottypes[i] = at
9999
end
@@ -120,7 +120,7 @@ mutable struct InferenceState
120120
IdSet{InferenceState}(), IdSet{InferenceState}(),
121121
src, get_world_counter(interp), valid_worlds,
122122
nargs, s_types, s_edges, stmt_info,
123-
Union{}, ip, 1, n, handler_at,
123+
, ip, 1, n, handler_at,
124124
ssavalue_uses,
125125
Vector{Tuple{InferenceState,LineNum}}(), # cycle_backedges
126126
Vector{InferenceState}(), # callers_in_cycle
@@ -316,9 +316,9 @@ function sptypes_from_meth_instance(linfo::MethodInstance)
316316
else
317317
ty = Const(v)
318318
end
319-
sp[i] = ty
319+
sp[i] = TypeLattice(ty)
320320
end
321-
return sp
321+
return collect(AbstractLattice, sp)
322322
end
323323

324324
_topmod(sv::InferenceState) = _topmod(sv.mod)
@@ -332,8 +332,8 @@ end
332332

333333
update_valid_age!(edge::InferenceState, sv::InferenceState) = update_valid_age!(sv, edge.valid_worlds)
334334

335-
function record_ssa_assign(ssa_id::Int, @nospecialize(new), frame::InferenceState)
336-
ssavaluetypes = frame.src.ssavaluetypes::Vector{Any}
335+
@latticeop args function record_ssa_assign(ssa_id::Int, @nospecialize(new), frame::InferenceState)
336+
ssavaluetypes = frame.src.ssavaluetypes::Vector{AbstractLattice}
337337
old = ssavaluetypes[ssa_id]
338338
if old === NOT_FOUND || !(new old)
339339
# typically, we expect that old ⊑ new (that output information only

0 commit comments

Comments
 (0)