Skip to content

Commit a6b2bd8

Browse files
committed
Move nargs/isva to CodeInfo
This changes the canonical source of truth for va handling from `Method` to `CodeInfo`. There are multiple goals for this change: 1. This addresses a longstanding complaint about the way that CodeInfo-returning generated functions work. Previously, the va-ness or not of the returned CodeInfo always had to match that of the generator. For Cassette-like transforms that generally have one big generator function that is varargs (while then looking up lowered code that is not varargs), this could become quite annoying. It's possible to workaround, but there is really no good reason to tie the two together. As we observed when we implemented OpaqueClosures, the vararg-ness of the signature and the `vararg arguments`->`tuple` transformation are mostly independent concepts. With this PR, generated functions can return CodeInfos with whatever combination of nargs/isva is convenient. 2. This change requires clarifying where the va processing boundary is in inference. #54076 was already moving in that direction for irinterp, and this essentially does much of the same for regular inference. As a consequence the constprop cache is now using non-va-cooked signatures, which I think is preferable. 3. This further decouples codegen from the presence of a `Method` (which is already not assumed, since the code being generated could be a toplevel thunk, but some codegen features are only available to things that come from Methods). There are a number of upcoming features that will require codegen of things that are not quite method specializations (See design doc linked in #52797 and things like #50641). This helps pave the road for that. 4. I've previously considered expanding the kinds of vararg signatures that can be described (see e.g. #53851), which also requires a decoupling of the signature and ast notions of vararg. This again lays the groundwork for that, although I have no immediate plans to implement this change. Impact wise, this adds an internal field, which is not too breaking, but downstream clients vary in how they construct their `CodeInfo`s and the current way they're doing it will likely be incorrect after this change, so they will require a small two-line adjustment. We should perhaps consider pulling out some of the more common patterns into a more stable package, since interface in most of the last few releases, but that's a separate issue.
1 parent 2cf469d commit a6b2bd8

23 files changed

+276
-187
lines changed

base/compiler/abstractinterpretation.jl

+18-27
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,12 @@ function conditional_argtype(𝕃ᵢ::AbstractLattice, @nospecialize(rt), @nospe
479479
if isa(rt, InterConditional) && rt.slot == i
480480
return rt
481481
else
482-
thentype = elsetype = tmeet(𝕃ᵢ, widenslotwrapper(argtypes[i]), fieldtype(sig, i))
482+
argt = widenslotwrapper(argtypes[i])
483+
if isvarargtype(argt)
484+
@assert fieldcount(sig) == i
485+
argt = unwrapva(argt)
486+
end
487+
thentype = elsetype = tmeet(𝕃ᵢ, argt, fieldtype(sig, i))
483488
condval = maybe_extract_const_bool(rt)
484489
condval === true && (elsetype = Bottom)
485490
condval === false && (thentype = Bottom)
@@ -986,15 +991,12 @@ function maybe_get_const_prop_profitable(interp::AbstractInterpreter,
986991
# N.B. remarks are emitted within `const_prop_entry_heuristic`
987992
return nothing
988993
end
989-
nargs::Int = method.nargs
990-
method.isva && (nargs -= 1)
991-
length(arginfo.argtypes) < nargs && return nothing
992994
if !const_prop_argument_heuristic(interp, arginfo, sv)
993995
add_remark!(interp, sv, "[constprop] Disabled by argument and rettype heuristics")
994996
return nothing
995997
end
996998
all_overridden = is_all_overridden(interp, arginfo, sv)
997-
if !force && !const_prop_function_heuristic(interp, f, arginfo, nargs, all_overridden, sv)
999+
if !force && !const_prop_function_heuristic(interp, f, arginfo, all_overridden, sv)
9981000
add_remark!(interp, sv, "[constprop] Disabled by function heuristic")
9991001
return nothing
10001002
end
@@ -1113,9 +1115,9 @@ function force_const_prop(interp::AbstractInterpreter, @nospecialize(f), method:
11131115
end
11141116

11151117
function const_prop_function_heuristic(interp::AbstractInterpreter, @nospecialize(f),
1116-
arginfo::ArgInfo, nargs::Int, all_overridden::Bool, sv::AbsIntState)
1118+
arginfo::ArgInfo, all_overridden::Bool, sv::AbsIntState)
11171119
argtypes = arginfo.argtypes
1118-
if nargs > 1
1120+
if length(argtypes) > 1
11191121
𝕃ᵢ = typeinf_lattice(interp)
11201122
if istopfunction(f, :getindex) || istopfunction(f, :setindex!)
11211123
arrty = argtypes[2]
@@ -1274,7 +1276,7 @@ function const_prop_call(interp::AbstractInterpreter,
12741276
end
12751277
overridden_by_const = falses(length(argtypes))
12761278
for i = 1:length(argtypes)
1277-
if argtypes[i] !== cache_argtypes[i]
1279+
if argtypes[i] !== argtype_by_index(cache_argtypes, i)
12781280
overridden_by_const[i] = true
12791281
end
12801282
end
@@ -1349,20 +1351,6 @@ function matching_cache_argtypes(𝕃::AbstractLattice, mi::MethodInstance,
13491351
end
13501352
given_argtypes[i] = widenslotwrapper(argtype)
13511353
end
1352-
if condargs !== nothing
1353-
given_argtypes = let condargs=condargs
1354-
va_process_argtypes(𝕃, given_argtypes, mi) do isva_given_argtypes::Vector{Any}, last::Int
1355-
# invalidate `Conditional` imposed on varargs
1356-
for (slotid, i) in condargs
1357-
if slotid last && (1 i length(isva_given_argtypes)) # `Conditional` is already widened to vararg-tuple otherwise
1358-
isva_given_argtypes[i] = widenconditional(isva_given_argtypes[i])
1359-
end
1360-
end
1361-
end
1362-
end
1363-
else
1364-
given_argtypes = va_process_argtypes(𝕃, given_argtypes, mi)
1365-
end
13661354
return pick_const_args!(𝕃, given_argtypes, cache_argtypes)
13671355
end
13681356

@@ -1721,7 +1709,7 @@ function abstract_apply(interp::AbstractInterpreter, argtypes::Vector{Any}, si::
17211709
return CallMeta(res, exct, effects, retinfo)
17221710
end
17231711

1724-
function argtype_by_index(argtypes::Vector{Any}, i::Int)
1712+
function argtype_by_index(argtypes::Vector{Any}, i::Integer)
17251713
n = length(argtypes)
17261714
na = argtypes[n]
17271715
if isvarargtype(na)
@@ -2890,12 +2878,12 @@ end
28902878
struct BestguessInfo{Interp<:AbstractInterpreter}
28912879
interp::Interp
28922880
bestguess
2893-
nargs::Int
2881+
nargs::UInt
28942882
slottypes::Vector{Any}
28952883
changes::VarTable
2896-
function BestguessInfo(interp::Interp, @nospecialize(bestguess), nargs::Int,
2884+
function BestguessInfo(interp::Interp, @nospecialize(bestguess), nargs::UInt,
28972885
slottypes::Vector{Any}, changes::VarTable) where Interp<:AbstractInterpreter
2898-
new{Interp}(interp, bestguess, nargs, slottypes, changes)
2886+
new{Interp}(interp, bestguess, Int(nargs), slottypes, changes)
28992887
end
29002888
end
29012889

@@ -2970,7 +2958,7 @@ end
29702958
# pick up the first "interesting" slot, convert `rt` to its `Conditional`
29712959
# TODO: ideally we want `Conditional` and `InterConditional` to convey
29722960
# constraints on multiple slots
2973-
for slot_id = 1:info.nargs
2961+
for slot_id = 1:Int(info.nargs)
29742962
rt = bool_rt_to_conditional(rt, slot_id, info)
29752963
rt isa InterConditional && break
29762964
end
@@ -2981,6 +2969,9 @@ end
29812969
= (typeinf_lattice(info.interp))
29822970
old = info.slottypes[slot_id]
29832971
new = widenslotwrapper(info.changes[slot_id].typ) # avoid nested conditional
2972+
if isvarargtype(old) || isvarargtype(new)
2973+
return rt
2974+
end
29842975
if new ᵢ old && !(old ᵢ new)
29852976
if isa(rt, Const)
29862977
val = rt.val

base/compiler/inferenceresult.jl

+91-95
Original file line numberDiff line numberDiff line change
@@ -24,27 +24,56 @@ function matching_cache_argtypes(𝕃::AbstractLattice, mi::MethodInstance,
2424
for i = 1:length(argtypes)
2525
given_argtypes[i] = widenslotwrapper(argtypes[i])
2626
end
27-
given_argtypes = va_process_argtypes(𝕃, given_argtypes, mi)
2827
return pick_const_args!(𝕃, given_argtypes, cache_argtypes)
2928
end
3029

30+
function pick_const_arg(𝕃::AbstractLattice, @nospecialize(given_argtype), @nospecialize(cache_argtype))
31+
if !is_argtype_match(𝕃, given_argtype, cache_argtype, false)
32+
# prefer the argtype we were given over the one computed from `mi`
33+
if (isa(given_argtype, PartialStruct) && isa(cache_argtype, Type) &&
34+
!(𝕃, given_argtype, cache_argtype))
35+
# if the type information of this `PartialStruct` is less strict than
36+
# declared method signature, narrow it down using `tmeet`
37+
given_argtype = tmeet(𝕃, given_argtype, cache_argtype)
38+
end
39+
else
40+
given_argtype = cache_argtype
41+
end
42+
return given_argtype
43+
end
44+
3145
function pick_const_args!(𝕃::AbstractLattice, given_argtypes::Vector{Any}, cache_argtypes::Vector{Any})
32-
nargtypes = length(given_argtypes)
33-
@assert nargtypes == length(cache_argtypes) #= == nargs =# "invalid `given_argtypes` for `mi`"
34-
for i = 1:nargtypes
35-
given_argtype = given_argtypes[i]
36-
cache_argtype = cache_argtypes[i]
37-
if !is_argtype_match(𝕃, given_argtype, cache_argtype, false)
38-
# prefer the argtype we were given over the one computed from `mi`
39-
if (isa(given_argtype, PartialStruct) && isa(cache_argtype, Type) &&
40-
!(𝕃, given_argtype, cache_argtype))
41-
# if the type information of this `PartialStruct` is less strict than
42-
# declared method signature, narrow it down using `tmeet`
43-
given_argtypes[i] = tmeet(𝕃, given_argtype, cache_argtype)
44-
end
46+
if length(given_argtypes) == 0 || length(cache_argtypes) == 0
47+
return Any[]
48+
end
49+
given_va = given_argtypes[end]
50+
cache_va = cache_argtypes[end]
51+
if isvarargtype(given_va)
52+
if isvarargtype(cache_va)
53+
# Process the common prefix, then join
54+
nprocessargs = max(length(given_argtypes)-1, length(cache_argtypes)-1)
55+
resize!(given_argtypes, nprocessargs+1)
56+
given_argtypes[end] = Vararg{pick_const_arg(𝕃, unwrapva(given_va), unwrapva(cache_va))}
4557
else
46-
given_argtypes[i] = cache_argtype
58+
nprocessargs = length(cache_argtypes)
59+
ngiven = length(given_argtypes)
60+
va = unwrapva(given_va)
61+
resize!(given_argtypes, nprocessargs)
62+
for i = ngiven:nprocessargs
63+
given_argtypes[i] = va
64+
end
4765
end
66+
elseif isvarargtype(cache_va)
67+
nprocessargs = length(given_argtypes)
68+
else
69+
@assert length(given_argtypes) == length(cache_argtypes)
70+
nprocessargs = length(given_argtypes)
71+
end
72+
for i = 1:nprocessargs
73+
given_argtype = given_argtypes[i]
74+
cache_argtype = argtype_by_index(cache_argtypes, i)
75+
given_argtype = pick_const_arg(𝕃, given_argtype, cache_argtype)
76+
given_argtypes[i] = given_argtype
4877
end
4978
return given_argtypes
5079
end
@@ -60,25 +89,33 @@ function is_argtype_match(𝕃::AbstractLattice,
6089
end
6190
end
6291

63-
va_process_argtypes(𝕃::AbstractLattice, given_argtypes::Vector{Any}, mi::MethodInstance) =
64-
va_process_argtypes(Returns(nothing), 𝕃, given_argtypes, mi)
65-
function va_process_argtypes(@specialize(va_handler!), 𝕃::AbstractLattice, given_argtypes::Vector{Any}, mi::MethodInstance)
66-
def = mi.def::Method
67-
isva = def.isva
68-
nargs = Int(def.nargs)
69-
if isva || isvarargtype(given_argtypes[end])
70-
isva_given_argtypes = Vector{Any}(undef, nargs)
92+
function va_process_argtypes(𝕃::AbstractLattice, given_argtypes::Vector{Any}, nargs::UInt, isva::Bool)
93+
if isva || (!isempty(given_argtypes) && isvarargtype(given_argtypes[end]))
94+
isva_given_argtypes = Vector{Any}(undef, Int(nargs))
7195
for i = 1:(nargs-isva)
72-
isva_given_argtypes[i] = argtype_by_index(given_argtypes, i)
96+
newarg = argtype_by_index(given_argtypes, i)
97+
if isva && has_conditional(𝕃) && isa(newarg, Conditional)
98+
if newarg.slot > (nargs-isva)
99+
newarg = widenconditional(newarg)
100+
end
101+
end
102+
isva_given_argtypes[i] = newarg
73103
end
74104
if isva
75105
if length(given_argtypes) < nargs && isvarargtype(given_argtypes[end])
76106
last = length(given_argtypes)
77107
else
78108
last = nargs
109+
if has_conditional(𝕃)
110+
for i = last:length(given_argtypes)
111+
newarg = given_argtypes[i]
112+
if isa(newarg, Conditional) && newarg.slot > (nargs-isva)
113+
given_argtypes[i] = widenconditional(newarg)
114+
end
115+
end
116+
end
79117
end
80118
isva_given_argtypes[nargs] = tuple_tfunc(𝕃, given_argtypes[last:end])
81-
va_handler!(isva_given_argtypes, last)
82119
end
83120
return isva_given_argtypes
84121
end
@@ -87,84 +124,44 @@ function va_process_argtypes(@specialize(va_handler!), 𝕃::AbstractLattice, gi
87124
end
88125

89126
function most_general_argtypes(method::Union{Method,Nothing}, @nospecialize(specTypes))
90-
toplevel = method === nothing
91-
isva = !toplevel && method.isva
92127
mi_argtypes = Any[(unwrap_unionall(specTypes)::DataType).parameters...]
93-
nargs::Int = toplevel ? 0 : method.nargs
94-
cache_argtypes = Vector{Any}(undef, nargs)
95-
# First, if we're dealing with a varargs method, then we set the last element of `args`
96-
# to the appropriate `Tuple` type or `PartialStruct` instance.
97-
mi_argtypes_length = length(mi_argtypes)
98-
if !toplevel && isva
99-
if specTypes::Type == Tuple
100-
mi_argtypes = Any[Any for i = 1:nargs]
101-
if nargs > 1
102-
mi_argtypes[end] = Tuple
103-
end
104-
vargtype = Tuple
105-
else
106-
if nargs > mi_argtypes_length
107-
va = mi_argtypes[mi_argtypes_length]
108-
if isvarargtype(va)
109-
new_va = rewrap_unionall(unconstrain_vararg_length(va), specTypes)
110-
vargtype = Tuple{new_va}
111-
else
112-
vargtype = Tuple{}
113-
end
114-
else
115-
vargtype_elements = Any[]
116-
for i in nargs:mi_argtypes_length
117-
p = mi_argtypes[i]
118-
p = unwraptv(isvarargtype(p) ? unconstrain_vararg_length(p) : p)
119-
push!(vargtype_elements, elim_free_typevars(rewrap_unionall(p, specTypes)))
120-
end
121-
for i in 1:length(vargtype_elements)
122-
atyp = vargtype_elements[i]
123-
if issingletontype(atyp)
124-
# replace singleton types with their equivalent Const object
125-
vargtype_elements[i] = Const(atyp.instance)
126-
elseif isconstType(atyp)
127-
vargtype_elements[i] = Const(atyp.parameters[1])
128-
end
129-
end
130-
vargtype = tuple_tfunc(fallback_lattice, vargtype_elements)
131-
end
132-
end
133-
cache_argtypes[nargs] = vargtype
134-
nargs -= 1
128+
nargtypes = length(mi_argtypes)
129+
nargs = isa(method, Method) ? method.nargs : 0
130+
if length(mi_argtypes) < nargs && isvarargtype(mi_argtypes[end])
131+
resize!(mi_argtypes, nargs)
135132
end
136133
# Now, we propagate type info from `mi_argtypes` into `cache_argtypes`, improving some
137134
# type info as we go (where possible). Note that if we're dealing with a varargs method,
138135
# we already handled the last element of `cache_argtypes` (and decremented `nargs` so that
139136
# we don't overwrite the result of that work here).
140-
if mi_argtypes_length > 0
141-
tail_index = nargtypes = min(mi_argtypes_length, nargs)
142-
local lastatype
143-
for i = 1:nargtypes
144-
atyp = mi_argtypes[i]
145-
if i == nargtypes && isvarargtype(atyp)
146-
atyp = unwrapva(atyp)
147-
tail_index -= 1
148-
end
149-
atyp = unwraptv(atyp)
150-
if issingletontype(atyp)
151-
# replace singleton types with their equivalent Const object
152-
atyp = Const(atyp.instance)
153-
elseif isconstType(atyp)
154-
atyp = Const(atyp.parameters[1])
155-
else
156-
atyp = elim_free_typevars(rewrap_unionall(atyp, specTypes))
157-
end
158-
i == nargtypes && (lastatype = atyp)
159-
cache_argtypes[i] = atyp
137+
tail_index = min(nargtypes, nargs)
138+
local lastatype
139+
for i = 1:nargtypes
140+
atyp = mi_argtypes[i]
141+
wasva = false
142+
if i == nargtypes && isvarargtype(atyp)
143+
wasva = true
144+
atyp = unwrapva(atyp)
160145
end
161-
for i = (tail_index+1):nargs
162-
cache_argtypes[i] = lastatype
146+
atyp = unwraptv(atyp)
147+
if issingletontype(atyp)
148+
# replace singleton types with their equivalent Const object
149+
atyp = Const(atyp.instance)
150+
elseif isconstType(atyp)
151+
atyp = Const(atyp.parameters[1])
152+
else
153+
atyp = elim_free_typevars(rewrap_unionall(atyp, specTypes))
163154
end
164-
else
165-
@assert nargs == 0 "invalid specialization of method" # wrong number of arguments
155+
mi_argtypes[i] = atyp
156+
if wasva
157+
lastatype = atyp
158+
mi_argtypes[end] = Vararg{widenconst(atyp)}
159+
end
160+
end
161+
for i = (tail_index+1):(nargs-1)
162+
mi_argtypes[i] = lastatype
166163
end
167-
return cache_argtypes
164+
return mi_argtypes
168165
end
169166

170167
# eliminate free `TypeVar`s in order to make the life much easier down the road:
@@ -184,7 +181,6 @@ function cache_lookup(𝕃::AbstractLattice, mi::MethodInstance, given_argtypes:
184181
cache::Vector{InferenceResult})
185182
method = mi.def::Method
186183
nargtypes = length(given_argtypes)
187-
@assert nargtypes == Int(method.nargs) "invalid `given_argtypes` for `mi`"
188184
for cached_result in cache
189185
cached_result.linfo === mi || @goto next_cache
190186
cache_argtypes = cached_result.argtypes

base/compiler/inferencestate.jl

+6-4
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,9 @@ mutable struct InferenceState
302302
bb_vartables = Union{Nothing,VarTable}[ nothing for i = 1:length(cfg.blocks) ]
303303
bb_vartable1 = bb_vartables[1] = VarTable(undef, nslots)
304304
argtypes = result.argtypes
305+
306+
argtypes = va_process_argtypes(typeinf_lattice(interp), argtypes, src.nargs, src.isva)
307+
305308
nargtypes = length(argtypes)
306309
for i = 1:nslots
307310
argtyp = (i > nargtypes) ? Bottom : argtypes[i]
@@ -766,10 +769,9 @@ function print_callstack(sv::InferenceState)
766769
end
767770

768771
function narguments(sv::InferenceState, include_va::Bool=true)
769-
def = sv.linfo.def
770-
nargs = length(sv.result.argtypes)
772+
nargs = sv.src.nargs
771773
if !include_va
772-
nargs -= isa(def, Method) && def.isva
774+
nargs -= sv.src.isva
773775
end
774776
return nargs
775777
end
@@ -831,7 +833,7 @@ function IRInterpretationState(interp::AbstractInterpreter,
831833
end
832834
method_info = MethodInfo(src)
833835
ir = inflate_ir(src, mi)
834-
argtypes = va_process_argtypes(optimizer_lattice(interp), argtypes, mi)
836+
argtypes = va_process_argtypes(optimizer_lattice(interp), argtypes, src.nargs, src.isva)
835837
return IRInterpretationState(interp, method_info, ir, mi, argtypes, world,
836838
codeinst.min_world, codeinst.max_world)
837839
end

0 commit comments

Comments
 (0)