Skip to content

Commit 4baf148

Browse files
committed
lattice overhaul step 2: convert existing extended lattice wrappers to LatticeElement attributes
- pack `PartialStruct` into `LatticeElement.fields` - pack `Conditional`/`InterConditional` into `LatticeElement.conditional` - pack `Const` into `LatticeElement.constant` - pack `PartialTypeVar` into `LatticeElement.partialtypevar` - pack `LimitedAccuracy` into `LatticeElement.causes` - pack `PartialOpaque` into `LatticeElement.partialopaque` - pack `MaybeUndef` into `LatticeElement.maybeundef` - merge `LatticeElement.partialopaque` and `LatticeElement.partialopaque` There is not much value in keeping them separate, since a variable usually doesn't have these "special" attributes at the same time. - wrap `Vararg` in `LatticeElement.special::Vararg` - add HACK to allow `DelayedTyp` to sneak in `LatticeElement` system And now we can eliminate `AbstractLattice`, and our inference code works with `LatticeElement` (mostly). - define `SSAValueType(s)` / `Argtypes` aliases
1 parent 306b6b0 commit 4baf148

29 files changed

+1463
-1133
lines changed

base/boot.jl

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -421,10 +421,7 @@ eval(Core, :(CodeInstance(mi::MethodInstance, @nospecialize(rettype), @nospecial
421421
min_world::UInt, max_world::UInt) =
422422
ccall(:jl_new_codeinst, Ref{CodeInstance}, (Any, Any, Any, Any, Int32, UInt, UInt),
423423
mi, rettype, inferred_const, inferred, const_flags, min_world, max_world)))
424-
eval(Core, :(Const(@nospecialize(v)) = $(Expr(:new, :Const, :v))))
425-
eval(Core, :(PartialStruct(@nospecialize(typ), fields::Array{Any, 1}) = $(Expr(:new, :PartialStruct, :typ, :fields))))
426424
eval(Core, :(PartialOpaque(@nospecialize(typ), @nospecialize(env), isva::Bool, parent::MethodInstance, source::Method) = $(Expr(:new, :PartialOpaque, :typ, :env, :isva, :parent, :source))))
427-
eval(Core, :(InterConditional(slot::Int, @nospecialize(vtype), @nospecialize(elsetype)) = $(Expr(:new, :InterConditional, :slot, :vtype, :elsetype))))
428425
eval(Core, :(MethodMatch(@nospecialize(spec_types), sparams::SimpleVector, method::Method, fully_covers::Bool) =
429426
$(Expr(:new, :MethodMatch, :spec_types, :sparams, :method, :fully_covers))))
430427

@@ -495,13 +492,11 @@ Symbol(s::Symbol) = s
495492
module IR
496493
export CodeInfo, MethodInstance, CodeInstance, GotoNode, GotoIfNot, ReturnNode,
497494
NewvarNode, SSAValue, Slot, SlotNumber, TypedSlot, Argument,
498-
PiNode, PhiNode, PhiCNode, UpsilonNode, LineInfoNode,
499-
Const, PartialStruct
495+
PiNode, PhiNode, PhiCNode, UpsilonNode, LineInfoNode
500496

501497
import Core: CodeInfo, MethodInstance, CodeInstance, GotoNode, GotoIfNot, ReturnNode,
502498
NewvarNode, SSAValue, Slot, SlotNumber, TypedSlot, Argument,
503-
PiNode, PhiNode, PhiCNode, UpsilonNode, LineInfoNode,
504-
Const, PartialStruct
499+
PiNode, PhiNode, PhiCNode, UpsilonNode, LineInfoNode
505500

506501
end
507502

base/compiler/abstractinterpretation.jl

Lines changed: 271 additions & 259 deletions
Large diffs are not rendered by default.

base/compiler/compiler.jl

Lines changed: 64 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -114,35 +114,11 @@ something(x::Any, y...) = x
114114
# compiler #
115115
############
116116

117-
# TODO remove me in the future, this is just to check the coverage of the overhaul
118-
import Core: Const, PartialStruct, InterConditional, PartialOpaque, TypeofVararg
119-
abstract type _AbstractLattice end
120-
const AbstractLattice = Union{
121-
Const, PartialStruct, InterConditional, PartialOpaque, TypeofVararg,
122-
_AbstractLattice}
123-
const Argtypes = Vector{AbstractLattice}
124-
125-
macro latticeop(mode, def)
126-
@assert is_function_def(def)
127-
sig, body = def.args
128-
if mode === :args || mode === :op
129-
nospecs = Symbol[]
130-
for arg in sig.args
131-
if isexpr(arg, :macrocall) && arg.args[1] === Symbol("@nospecialize")
132-
push!(nospecs, arg.args[3])
133-
end
134-
end
135-
idx = findfirst(x->!isa(x, LineNumberNode), body.args)
136-
for var in nospecs
137-
insert!(body.args, idx, Expr(:(=), var, Expr(:(::), var, :AbstractLattice)))
138-
end
139-
end
140-
if mode === :ret || mode === :op
141-
sig = Expr(:(::), sig, :AbstractLattice)
142-
end
143-
return esc(Expr(def.head, sig, body))
144-
end
145-
anymap(f::Function, a::Vector{AbstractLattice}) = Any[ f(a[i]) for i in 1:length(a) ]
117+
include("compiler/typelattice.jl")
118+
119+
const Argtypes = Vector{LatticeElement}
120+
const EMPTY_SLOTTYPES = Argtypes()
121+
anymap(f::Function, a::Argtypes) = Any[ f(a[i]) for i in 1:length(a) ]
146122

147123
include("compiler/cicache.jl")
148124
include("compiler/types.jl")
@@ -155,7 +131,6 @@ include("compiler/inferencestate.jl")
155131

156132
include("compiler/typeutils.jl")
157133
include("compiler/typelimits.jl")
158-
include("compiler/typelattice.jl")
159134
include("compiler/tfuncs.jl")
160135
include("compiler/stmtinfo.jl")
161136

@@ -176,6 +151,65 @@ function extrema(x::Array)
176151
return vmin, vmax
177152
end
178153

154+
# function show(io::IO, xs::Vector)
155+
# print(io, eltype(xs), '[')
156+
# show_itr(io, xs)
157+
# print(io, ']')
158+
# end
159+
# function show(io::IO, xs::Tuple)
160+
# print(io, '(')
161+
# show_itr(io, xs)
162+
# print(io, ')')
163+
# end
164+
# function show_itr(io::IO, xs)
165+
# n = length(xs)
166+
# for i in 1:n
167+
# show(io, xs[i])
168+
# i == n || print(io, ", ")
169+
# end
170+
# end
171+
# function show(io::IO, typ′::LatticeElement)
172+
# function name(x)
173+
# if isLimitedAccuracy(typ′)
174+
# return (nameof(x), '′',)
175+
# else
176+
# return (nameof(x),)
177+
# end
178+
# end
179+
# typ = ignorelimited(typ′)
180+
# if isConditional(typ)
181+
# show(io, conditional(typ))
182+
# elseif isConst(typ)
183+
# print(io, name(Const)..., '(', constant(typ), ')')
184+
# elseif isPartialStruct(typ)
185+
# print(io, name(PartialStruct)..., '(', widenconst(typ), ", [")
186+
# n = length(partialfields(typ))
187+
# for i in 1:n
188+
# show(io, partialfields(typ)[i])
189+
# i == n || print(io, ", ")
190+
# end
191+
# print(io, "])")
192+
# elseif isPartialTypeVar(typ)
193+
# print(io, name(PartialTypeVar)..., '(')
194+
# show(io, typ.partialtypevar.tv)
195+
# print(io, ')')
196+
# else
197+
# print(io, name(NativeType)..., '(', widenconst(typ), ')')
198+
# end
199+
# end
200+
# function show(io::IO, typ::ConditionalInfo)
201+
# if typ === __NULL_CONDITIONAL__
202+
# return print(io, "__NULL_CONDITIONAL__")
203+
# end
204+
# print(io, nameof(Conditional), '(')
205+
# show(io, typ.var)
206+
# print(io, ", ")
207+
# show(io, typ.vtype)
208+
# print(io, ", ")
209+
# show(io, typ.elsetype)
210+
# print(io, ')')
211+
# end
212+
179213
include("compiler/bootstrap.jl")
180214
ccall(:jl_set_typeinf_func, Cvoid, (Any,), typeinf_ext_toplevel)
181215

base/compiler/inferenceresult.jl

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# This file is a part of Julia. License is MIT: https://julialang.org/license
22

3-
@latticeop args function is_argtype_match(@nospecialize(given_argtype),
4-
@nospecialize(cache_argtype),
3+
function is_argtype_match(given_argtype::LatticeElement,
4+
cache_argtype::LatticeElement,
55
overridden_by_const::Bool)
66
if is_forwardable_argtype(given_argtype)
77
return is_lattice_equal(given_argtype, cache_argtype)
@@ -10,10 +10,10 @@
1010
end
1111

1212
function is_forwardable_argtype(@nospecialize x)
13-
return isa(x, Const) ||
14-
isa(x, Conditional) ||
15-
isa(x, PartialStruct) ||
16-
isa(x, PartialOpaque)
13+
return isConst(x) ||
14+
isConditional(x) ||
15+
isPartialStruct(x) ||
16+
isPartialOpaque(x)
1717
end
1818

1919
# In theory, there could be a `cache` containing a matching `InferenceResult`
@@ -26,13 +26,13 @@ function matching_cache_argtypes(
2626
@assert isa(linfo.def, Method) # ensure the next line works
2727
nargs::Int = linfo.def.nargs
2828
cache_argtypes, overridden_by_const = matching_cache_argtypes(linfo, nothing, va_override)
29-
given_argtypes = Vector{AbstractLattice}(undef, length(argtypes))
29+
given_argtypes = Vector{LatticeElement}(undef, length(argtypes))
3030
local condargs = nothing
3131
for i in 1:length(argtypes)
3232
argtype = argtypes[i]
3333
# forward `Conditional` if it conveys a constraint on any other argument
34-
if isa(argtype, Conditional) && fargs !== nothing
35-
cnd = argtype
34+
if isConditional(argtype) && fargs !== nothing
35+
cnd = conditional(argtype)
3636
slotid = find_constrained_arg(cnd, fargs, sv)
3737
if slotid !== nothing
3838
# using union-split signature, we may be able to narrow down `Conditional`
@@ -48,26 +48,26 @@ function matching_cache_argtypes(
4848
condargs = Tuple{Int,Int}[]
4949
end
5050
push!(condargs, (slotid, i))
51-
given_argtypes[i] = Conditional(SlotNumber(slotid), vtype, elsetype)
51+
given_argtypes[i] = Conditional(slotid, vtype, elsetype)
5252
end
5353
continue
5454
end
5555
end
5656
given_argtypes[i] = widenconditional(argtype)
5757
end
5858
isva = va_override || linfo.def.isva
59-
if isva || isvarargtype(unwraptype(given_argtypes[end]))
60-
isva_given_argtypes = Vector{Any}(undef, nargs)
59+
if isva || isVararg(given_argtypes[end])
60+
isva_given_argtypes = Vector{LatticeElement}(undef, nargs)
6161
for i = 1:(nargs - isva)
6262
isva_given_argtypes[i] = argtype_by_index(given_argtypes, i)
6363
end
6464
if isva
65-
if length(given_argtypes) < nargs && isvarargtype(unwraptype(given_argtypes[end]))
65+
if length(given_argtypes) < nargs && isVararg(given_argtypes[end])
6666
last = length(given_argtypes)
6767
else
6868
last = nargs
6969
end
70-
isva_given_argtypes[nargs] = TypeLattice(tuple_tfunc(anymap(unwraptype, given_argtypes[last:end])))
70+
isva_given_argtypes[nargs] = LatticeElement(tuple_tfunc(given_argtypes[last:end]))
7171
# invalidate `Conditional` imposed on varargs
7272
if condargs !== nothing
7373
for (slotid, i) in condargs
@@ -101,7 +101,7 @@ function most_general_argtypes(method::Union{Method, Nothing}, @nospecialize(spe
101101
# For opaque closure, the closure environment is processed elsewhere
102102
nargs -= 1
103103
end
104-
cache_argtypes = Vector{AbstractLattice}(undef, nargs)
104+
cache_argtypes = Vector{LatticeElement}(undef, nargs)
105105
# First, if we're dealing with a varargs method, then we set the last element of `args`
106106
# to the appropriate `Tuple` type or `PartialStruct` instance.
107107
if !toplevel && isva
@@ -110,23 +110,24 @@ function most_general_argtypes(method::Union{Method, Nothing}, @nospecialize(spe
110110
linfo_argtypes = Any[Any for i = 1:nargs]
111111
linfo_argtypes[end] = Vararg{Any}
112112
end
113-
vargtype = Tuple
113+
vargtype = NativeType(Tuple)
114114
else
115115
linfo_argtypes_length = length(linfo_argtypes)
116116
if nargs > linfo_argtypes_length
117117
va = linfo_argtypes[linfo_argtypes_length]
118118
if isvarargtype(va)
119119
new_va = rewrap_unionall(unconstrain_vararg_length(va), specTypes)
120-
vargtype = Tuple{new_va}
120+
vargtype = NativeType(Tuple{new_va})
121121
else
122-
vargtype = Tuple{}
122+
vargtype = NativeType(Tuple{})
123123
end
124124
else
125-
vargtype_elements = Any[]
125+
vargtype_elements = LatticeElement[]
126126
for i in nargs:linfo_argtypes_length
127127
p = linfo_argtypes[i]
128128
p = unwraptv(isvarargtype(p) ? unconstrain_vararg_length(p) : p)
129-
push!(vargtype_elements, elim_free_typevars(rewrap_unionall(p, specTypes)))
129+
p = elim_free_typevars(rewrap_unionall(p, specTypes))
130+
push!(vargtype_elements, NativeType(p))
130131
end
131132
for i in 1:length(vargtype_elements)
132133
atyp = vargtype_elements[i]
@@ -140,7 +141,7 @@ function most_general_argtypes(method::Union{Method, Nothing}, @nospecialize(spe
140141
vargtype = tuple_tfunc(vargtype_elements)
141142
end
142143
end
143-
cache_argtypes[nargs] = TypeLattice(vargtype)
144+
cache_argtypes[nargs] = vargtype
144145
nargs -= 1
145146
end
146147
# Now, we propagate type info from `linfo_argtypes` into `cache_argtypes`, improving some
@@ -168,10 +169,10 @@ function most_general_argtypes(method::Union{Method, Nothing}, @nospecialize(spe
168169
atyp = elim_free_typevars(rewrap_unionall(atyp, specTypes))
169170
end
170171
i == n && (lastatype = atyp)
171-
cache_argtypes[i] = TypeLattice(atyp)
172+
cache_argtypes[i] = LatticeElement(atyp)
172173
end
173174
for i = (tail_index + 1):nargs
174-
cache_argtypes[i] = TypeLattice(lastatype)
175+
cache_argtypes[i] = LatticeElement(lastatype)
175176
end
176177
else
177178
@assert nargs == 0 "invalid specialization of method" # wrong number of arguments
@@ -199,7 +200,7 @@ function matching_cache_argtypes(linfo::MethodInstance, ::Nothing, va_override::
199200
return cache_argtypes, falses(length(cache_argtypes))
200201
end
201202

202-
function cache_lookup(linfo::MethodInstance, given_argtypes::Vector{AbstractLattice}, cache::Vector{InferenceResult})
203+
function cache_lookup(linfo::MethodInstance, given_argtypes::Argtypes, cache::Vector{InferenceResult})
203204
method = linfo.def::Method
204205
nargs::Int = method.nargs
205206
method.isva && (nargs -= 1)
@@ -218,7 +219,7 @@ function cache_lookup(linfo::MethodInstance, given_argtypes::Vector{AbstractLatt
218219
end
219220
end
220221
if method.isva && cache_match
221-
cache_match = is_argtype_match(TypeLattice(tuple_tfunc(anymap(unwraptype, given_argtypes[(nargs + 1):end]))),
222+
cache_match = is_argtype_match(LatticeElement(tuple_tfunc(anymap(unwraptype, given_argtypes[(nargs + 1):end]))),
222223
cache_argtypes[end],
223224
cache_overridden_by_const[end])
224225
end

base/compiler/inferencestate.jl

Lines changed: 13 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,12 @@
22

33
const LineNum = Int
44

5-
# The type of a variable load is either a value or an UndefVarError
6-
# (only used in abstractinterpret, doesn't appear in optimize)
7-
struct VarState
8-
typ::AbstractLattice
9-
undef::Bool
10-
@latticeop args VarState(@nospecialize(typ), undef::Bool) = new(typ, undef)
11-
end
12-
13-
"""
14-
const VarTable = Vector{VarState}
15-
16-
The extended lattice that maps local variables to inferred type represented as `AbstractLattice`.
17-
Each index corresponds to the `id` of `SlotNumber` which identifies each local variable.
18-
Note that `InferenceState` will maintain multiple `VarTable`s at each SSA statement
19-
to enable flow-sensitive analysis.
20-
"""
21-
const VarTable = Vector{VarState}
22-
235
mutable struct InferenceState
246
params::InferenceParams
257
result::InferenceResult # remember where to put the result
268
linfo::MethodInstance
27-
sptypes::Vector{AbstractLattice} # types of static parameter
28-
slottypes::Vector{AbstractLattice}
9+
sptypes::Argtypes # types of static parameter
10+
slottypes::Argtypes
2911
mod::Module
3012
currpc::LineNum
3113
pclimitations::IdSet{InferenceState} # causes of precision restrictions (LimitedAccuracy) on currpc ssavalue
@@ -40,7 +22,7 @@ mutable struct InferenceState
4022
stmt_edges::Vector{Union{Nothing, Vector{Any}}}
4123
stmt_info::Vector{Any}
4224
# return type
43-
bestguess::AbstractLattice
25+
bestguess::LatticeElement
4426
# current active instruction pointers
4527
ip::BitSet
4628
pc´´::LineNum
@@ -79,7 +61,9 @@ mutable struct InferenceState
7961
sp = sptypes_from_meth_instance(linfo::MethodInstance)
8062

8163
nssavalues = src.ssavaluetypes::Int
82-
src.ssavaluetypes = AbstractLattice[ NOT_FOUND for i = 1:nssavalues ]
64+
# NOTE we can't initialize `src.ssavaluetypes` as `Argtypes` to avoid
65+
# an allocation within `ir_to_codeinf!(src)` where we widen all ssavaluetypes to native Julia types
66+
src.ssavaluetypes = Any[ NOT_FOUND for i = 1:nssavalues ]
8367
stmt_info = Any[ nothing for i = 1:length(code) ]
8468

8569
n = length(code)
@@ -91,9 +75,9 @@ mutable struct InferenceState
9175
argtypes = result.argtypes
9276
nargs = length(argtypes)
9377
s_argtypes = VarTable(undef, nslots)
94-
slottypes = Vector{AbstractLattice}(undef, nslots)
78+
slottypes = Vector{LatticeElement}(undef, nslots)
9579
for i in 1:nslots
96-
at = (i > nargs) ?: TypeLattice(argtypes[i])
80+
at = (i > nargs) ?: LatticeElement(argtypes[i])
9781
s_argtypes[i] = VarState(at, i > nargs)
9882
slottypes[i] = at
9983
end
@@ -316,9 +300,9 @@ function sptypes_from_meth_instance(linfo::MethodInstance)
316300
else
317301
ty = Const(v)
318302
end
319-
sp[i] = TypeLattice(ty)
303+
sp[i] = LatticeElement(ty)
320304
end
321-
return collect(AbstractLattice, sp)
305+
return collect(LatticeElement, sp)
322306
end
323307

324308
_topmod(sv::InferenceState) = _topmod(sv.mod)
@@ -332,9 +316,9 @@ end
332316

333317
update_valid_age!(edge::InferenceState, sv::InferenceState) = update_valid_age!(sv, edge.valid_worlds)
334318

335-
@latticeop args function record_ssa_assign(ssa_id::Int, @nospecialize(new), frame::InferenceState)
336-
ssavaluetypes = frame.src.ssavaluetypes::Vector{AbstractLattice}
337-
old = ssavaluetypes[ssa_id]
319+
function record_ssa_assign(ssa_id::Int, new::LatticeElement, frame::InferenceState)
320+
ssavaluetypes = frame.src.ssavaluetypes::SSAValueTypes
321+
old = ssavaluetypes[ssa_id]::SSAValueType
338322
if old === NOT_FOUND || !(new old)
339323
# typically, we expect that old ⊑ new (that output information only
340324
# gets less precise with worse input information), but to actually

0 commit comments

Comments
 (0)