Skip to content

Commit b4fe1cc

Browse files
committed
Refactor lattice code to expose layering and enable easy extension
There's been two threads of work involving the compiler's notion of the inference lattice. One is that the lattice has gotten to complicated and with too many internal constraints that are not manifest in the type system. #42596 attempted to address this, but it's quite disruptive as it changes the lattice types and all the signatures of the lattice operations, which are used quite extensively throughout the ecosystem (despite being internal), so that change is quite disruptive (and something we'd ideally only make the ecosystem do once). The other thread of work is that people would like to experiment with a variety of extended lattices outside of base (either to prototype potential additions to the lattice in base or to do custom abstract interpretation over the julia code). At the moment, the lattice is quite closely interwoven with the rest of the abstract interpreter. In response to this request in #40992, I had proposed a `CustomLattice` element with callbacks, but this doesn't compose particularly well, is cumbersome and imposes overhead on some of the hottest parts of the compiler, so it's a bit of a tough sell to merge into Base. In this PR, I'd like to propose a refactoring that is relatively non-invasive to non-Base users, but I think would allow easier experimentation with changes to the lattice for these two use cases. In essence, we're splitting the lattice into a ladder of 5 different lattices, each containing the previous lattice as a sub-lattice. These 5 lattices are: - JLTypeLattice (Anything that's a `Type`) - ConstsLattice ( + `Const`, `PartialTypeVar`) - PartialsLattice ( + `PartialStruct` ) - ConditionalsLattice ( + `Conditional` ) - InferenceLattice ( + `LimitedAccuracy`, `MaybeUndef` ) The idea is that where a lattice element contains another lattice element (e.g. in `PartialStruct` or `Conditional`), the element contained may only be from a wider lattice. In this PR, this is not enforced by the type system. This is quite deliberate, as I want to retain the types and object layouts of the lattice elements, but of course a future #42596-like change could add such type enforcement. Of particular note is that the `PartialsLattice` and `ConditionalsLattice` is parameterized and additional layers may be added in the stack. For example, in #40992, I had proposed a lattice element that refines `Int` and tracks symbolic expressions. In this setup, this could be accomplished by adding an appropriate lattice in between the `ConstsLattice` and the `PartialsLattice` (of course, additional hooks would be required to make the tfuncs work, but that is outside the scope of this PR). I don't think this is a full solution, but I think it'll help us play with some of these extended lattice options over the next 6-12 months in the packages that want to do this sort of thing. Presumably once we know what all the potential lattice extensions look like, we will want to take another look at this (likely together with whatever solution we come up with for the AbstractInterpreter composability problem and a rebase of #42596). WIP because I didn't bother updating and plumbing through the lattice in all the call sites yet, but that's mostly mechanical, so if we like this direction, I will make that change and hope to merge this in short order (because otherwise it'll accumulate massive merge conflicts).
1 parent 8fa066b commit b4fe1cc

File tree

14 files changed

+496
-212
lines changed

14 files changed

+496
-212
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 66 additions & 54 deletions
Large diffs are not rendered by default.

base/compiler/abstractlattice.jl

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
abstract type AbstractLattice; end
2+
function widen end
3+
4+
"""
5+
struct JLTypeLattice
6+
7+
A singleton type representing the lattice of Julia types, without any inference
8+
extensions.
9+
"""
10+
struct JLTypeLattice <: AbstractLattice; end
11+
widen(::JLTypeLattice) = error("Type lattice is the least-precise lattice available")
12+
is_valid_lattice(::JLTypeLattice, @nospecialize(elem)) = isa(elem, Type)
13+
14+
"""
15+
struct ConstsLattice
16+
17+
A lattice extending `JLTypeLattice` and adjoining `Const` and `PartialTypeVar`.
18+
"""
19+
struct ConstsLattice <: AbstractLattice; end
20+
widen(::ConstsLattice) = JLTypeLattice()
21+
is_valid_lattice(lattice::ConstsLattice, @nospecialize(elem)) =
22+
is_valid_lattice(widen(lattice), elem) || isa(elem, Const) || isa(elem, PartialTypeVar)
23+
24+
"""
25+
struct PartialsLattice{L}
26+
27+
A lattice extending lattice `L` and adjoining `PartialStruct` and `PartialOpaque`.
28+
"""
29+
struct PartialsLattice{L <: AbstractLattice} <: AbstractLattice
30+
parent::L
31+
end
32+
widen(L::PartialsLattice) = L.parent
33+
is_valid_lattice(lattice::PartialsLattice, @nospecialize(elem)) =
34+
is_valid_lattice(widen(lattice), elem) ||
35+
isa(elem, PartialStruct) || isa(elem, PartialOpaque)
36+
37+
"""
38+
struct ConditionalsLattice{L}
39+
40+
A lattice extending lattice `L` and adjoining `Conditional`.
41+
"""
42+
struct ConditionalsLattice{L <: AbstractLattice} <: AbstractLattice
43+
parent::L
44+
end
45+
widen(L::ConditionalsLattice) = L.parent
46+
is_valid_lattice(lattice::ConditionalsLattice, @nospecialize(elem)) =
47+
is_valid_lattice(widen(lattice), elem) || isa(elem, Conditional)
48+
49+
struct InterConditionalsLattice{L <: AbstractLattice} <: AbstractLattice
50+
parent::L
51+
end
52+
widen(L::InterConditionalsLattice) = L.parent
53+
is_valid_lattice(lattice::InterConditionalsLattice, @nospecialize(elem)) =
54+
is_valid_lattice(widen(lattice), elem) || isa(elem, InterConditional)
55+
56+
const AnyConditionalsLattice{L} = Union{ConditionalsLattice{L}, InterConditionalsLattice{L}}
57+
const BaseInferenceLattice = typeof(ConditionalsLattice(PartialsLattice(ConstsLattice())))
58+
const IPOResultLattice = typeof(InterConditionalsLattice(PartialsLattice(ConstsLattice())))
59+
60+
"""
61+
struct OptimizerLattice
62+
63+
The lattice used by the optimizer. Extends
64+
`BaseInferenceLattice` with `MaybeUndef`.
65+
"""
66+
struct OptimizerLattice <: AbstractLattice; end
67+
widen(L::OptimizerLattice) = BaseInferenceLattice.instance
68+
is_valid_lattice(lattice::OptimizerLattice, @nospecialize(elem)) =
69+
is_valid_lattice(widen(lattice), elem) || isa(elem, MaybeUndef)
70+
71+
"""
72+
struct InferenceLattice{L}
73+
74+
The full lattice used for abstract interpration during inference. Takes
75+
a base lattice and adjoins `LimitedAccuracy`.
76+
"""
77+
struct InferenceLattice{L} <: AbstractLattice
78+
parent::L
79+
end
80+
widen(L::InferenceLattice) = L.parent
81+
is_valid_lattice(lattice::InferenceLattice, @nospecialize(elem)) =
82+
is_valid_lattice(widen(lattice), elem) || isa(elem, LimitedAccuracy)
83+
84+
"""
85+
tmeet(lattice, a, b::Type)
86+
87+
Compute the lattice meet of lattice elements `a` and `b` over the lattice
88+
`lattice`. If `lattice` is `JLTypeLattice`, this is equiavalent to type
89+
intersection. Note that currently `b` is restricted to being a type (interpreted
90+
as a lattice element in the JLTypeLattice sub-lattice of `lattice`).
91+
"""
92+
function tmeet end
93+
94+
function tmeet(::JLTypeLattice, @nospecialize(a::Type), @nospecialize(b::Type))
95+
ti = typeintersect(a, b)
96+
valid_as_lattice(ti) || return Bottom
97+
return ti
98+
end
99+
100+
"""
101+
tmerge(lattice, a, b)
102+
103+
Compute a lattice join of elements `a` and `b` over the lattice `lattice`.
104+
Note that the computed element need not be the least upper bound of `a` and
105+
`b`, but rather, we impose some heuristic limits on the complexity of the
106+
joined element, ideally without losing too much precision in common cases and
107+
remaining mostly associative and commutative.
108+
"""
109+
function tmerge end
110+
111+
"""
112+
⊑(lattice, a, b)
113+
114+
Compute the lattice ordering (i.e. less-than-or-equal) relationship between
115+
lattice elements `a` and `b` over the lattice `lattice`. If `lattice` is
116+
`JLTypeLattice`, this is equiavalent to subtyping.
117+
"""
118+
function end
119+
120+
(::JLTypeLattice, @nospecialize(a::Type), @nospecialize(b::Type)) = a <: b
121+
122+
"""
123+
⊏(lattice, a, b) -> Bool
124+
125+
The strict partial order over the type inference lattice.
126+
This is defined as the irreflexive kernel of `⊑`.
127+
"""
128+
(lattice::AbstractLattice, @nospecialize(a), @nospecialize(b)) = (lattice, a, b) && !(lattice, b, a)
129+
130+
"""
131+
⋤(lattice, a, b) -> Bool
132+
133+
This order could be used as a slightly more efficient version of the strict order `⊏`,
134+
where we can safely assume `a ⊑ b` holds.
135+
"""
136+
(lattice::AbstractLattice, @nospecialize(a), @nospecialize(b)) = !(lattice, b, a)
137+
138+
"""
139+
is_lattice_equal(lattice, a, b) -> Bool
140+
141+
Check if two lattice elements are partial order equivalent.
142+
This is basically `a ⊑ b && b ⊑ a` but (optionally) with extra performance optimizations.
143+
"""
144+
function is_lattice_equal(lattice::AbstractLattice, @nospecialize(a), @nospecialize(b))
145+
a === b && return true
146+
(lattice, a, b) && (lattice, b, a)
147+
end
148+
149+
"""
150+
has_nontrivial_const_info(lattice, t) -> Bool
151+
152+
Determine whether the given lattice element `t` of `lattice` has non-trivial
153+
constant information that would not be available from the type itself.
154+
"""
155+
has_nontrivial_const_info(lattice::AbstractLattice, @nospecialize t) =
156+
has_nontrivial_const_info(widen(lattice), t)
157+
has_nontrivial_const_info(::JLTypeLattice, @nospecialize(t)) = false
158+
159+
# Curried versions
160+
(lattice::AbstractLattice) = (@nospecialize(a), @nospecialize(b)) -> (lattice, a, b)
161+
(lattice::AbstractLattice) = (@nospecialize(a), @nospecialize(b)) -> (lattice, a, b)
162+
(lattice::AbstractLattice) = (@nospecialize(a), @nospecialize(b)) -> (lattice, a, b)
163+
164+
# Fallbacks for external packages using these methods
165+
const fallback_lattice = InferenceLattice(BaseInferenceLattice.instance)
166+
const fallback_ipo_lattice = InferenceLattice(IPOResultLattice.instance)
167+
168+
(@nospecialize(a), @nospecialize(b)) = (fallback_lattice, a, b)
169+
tmeet(@nospecialize(a), @nospecialize(b)) = tmeet(fallback_lattice, a, b)
170+
tmerge(@nospecialize(a), @nospecialize(b)) = tmerge(fallback_lattice, a, b)
171+
(@nospecialize(a), @nospecialize(b)) = (fallback_lattice, a, b)
172+
(@nospecialize(a), @nospecialize(b)) = (fallback_lattice, a, b)
173+
is_lattice_equal(@nospecialize(a), @nospecialize(b)) = is_lattice_equal(fallback_lattice, a, b)

base/compiler/compiler.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ include("compiler/ssair/ir.jl")
156156
include("compiler/inferenceresult.jl")
157157
include("compiler/inferencestate.jl")
158158

159+
include("compiler/abstractlattice.jl")
159160
include("compiler/typeutils.jl")
160161
include("compiler/typelimits.jl")
161162
include("compiler/typelattice.jl")

base/compiler/optimize.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -211,11 +211,13 @@ Returns a tuple of (effect_free_and_nothrow, nothrow) for a given statement.
211211
"""
212212
function stmt_effect_flags(@nospecialize(stmt), @nospecialize(rt), src::Union{IRCode,IncrementalCompact})
213213
# TODO: We're duplicating analysis from inference here.
214+
lattice = OptimizerLattice()
215+
= (lattice)
214216
isa(stmt, PiNode) && return (true, true)
215217
isa(stmt, PhiNode) && return (true, true)
216218
isa(stmt, ReturnNode) && return (false, true)
217219
isa(stmt, GotoNode) && return (false, true)
218-
isa(stmt, GotoIfNot) && return (false, argextype(stmt.cond, src) Bool)
220+
isa(stmt, GotoIfNot) && return (false, argextype(stmt.cond, src) Bool)
219221
isa(stmt, Slot) && return (false, false) # Slots shouldn't occur in the IR at this point, but let's be defensive here
220222
if isa(stmt, GlobalRef)
221223
nothrow = isdefined(stmt.mod, stmt.name)
@@ -248,7 +250,7 @@ function stmt_effect_flags(@nospecialize(stmt), @nospecialize(rt), src::Union{IR
248250
return (total, total)
249251
end
250252
rt === Bottom && return (false, false)
251-
nothrow = _builtin_nothrow(f, Any[argextype(args[i], src) for i = 2:length(args)], rt)
253+
nothrow = _builtin_nothrow(f, Any[argextype(args[i], src) for i = 2:length(args)], rt, lattice)
252254
nothrow || return (false, false)
253255
return (contains_is(_EFFECT_FREE_BUILTINS, f), nothrow)
254256
elseif head === :new
@@ -262,7 +264,7 @@ function stmt_effect_flags(@nospecialize(stmt), @nospecialize(rt), src::Union{IR
262264
for fld_idx in 1:(length(args) - 1)
263265
eT = argextype(args[fld_idx + 1], src)
264266
fT = fieldtype(typ, fld_idx)
265-
eT fT || return (false, false)
267+
eT fT || return (false, false)
266268
end
267269
return (true, true)
268270
elseif head === :foreigncall
@@ -277,11 +279,11 @@ function stmt_effect_flags(@nospecialize(stmt), @nospecialize(rt), src::Union{IR
277279
typ = argextype(args[1], src)
278280
typ, isexact = instanceof_tfunc(typ)
279281
isexact || return (false, false)
280-
typ Tuple || return (false, false)
282+
typ Tuple || return (false, false)
281283
rt_lb = argextype(args[2], src)
282284
rt_ub = argextype(args[3], src)
283285
source = argextype(args[4], src)
284-
if !(rt_lb Type && rt_ub Type && source Method)
286+
if !(rt_lb Type && rt_ub Type && source Method)
285287
return (false, false)
286288
end
287289
return (true, true)

base/compiler/ssair/inlining.jl

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ function CFGInliningState(ir::IRCode)
114114
)
115115
end
116116

117+
(@nospecialize(a), @nospecialize(b)) = (OptimizerLattice(), a, b)
118+
117119
# Tells the inliner that we're now inlining into block `block`, meaning
118120
# all previous blocks have been processed and can be added to the new cfg
119121
function inline_into_block!(state::CFGInliningState, block::Int)
@@ -381,7 +383,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
381383
nonva_args = argexprs[1:end-1]
382384
va_arg = argexprs[end]
383385
tuple_call = Expr(:call, TOP_TUPLE, def, nonva_args...)
384-
tuple_type = tuple_tfunc(Any[argextype(arg, compact) for arg in nonva_args])
386+
tuple_type = tuple_tfunc(OptimizerLattice(), Any[argextype(arg, compact) for arg in nonva_args])
385387
tupl = insert_node_here!(compact, NewInstruction(tuple_call, tuple_type, topline))
386388
apply_iter_expr = Expr(:call, Core._apply_iterate, iterate, Core._compute_sparams, tupl, va_arg)
387389
sparam_vals = insert_node_here!(compact,
@@ -476,7 +478,7 @@ function fix_va_argexprs!(compact::IncrementalCompact,
476478
push!(tuple_call.args, arg)
477479
push!(tuple_typs, argextype(arg, compact))
478480
end
479-
tuple_typ = tuple_tfunc(tuple_typs)
481+
tuple_typ = tuple_tfunc(OptimizerLattice(), tuple_typs)
480482
tuple_inst = NewInstruction(tuple_call, tuple_typ, line_idx)
481483
push!(newargexprs, insert_node_here!(compact, tuple_inst))
482484
return newargexprs
@@ -1080,8 +1082,8 @@ function inline_apply!(
10801082
nonempty_idx = 0
10811083
for i = (arg_start + 1):length(argtypes)
10821084
ti = argtypes[i]
1083-
ti Tuple{} && continue
1084-
if ti Tuple && nonempty_idx == 0
1085+
ti Tuple{} && continue
1086+
if ti Tuple && nonempty_idx == 0
10851087
nonempty_idx = i
10861088
continue
10871089
end
@@ -1123,9 +1125,9 @@ end
11231125
# TODO: this test is wrong if we start to handle Unions of function types later
11241126
is_builtin(s::Signature) =
11251127
isa(s.f, IntrinsicFunction) ||
1126-
s.ft IntrinsicFunction ||
1128+
s.ft IntrinsicFunction ||
11271129
isa(s.f, Builtin) ||
1128-
s.ft Builtin
1130+
s.ft Builtin
11291131

11301132
function inline_invoke!(
11311133
ir::IRCode, idx::Int, stmt::Expr, info::InvokeCallInfo, flag::UInt8,
@@ -1165,7 +1167,7 @@ function narrow_opaque_closure!(ir::IRCode, stmt::Expr, @nospecialize(info), sta
11651167
ub, exact = instanceof_tfunc(ubt)
11661168
exact || return
11671169
# Narrow opaque closure type
1168-
newT = widenconst(tmeet(tmerge(lb, info.unspec.rt), ub))
1170+
newT = widenconst(tmeet(OptimizerLattice(), tmerge(OptimizerLattice(), lb, info.unspec.rt), ub))
11691171
if newT != ub
11701172
# N.B.: Narrowing the ub requires a backdge on the mi whose type
11711173
# information we're using, since a change in that function may
@@ -1222,7 +1224,7 @@ function process_simple!(ir::IRCode, idx::Int, state::InliningState, todo::Vecto
12221224
ir.stmts[idx][:inst] = earlyres.val
12231225
return nothing
12241226
end
1225-
if (sig.f === modifyfield! || sig.ft typeof(modifyfield!)) && 5 <= length(stmt.args) <= 6
1227+
if (sig.f === modifyfield! || sig.ft typeof(modifyfield!)) && 5 <= length(stmt.args) <= 6
12261228
let info = ir.stmts[idx][:info]
12271229
info isa MethodResultPure && (info = info.info)
12281230
info isa ConstCallInfo && (info = info.call)
@@ -1240,7 +1242,7 @@ function process_simple!(ir::IRCode, idx::Int, state::InliningState, todo::Vecto
12401242
end
12411243

12421244
if check_effect_free!(ir, idx, stmt, rt)
1243-
if sig.f === typeassert || sig.ft typeof(typeassert)
1245+
if sig.f === typeassert || sig.ft typeof(typeassert)
12441246
# typeassert is a no-op if effect free
12451247
ir.stmts[idx][:inst] = stmt.args[2]
12461248
return nothing
@@ -1637,7 +1639,7 @@ function early_inline_special_case(
16371639
elseif ispuretopfunction(f) || contains_is(_PURE_BUILTINS, f)
16381640
return SomeCase(quoted(val))
16391641
elseif contains_is(_EFFECT_FREE_BUILTINS, f)
1640-
if _builtin_nothrow(f, argtypes[2:end], type)
1642+
if _builtin_nothrow(f, argtypes[2:end], type, OptimizerLattice())
16411643
return SomeCase(quoted(val))
16421644
end
16431645
elseif f === Core.get_binding_type
@@ -1683,17 +1685,17 @@ function late_inline_special_case!(
16831685
elseif length(argtypes) == 3 && istopfunction(f, :(>:))
16841686
# special-case inliner for issupertype
16851687
# that works, even though inference generally avoids inferring the `>:` Method
1686-
if isa(type, Const) && _builtin_nothrow(<:, Any[argtypes[3], argtypes[2]], type)
1688+
if isa(type, Const) && _builtin_nothrow(<:, Any[argtypes[3], argtypes[2]], type, OptimizerLattice())
16871689
return SomeCase(quoted(type.val))
16881690
end
16891691
subtype_call = Expr(:call, GlobalRef(Core, :(<:)), stmt.args[3], stmt.args[2])
16901692
return SomeCase(subtype_call)
1691-
elseif f === TypeVar && 2 <= length(argtypes) <= 4 && (argtypes[2] Symbol)
1693+
elseif f === TypeVar && 2 <= length(argtypes) <= 4 && (argtypes[2] Symbol)
16921694
typevar_call = Expr(:call, GlobalRef(Core, :_typevar), stmt.args[2],
16931695
length(stmt.args) < 4 ? Bottom : stmt.args[3],
16941696
length(stmt.args) == 2 ? Any : stmt.args[end])
16951697
return SomeCase(typevar_call)
1696-
elseif f === UnionAll && length(argtypes) == 3 && (argtypes[2] TypeVar)
1698+
elseif f === UnionAll && length(argtypes) == 3 && (argtypes[2] TypeVar)
16971699
unionall_call = Expr(:foreigncall, QuoteNode(:jl_type_unionall), Any, svec(Any, Any),
16981700
0, QuoteNode(:ccall), stmt.args[2], stmt.args[3])
16991701
return SomeCase(unionall_call)

0 commit comments

Comments
 (0)