Skip to content

Commit 5320bd9

Browse files
committed
Move nargs/isva to CodeInfo
This changes the canonical source of truth for va handling from `Method` to `CodeInfo`. There are multiple goals for this change: 1. This addresses a longstanding complaint about the way that CodeInfo-returning generated functions work. Previously, the va-ness or not of the returned CodeInfo always had to match that of the generator. For Cassette-like transforms that generally have one big generator function that is varargs (while then looking up lowered code that is not varargs), this could become quite annoying. It's possible to workaround, but there is really no good reason to tie the two together. As we observed when we implemented OpaqueClosures, the vararg-ness of the signature and the `vararg arguments`->`tuple` transformation are mostly independent concepts. With this PR, generated functions can return CodeInfos with whatever combination of nargs/isva is convenient. 2. This change requires clarifying where the va processing boundary is in inference. #54076 was already moving in that direction for irinterp, and this essentially does much of the same for regular inference. As a consequence the constprop cache is now using non-va-cooked signatures, which I think is preferable. 3. This further decouples codegen from the presence of a `Method` (which is already not assumed, since the code being generated could be a toplevel thunk, but some codegen features are only available to things that come from Methods). There are a number of upcoming features that will require codegen of things that are not quite method specializations (See design doc linked in #52797 and things like #50641). This helps pave the road for that. 4. I've previously considered expanding the kinds of vararg signatures that can be described (see e.g. #53851), which also requires a decoupling of the signature and ast notions of vararg. This again lays the groundwork for that, although I have no immediate plans to implement this change. Impact wise, this adds an internal field, which is not too breaking, but downstream clients vary in how they construct their `CodeInfo`s and the current way they're doing it will likely be incorrect after this change, so they will require a small two-line adjustment. We should perhaps consider pulling out some of the more common patterns into a more stable package, since interface in most of the last few releases, but that's a separate issue.
1 parent 0b70d26 commit 5320bd9

15 files changed

+189
-153
lines changed

base/compiler/abstractinterpretation.jl

+17-26
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,12 @@ function conditional_argtype(𝕃ᵢ::AbstractLattice, @nospecialize(rt), @nospe
479479
if isa(rt, InterConditional) && rt.slot == i
480480
return rt
481481
else
482-
thentype = elsetype = tmeet(𝕃ᵢ, widenslotwrapper(argtypes[i]), fieldtype(sig, i))
482+
argt = widenslotwrapper(argtypes[i])
483+
if isvarargtype(argt)
484+
@assert fieldcount(sig) == i
485+
argt = unwrapva(argt)
486+
end
487+
thentype = elsetype = tmeet(𝕃ᵢ, argt, fieldtype(sig, i))
483488
condval = maybe_extract_const_bool(rt)
484489
condval === true && (elsetype = Bottom)
485490
condval === false && (thentype = Bottom)
@@ -986,15 +991,12 @@ function maybe_get_const_prop_profitable(interp::AbstractInterpreter,
986991
# N.B. remarks are emitted within `const_prop_entry_heuristic`
987992
return nothing
988993
end
989-
nargs::Int = method.nargs
990-
method.isva && (nargs -= 1)
991-
length(arginfo.argtypes) < nargs && return nothing
992994
if !const_prop_argument_heuristic(interp, arginfo, sv)
993995
add_remark!(interp, sv, "[constprop] Disabled by argument and rettype heuristics")
994996
return nothing
995997
end
996998
all_overridden = is_all_overridden(interp, arginfo, sv)
997-
if !force && !const_prop_function_heuristic(interp, f, arginfo, nargs, all_overridden, sv)
999+
if !force && !const_prop_function_heuristic(interp, f, arginfo, all_overridden, sv)
9981000
add_remark!(interp, sv, "[constprop] Disabled by function heuristic")
9991001
return nothing
10001002
end
@@ -1113,9 +1115,9 @@ function force_const_prop(interp::AbstractInterpreter, @nospecialize(f), method:
11131115
end
11141116

11151117
function const_prop_function_heuristic(interp::AbstractInterpreter, @nospecialize(f),
1116-
arginfo::ArgInfo, nargs::Int, all_overridden::Bool, sv::AbsIntState)
1118+
arginfo::ArgInfo, all_overridden::Bool, sv::AbsIntState)
11171119
argtypes = arginfo.argtypes
1118-
if nargs > 1
1120+
if length(argtypes) > 1
11191121
𝕃ᵢ = typeinf_lattice(interp)
11201122
if istopfunction(f, :getindex) || istopfunction(f, :setindex!)
11211123
arrty = argtypes[2]
@@ -1349,20 +1351,6 @@ function matching_cache_argtypes(𝕃::AbstractLattice, mi::MethodInstance,
13491351
end
13501352
given_argtypes[i] = widenslotwrapper(argtype)
13511353
end
1352-
if condargs !== nothing
1353-
given_argtypes = let condargs=condargs
1354-
va_process_argtypes(𝕃, given_argtypes, mi) do isva_given_argtypes::Vector{Any}, last::Int
1355-
# invalidate `Conditional` imposed on varargs
1356-
for (slotid, i) in condargs
1357-
if slotid last && (1 i length(isva_given_argtypes)) # `Conditional` is already widened to vararg-tuple otherwise
1358-
isva_given_argtypes[i] = widenconditional(isva_given_argtypes[i])
1359-
end
1360-
end
1361-
end
1362-
end
1363-
else
1364-
given_argtypes = va_process_argtypes(𝕃, given_argtypes, mi)
1365-
end
13661354
return pick_const_args!(𝕃, given_argtypes, cache_argtypes)
13671355
end
13681356

@@ -1721,7 +1709,7 @@ function abstract_apply(interp::AbstractInterpreter, argtypes::Vector{Any}, si::
17211709
return CallMeta(res, exct, effects, retinfo)
17221710
end
17231711

1724-
function argtype_by_index(argtypes::Vector{Any}, i::Int)
1712+
function argtype_by_index(argtypes::Vector{Any}, i::Integer)
17251713
n = length(argtypes)
17261714
na = argtypes[n]
17271715
if isvarargtype(na)
@@ -2890,12 +2878,12 @@ end
28902878
struct BestguessInfo{Interp<:AbstractInterpreter}
28912879
interp::Interp
28922880
bestguess
2893-
nargs::Int
2881+
nargs::UInt
28942882
slottypes::Vector{Any}
28952883
changes::VarTable
2896-
function BestguessInfo(interp::Interp, @nospecialize(bestguess), nargs::Int,
2884+
function BestguessInfo(interp::Interp, @nospecialize(bestguess), nargs::UInt,
28972885
slottypes::Vector{Any}, changes::VarTable) where Interp<:AbstractInterpreter
2898-
new{Interp}(interp, bestguess, nargs, slottypes, changes)
2886+
new{Interp}(interp, bestguess, Int(nargs), slottypes, changes)
28992887
end
29002888
end
29012889

@@ -2970,7 +2958,7 @@ end
29702958
# pick up the first "interesting" slot, convert `rt` to its `Conditional`
29712959
# TODO: ideally we want `Conditional` and `InterConditional` to convey
29722960
# constraints on multiple slots
2973-
for slot_id = 1:info.nargs
2961+
for slot_id = 1:Int(info.nargs)
29742962
rt = bool_rt_to_conditional(rt, slot_id, info)
29752963
rt isa InterConditional && break
29762964
end
@@ -2981,6 +2969,9 @@ end
29812969
= (typeinf_lattice(info.interp))
29822970
old = info.slottypes[slot_id]
29832971
new = widenslotwrapper(info.changes[slot_id].typ) # avoid nested conditional
2972+
if isvarargtype(old) || isvarargtype(new)
2973+
return rt
2974+
end
29842975
if new ᵢ old && !(old ᵢ new)
29852976
if isa(rt, Const)
29862977
val = rt.val

base/compiler/inferenceresult.jl

+88-95
Original file line numberDiff line numberDiff line change
@@ -24,27 +24,53 @@ function matching_cache_argtypes(𝕃::AbstractLattice, mi::MethodInstance,
2424
for i = 1:length(argtypes)
2525
given_argtypes[i] = widenslotwrapper(argtypes[i])
2626
end
27-
given_argtypes = va_process_argtypes(𝕃, given_argtypes, mi)
2827
return pick_const_args!(𝕃, given_argtypes, cache_argtypes)
2928
end
3029

30+
function pick_const_arg(𝕃::AbstractLattice, @nospecialize(given_argtype), @nospecialize(cache_argtype))
31+
if !is_argtype_match(𝕃, given_argtype, cache_argtype, false)
32+
# prefer the argtype we were given over the one computed from `mi`
33+
if (isa(given_argtype, PartialStruct) && isa(cache_argtype, Type) &&
34+
!(𝕃, given_argtype, cache_argtype))
35+
# if the type information of this `PartialStruct` is less strict than
36+
# declared method signature, narrow it down using `tmeet`
37+
given_argtype = tmeet(𝕃, given_argtype, cache_argtype)
38+
end
39+
else
40+
given_argtype = cache_argtype
41+
end
42+
return given_argtype
43+
end
44+
3145
function pick_const_args!(𝕃::AbstractLattice, given_argtypes::Vector{Any}, cache_argtypes::Vector{Any})
32-
nargtypes = length(given_argtypes)
33-
@assert nargtypes == length(cache_argtypes) #= == nargs =# "invalid `given_argtypes` for `mi`"
34-
for i = 1:nargtypes
35-
given_argtype = given_argtypes[i]
36-
cache_argtype = cache_argtypes[i]
37-
if !is_argtype_match(𝕃, given_argtype, cache_argtype, false)
38-
# prefer the argtype we were given over the one computed from `mi`
39-
if (isa(given_argtype, PartialStruct) && isa(cache_argtype, Type) &&
40-
!(𝕃, given_argtype, cache_argtype))
41-
# if the type information of this `PartialStruct` is less strict than
42-
# declared method signature, narrow it down using `tmeet`
43-
given_argtypes[i] = tmeet(𝕃, given_argtype, cache_argtype)
44-
end
46+
if length(given_argtypes) == 0 || length(cache_argtypes) == 0
47+
return Any[]
48+
end
49+
given_va = given_argtypes[end]
50+
cache_va = cache_argtypes[end]
51+
if isvarargtype(given_va)
52+
if isvarargtype(cache_va)
53+
# Process the common prefix, then join
54+
nprocessargs = max(length(given_argtypes)-1, length(cache_argtypes)-1)
55+
resize!(given_argtypes, nprocessargs+1)
56+
given_argtypes[end] = Vararg{pick_const_arg(𝕃, unwrapva(given_va), unwrapva(cache_va))}
4557
else
46-
given_argtypes[i] = cache_argtype
58+
nprocessargs = length(cache_argtypes)
59+
resize!(given_argtypes, nprocessargs)
4760
end
61+
elseif isvarargtype(cache_va)
62+
nprocessargs = length(given_argtypes)
63+
resize!(given_argtypes, nprocessargs)
64+
else
65+
@assert length(given_argtypes) == length(cache_argtypes)
66+
nprocessargs = length(given_argtypes)
67+
resize!(given_argtypes, nprocessargs)
68+
end
69+
for i = 1:nprocessargs
70+
given_argtype = argtype_by_index(given_argtypes, i)
71+
cache_argtype = argtype_by_index(cache_argtypes, i)
72+
given_argtype = pick_const_arg(𝕃, given_argtype, cache_argtype)
73+
given_argtypes[i] = given_argtype
4874
end
4975
return given_argtypes
5076
end
@@ -60,25 +86,33 @@ function is_argtype_match(𝕃::AbstractLattice,
6086
end
6187
end
6288

63-
va_process_argtypes(𝕃::AbstractLattice, given_argtypes::Vector{Any}, mi::MethodInstance) =
64-
va_process_argtypes(Returns(nothing), 𝕃, given_argtypes, mi)
65-
function va_process_argtypes(@specialize(va_handler!), 𝕃::AbstractLattice, given_argtypes::Vector{Any}, mi::MethodInstance)
66-
def = mi.def::Method
67-
isva = def.isva
68-
nargs = Int(def.nargs)
69-
if isva || isvarargtype(given_argtypes[end])
70-
isva_given_argtypes = Vector{Any}(undef, nargs)
89+
function va_process_argtypes(𝕃::AbstractLattice, given_argtypes::Vector{Any}, nargs::UInt, isva::Bool)
90+
if isva || (!isempty(given_argtypes) && isvarargtype(given_argtypes[end]))
91+
isva_given_argtypes = Vector{Any}(undef, Int(nargs))
7192
for i = 1:(nargs-isva)
72-
isva_given_argtypes[i] = argtype_by_index(given_argtypes, i)
93+
newarg = argtype_by_index(given_argtypes, i)
94+
if isva && has_conditional(𝕃) && isa(newarg, Conditional)
95+
if newarg.slotid > (nargs-isva)
96+
newarg = widenconditional(newarg)
97+
end
98+
end
99+
isva_given_argtypes[i] = newarg
73100
end
74101
if isva
75102
if length(given_argtypes) < nargs && isvarargtype(given_argtypes[end])
76103
last = length(given_argtypes)
77104
else
78105
last = nargs
106+
if has_conditional(𝕃)
107+
for i = last:length(given_argtypes)
108+
newarg = given_argtypes[i]
109+
if isa(newarg, Conditional) && newarg.slotid > (nargs-isva)
110+
given_argtypes[i] = widenconditional(newarg)
111+
end
112+
end
113+
end
79114
end
80115
isva_given_argtypes[nargs] = tuple_tfunc(𝕃, given_argtypes[last:end])
81-
va_handler!(isva_given_argtypes, last)
82116
end
83117
return isva_given_argtypes
84118
end
@@ -87,84 +121,44 @@ function va_process_argtypes(@specialize(va_handler!), 𝕃::AbstractLattice, gi
87121
end
88122

89123
function most_general_argtypes(method::Union{Method,Nothing}, @nospecialize(specTypes))
90-
toplevel = method === nothing
91-
isva = !toplevel && method.isva
92124
mi_argtypes = Any[(unwrap_unionall(specTypes)::DataType).parameters...]
93-
nargs::Int = toplevel ? 0 : method.nargs
94-
cache_argtypes = Vector{Any}(undef, nargs)
95-
# First, if we're dealing with a varargs method, then we set the last element of `args`
96-
# to the appropriate `Tuple` type or `PartialStruct` instance.
97-
mi_argtypes_length = length(mi_argtypes)
98-
if !toplevel && isva
99-
if specTypes::Type == Tuple
100-
mi_argtypes = Any[Any for i = 1:nargs]
101-
if nargs > 1
102-
mi_argtypes[end] = Tuple
103-
end
104-
vargtype = Tuple
105-
else
106-
if nargs > mi_argtypes_length
107-
va = mi_argtypes[mi_argtypes_length]
108-
if isvarargtype(va)
109-
new_va = rewrap_unionall(unconstrain_vararg_length(va), specTypes)
110-
vargtype = Tuple{new_va}
111-
else
112-
vargtype = Tuple{}
113-
end
114-
else
115-
vargtype_elements = Any[]
116-
for i in nargs:mi_argtypes_length
117-
p = mi_argtypes[i]
118-
p = unwraptv(isvarargtype(p) ? unconstrain_vararg_length(p) : p)
119-
push!(vargtype_elements, elim_free_typevars(rewrap_unionall(p, specTypes)))
120-
end
121-
for i in 1:length(vargtype_elements)
122-
atyp = vargtype_elements[i]
123-
if issingletontype(atyp)
124-
# replace singleton types with their equivalent Const object
125-
vargtype_elements[i] = Const(atyp.instance)
126-
elseif isconstType(atyp)
127-
vargtype_elements[i] = Const(atyp.parameters[1])
128-
end
129-
end
130-
vargtype = tuple_tfunc(fallback_lattice, vargtype_elements)
131-
end
132-
end
133-
cache_argtypes[nargs] = vargtype
134-
nargs -= 1
125+
nargtypes = length(mi_argtypes)
126+
nargs = isa(method, Method) ? method.nargs : 0
127+
if length(mi_argtypes) < nargs && isvarargtype(mi_argtypes[end])
128+
resize!(mi_argtypes, nargs)
135129
end
136130
# Now, we propagate type info from `mi_argtypes` into `cache_argtypes`, improving some
137131
# type info as we go (where possible). Note that if we're dealing with a varargs method,
138132
# we already handled the last element of `cache_argtypes` (and decremented `nargs` so that
139133
# we don't overwrite the result of that work here).
140-
if mi_argtypes_length > 0
141-
tail_index = nargtypes = min(mi_argtypes_length, nargs)
142-
local lastatype
143-
for i = 1:nargtypes
144-
atyp = mi_argtypes[i]
145-
if i == nargtypes && isvarargtype(atyp)
146-
atyp = unwrapva(atyp)
147-
tail_index -= 1
148-
end
149-
atyp = unwraptv(atyp)
150-
if issingletontype(atyp)
151-
# replace singleton types with their equivalent Const object
152-
atyp = Const(atyp.instance)
153-
elseif isconstType(atyp)
154-
atyp = Const(atyp.parameters[1])
155-
else
156-
atyp = elim_free_typevars(rewrap_unionall(atyp, specTypes))
157-
end
158-
i == nargtypes && (lastatype = atyp)
159-
cache_argtypes[i] = atyp
134+
tail_index = min(nargtypes, nargs)
135+
local lastatype
136+
for i = 1:nargtypes
137+
atyp = mi_argtypes[i]
138+
wasva = false
139+
if i == nargtypes && isvarargtype(atyp)
140+
wasva = true
141+
atyp = unwrapva(atyp)
160142
end
161-
for i = (tail_index+1):nargs
162-
cache_argtypes[i] = lastatype
143+
atyp = unwraptv(atyp)
144+
if issingletontype(atyp)
145+
# replace singleton types with their equivalent Const object
146+
atyp = Const(atyp.instance)
147+
elseif isconstType(atyp)
148+
atyp = Const(atyp.parameters[1])
149+
else
150+
atyp = elim_free_typevars(rewrap_unionall(atyp, specTypes))
163151
end
164-
else
165-
@assert nargs == 0 "invalid specialization of method" # wrong number of arguments
152+
mi_argtypes[i] = atyp
153+
if wasva
154+
lastatype = atyp
155+
mi_argtypes[end] = Vararg{atyp}
156+
end
157+
end
158+
for i = (tail_index+1):(nargs-1)
159+
mi_argtypes[i] = lastatype
166160
end
167-
return cache_argtypes
161+
return mi_argtypes
168162
end
169163

170164
# eliminate free `TypeVar`s in order to make the life much easier down the road:
@@ -184,7 +178,6 @@ function cache_lookup(𝕃::AbstractLattice, mi::MethodInstance, given_argtypes:
184178
cache::Vector{InferenceResult})
185179
method = mi.def::Method
186180
nargtypes = length(given_argtypes)
187-
@assert nargtypes == Int(method.nargs) "invalid `given_argtypes` for `mi`"
188181
for cached_result in cache
189182
cached_result.linfo === mi || @goto next_cache
190183
cache_argtypes = cached_result.argtypes

base/compiler/inferencestate.jl

+6-4
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,9 @@ mutable struct InferenceState
302302
bb_vartables = Union{Nothing,VarTable}[ nothing for i = 1:length(cfg.blocks) ]
303303
bb_vartable1 = bb_vartables[1] = VarTable(undef, nslots)
304304
argtypes = result.argtypes
305+
306+
argtypes = va_process_argtypes(typeinf_lattice(interp), argtypes, src.nargs, src.isva)
307+
305308
nargtypes = length(argtypes)
306309
for i = 1:nslots
307310
argtyp = (i > nargtypes) ? Bottom : argtypes[i]
@@ -766,10 +769,9 @@ function print_callstack(sv::InferenceState)
766769
end
767770

768771
function narguments(sv::InferenceState, include_va::Bool=true)
769-
def = sv.linfo.def
770-
nargs = length(sv.result.argtypes)
772+
nargs = sv.src.nargs
771773
if !include_va
772-
nargs -= isa(def, Method) && def.isva
774+
nargs -= sv.src.isva
773775
end
774776
return nargs
775777
end
@@ -831,7 +833,7 @@ function IRInterpretationState(interp::AbstractInterpreter,
831833
end
832834
method_info = MethodInfo(src)
833835
ir = inflate_ir(src, mi)
834-
argtypes = va_process_argtypes(optimizer_lattice(interp), argtypes, mi)
836+
argtypes = va_process_argtypes(optimizer_lattice(interp), argtypes, src.nargs, src.isva)
835837
return IRInterpretationState(interp, method_info, ir, mi, argtypes, world,
836838
codeinst.min_world, codeinst.max_world)
837839
end

base/compiler/optimize.jl

+2-3
Original file line numberDiff line numberDiff line change
@@ -1264,14 +1264,13 @@ end
12641264
function slot2reg(ir::IRCode, ci::CodeInfo, sv::OptimizationState)
12651265
# need `ci` for the slot metadata, IR for the code
12661266
svdef = sv.linfo.def
1267-
nargs = isa(svdef, Method) ? Int(svdef.nargs) : 0
12681267
@timeit "domtree 1" domtree = construct_domtree(ir)
1269-
defuse_insts = scan_slot_def_use(nargs, ci, ir.stmts.stmt)
1268+
defuse_insts = scan_slot_def_use(ci.nargs, ci, ir.stmts.stmt)
12701269
𝕃ₒ = optimizer_lattice(sv.inlining.interp)
12711270
@timeit "construct_ssa" ir = construct_ssa!(ci, ir, sv, domtree, defuse_insts, 𝕃ₒ) # consumes `ir`
12721271
# NOTE now we have converted `ir` to the SSA form and eliminated slots
12731272
# let's resize `argtypes` now and remove unnecessary types for the eliminated slots
1274-
resize!(ir.argtypes, nargs)
1273+
resize!(ir.argtypes, ci.nargs)
12751274
return ir
12761275
end
12771276

0 commit comments

Comments
 (0)