Skip to content

Commit 5aa2462

Browse files
authored
Merge pull request #30878 from JuliaLang/jb/sptypes
some improvements to static parameter handling in inference
2 parents e7e726b + 36d490a commit 5aa2462

File tree

12 files changed

+108
-101
lines changed

12 files changed

+108
-101
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -863,21 +863,6 @@ function abstract_eval_cfunction(e::Expr, vtypes::VarTable, sv::InferenceState)
863863
nothing
864864
end
865865

866-
# convert an inferred static parameter value to the inferred type of a static_parameter expression
867-
function sparam_type(@nospecialize(val))
868-
if isa(val, TypeVar)
869-
if Any <: val.ub
870-
# static param bound to typevar
871-
# if the tvar is not known to refer to anything more specific than Any,
872-
# the static param might actually be an integer, symbol, etc.
873-
return Any
874-
else
875-
return UnionAll(val, Type{val})
876-
end
877-
end
878-
return AbstractEvalConstant(val)
879-
end
880-
881866
function abstract_eval(@nospecialize(e), vtypes::VarTable, sv::InferenceState)
882867
if isa(e, QuoteNode)
883868
return AbstractEvalConstant((e::QuoteNode).value)
@@ -940,8 +925,8 @@ function abstract_eval(@nospecialize(e), vtypes::VarTable, sv::InferenceState)
940925
elseif e.head === :static_parameter
941926
n = e.args[1]
942927
t = Any
943-
if 1 <= n <= length(sv.sp)
944-
t = sparam_type(sv.sp[n])
928+
if 1 <= n <= length(sv.sptypes)
929+
t = sv.sptypes[n]
945930
end
946931
elseif e.head === :method
947932
t = (length(e.args) == 1) ? Any : Nothing
@@ -975,9 +960,9 @@ function abstract_eval(@nospecialize(e), vtypes::VarTable, sv::InferenceState)
975960
end
976961
elseif isa(sym, Expr) && sym.head === :static_parameter
977962
n = sym.args[1]
978-
if 1 <= n <= length(sv.sp)
979-
val = sv.sp[n]
980-
if !isa(val, TypeVar)
963+
if 1 <= n <= length(sv.sptypes)
964+
spty = sv.sptypes[n]
965+
if isa(spty, Const)
981966
t = Const(true)
982967
end
983968
end

base/compiler/inferencestate.jl

Lines changed: 45 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ const LineNum = Int
55
mutable struct InferenceState
66
params::Params # describes how to compute the result
77
result::InferenceResult # remember where to put the result
8-
linfo::MethodInstance # used here for the tuple (specTypes, env, Method) and world-age validity
9-
sp::SimpleVector # static parameters
8+
linfo::MethodInstance # used here for the tuple (specTypes, env, Method) and world-age validity
9+
sptypes::Vector{Any} # types of static parameter
1010
slottypes::Vector{Any}
1111
mod::Module
1212
currpc::LineNum
@@ -48,7 +48,7 @@ mutable struct InferenceState
4848
code = src.code::Array{Any,1}
4949
toplevel = !isa(linfo.def, Method)
5050

51-
sp = spvals_from_meth_instance(linfo::MethodInstance)
51+
sp = sptypes_from_meth_instance(linfo::MethodInstance)
5252

5353
nssavalues = src.ssavaluetypes::Int
5454
src.ssavaluetypes = Any[ NOT_FOUND for i = 1:nssavalues ]
@@ -120,7 +120,7 @@ function InferenceState(result::InferenceResult, cached::Bool, params::Params)
120120
return InferenceState(result, src, cached, params)
121121
end
122122

123-
function spvals_from_meth_instance(linfo::MethodInstance)
123+
function sptypes_from_meth_instance(linfo::MethodInstance)
124124
toplevel = !isa(linfo.def, Method)
125125
if !toplevel && isempty(linfo.sparam_vals) && !isempty(linfo.def.sparam_syms)
126126
# linfo is unspecialized
@@ -130,35 +130,54 @@ function spvals_from_meth_instance(linfo::MethodInstance)
130130
push!(sp, sig.var)
131131
sig = sig.body
132132
end
133-
sp = svec(sp...)
134133
else
135-
sp = linfo.sparam_vals
136-
if _any(t->isa(t,TypeVar), sp)
137-
sp = collect(Any, sp)
138-
end
134+
sp = collect(Any, linfo.sparam_vals)
139135
end
140-
if !isa(sp, SimpleVector)
141-
for i = 1:length(sp)
142-
v = sp[i]
143-
if v isa TypeVar
144-
ub = v.ub
145-
while ub isa TypeVar
146-
ub = ub.ub
147-
end
148-
if has_free_typevars(ub)
149-
ub = Any
136+
for i = 1:length(sp)
137+
v = sp[i]
138+
if v isa TypeVar
139+
ub = v.ub
140+
while ub isa TypeVar
141+
ub = ub.ub
142+
end
143+
if has_free_typevars(ub)
144+
ub = Any
145+
end
146+
lb = v.lb
147+
while lb isa TypeVar
148+
lb = lb.lb
149+
end
150+
if has_free_typevars(lb)
151+
lb = Bottom
152+
end
153+
if Any <: ub && lb <: Bottom
154+
ty = Any
155+
# if this parameter came from arg::Type{T}, we know that T::Type
156+
sig = linfo.def.sig
157+
temp = sig
158+
for j = 1:i-1
159+
temp = temp.body
150160
end
151-
lb = v.lb
152-
while lb isa TypeVar
153-
lb = lb.lb
161+
Pi = temp.var
162+
while temp isa UnionAll
163+
temp = temp.body
154164
end
155-
if has_free_typevars(lb)
156-
lb = Bottom
165+
sigtypes = temp.parameters
166+
for j = 1:length(sigtypes)
167+
tj = sigtypes[j]
168+
if isType(tj) && tj.parameters[1] === Pi
169+
ty = Type
170+
break
171+
end
157172
end
158-
sp[i] = TypeVar(v.name, lb, ub)
173+
else
174+
tv = TypeVar(v.name, lb, ub)
175+
ty = UnionAll(tv, Type{tv})
159176
end
177+
else
178+
ty = Const(v)
160179
end
161-
sp = svec(sp...)
180+
sp[i] = ty
162181
end
163182
return sp
164183
end

base/compiler/optimize.jl

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ mutable struct OptimizationState
1313
min_valid::UInt
1414
max_valid::UInt
1515
params::Params
16-
sp::SimpleVector # static parameters
16+
sptypes::Vector{Any} # static parameters
1717
slottypes::Vector{Any}
1818
const_api::Bool
1919
function OptimizationState(frame::InferenceState)
@@ -27,7 +27,7 @@ mutable struct OptimizationState
2727
s_edges::Vector{Any},
2828
src, frame.mod, frame.nargs,
2929
frame.min_valid, frame.max_valid,
30-
frame.params, frame.sp, frame.slottypes, false)
30+
frame.params, frame.sptypes, frame.slottypes, false)
3131
end
3232
function OptimizationState(linfo::MethodInstance, src::CodeInfo,
3333
params::Params)
@@ -54,7 +54,7 @@ mutable struct OptimizationState
5454
s_edges::Vector{Any},
5555
src, inmodule, nargs,
5656
min_world(linfo), max_world(linfo),
57-
params, spvals_from_meth_instance(linfo), slottypes, false)
57+
params, sptypes_from_meth_instance(linfo), slottypes, false)
5858
end
5959
end
6060

@@ -135,7 +135,7 @@ function isinlineable(m::Method, me::OptimizationState, bonus::Int=0)
135135
end
136136
end
137137
if !inlineable
138-
inlineable = inline_worthy(me.src.code, me.src, me.sp, me.slottypes, me.params, cost_threshold + bonus)
138+
inlineable = inline_worthy(me.src.code, me.src, me.sptypes, me.slottypes, me.params, cost_threshold + bonus)
139139
end
140140
return inlineable
141141
end
@@ -148,7 +148,7 @@ function stmt_affects_purity(@nospecialize(stmt), ir)
148148
return false
149149
end
150150
if isa(stmt, GotoIfNot)
151-
t = argextype(stmt.cond, ir, ir.spvals)
151+
t = argextype(stmt.cond, ir, ir.sptypes)
152152
return !(t Bool)
153153
end
154154
if isa(stmt, Expr)
@@ -175,7 +175,7 @@ function optimize(opt::OptimizationState, @nospecialize(result))
175175
proven_pure = true
176176
for i in 1:length(ir.stmts)
177177
stmt = ir.stmts[i]
178-
if stmt_affects_purity(stmt, ir) && !stmt_effect_free(stmt, ir.types[i], ir, ir.spvals)
178+
if stmt_affects_purity(stmt, ir) && !stmt_effect_free(stmt, ir.types[i], ir, ir.sptypes)
179179
proven_pure = false
180180
break
181181
end
@@ -268,19 +268,19 @@ plus_saturate(x::Int, y::Int) = max(x, y, x+y)
268268
# known return type
269269
isknowntype(@nospecialize T) = (T == Union{}) || isconcretetype(T)
270270

271-
function statement_cost(ex::Expr, line::Int, src::CodeInfo, spvals::SimpleVector, slottypes::Vector{Any}, params::Params)
271+
function statement_cost(ex::Expr, line::Int, src::CodeInfo, sptypes::Vector{Any}, slottypes::Vector{Any}, params::Params)
272272
head = ex.head
273273
if is_meta_expr_head(head)
274274
return 0
275275
elseif head === :call
276276
farg = ex.args[1]
277-
ftyp = argextype(farg, src, spvals, slottypes)
277+
ftyp = argextype(farg, src, sptypes, slottypes)
278278
if ftyp === IntrinsicFunction && farg isa SSAValue
279279
# if this comes from code that was already inlined into another function,
280280
# Consts have been widened. try to recover in simple cases.
281281
farg = src.code[farg.id]
282282
if isa(farg, GlobalRef) || isa(farg, QuoteNode) || isa(farg, IntrinsicFunction) || isexpr(farg, :static_parameter)
283-
ftyp = argextype(farg, src, spvals, slottypes)
283+
ftyp = argextype(farg, src, sptypes, slottypes)
284284
end
285285
end
286286
f = singleton_type(ftyp)
@@ -302,7 +302,7 @@ function statement_cost(ex::Expr, line::Int, src::CodeInfo, spvals::SimpleVector
302302
# return plus_saturate(argcost, isknowntype(extyp) ? 1 : params.inline_nonleaf_penalty)
303303
return 0
304304
elseif f === Main.Core.arrayref && length(ex.args) >= 3
305-
atyp = argextype(ex.args[3], src, spvals, slottypes)
305+
atyp = argextype(ex.args[3], src, sptypes, slottypes)
306306
return isknowntype(atyp) ? 4 : params.inline_nonleaf_penalty
307307
end
308308
fidx = find_tfunc(f)
@@ -325,7 +325,7 @@ function statement_cost(ex::Expr, line::Int, src::CodeInfo, spvals::SimpleVector
325325
elseif head === :return
326326
a = ex.args[1]
327327
if a isa Expr
328-
return statement_cost(a, -1, src, spvals, slottypes, params)
328+
return statement_cost(a, -1, src, sptypes, slottypes, params)
329329
end
330330
return 0
331331
elseif head === :(=)
@@ -336,7 +336,7 @@ function statement_cost(ex::Expr, line::Int, src::CodeInfo, spvals::SimpleVector
336336
end
337337
a = ex.args[2]
338338
if a isa Expr
339-
cost = plus_saturate(cost, statement_cost(a, -1, src, spvals, slottypes, params))
339+
cost = plus_saturate(cost, statement_cost(a, -1, src, sptypes, slottypes, params))
340340
end
341341
return cost
342342
elseif head === :copyast
@@ -357,13 +357,13 @@ function statement_cost(ex::Expr, line::Int, src::CodeInfo, spvals::SimpleVector
357357
return 0
358358
end
359359

360-
function inline_worthy(body::Array{Any,1}, src::CodeInfo, spvals::SimpleVector, slottypes::Vector{Any},
360+
function inline_worthy(body::Array{Any,1}, src::CodeInfo, sptypes::Vector{Any}, slottypes::Vector{Any},
361361
params::Params, cost_threshold::Integer=params.inline_cost_threshold)
362362
bodycost::Int = 0
363363
for line = 1:length(body)
364364
stmt = body[line]
365365
if stmt isa Expr
366-
thiscost = statement_cost(stmt, line, src, spvals, slottypes, params)::Int
366+
thiscost = statement_cost(stmt, line, src, sptypes, slottypes, params)::Int
367367
elseif stmt isa GotoNode
368368
# loops are generally always expensive
369369
# but assume that forward jumps are already counted for from
@@ -378,11 +378,11 @@ function inline_worthy(body::Array{Any,1}, src::CodeInfo, spvals::SimpleVector,
378378
return true
379379
end
380380

381-
function is_known_call(e::Expr, @nospecialize(func), src, spvals::SimpleVector, slottypes::Vector{Any} = empty_slottypes)
381+
function is_known_call(e::Expr, @nospecialize(func), src, sptypes::Vector{Any}, slottypes::Vector{Any} = empty_slottypes)
382382
if e.head !== :call
383383
return false
384384
end
385-
f = argextype(e.args[1], src, spvals, slottypes)
385+
f = argextype(e.args[1], src, sptypes, slottypes)
386386
return isa(f, Const) && f.val === func
387387
end
388388

base/compiler/ssair/driver.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,9 @@ function just_construct_ssa(ci::CodeInfo, code::Vector{Any}, nargs::Int, sv::Opt
104104
@timeit "domtree 1" domtree = construct_domtree(cfg)
105105
ir = let code = Any[nothing for _ = 1:length(code)]
106106
argtypes = sv.slottypes[1:(nargs+1)]
107-
IRCode(code, Any[], ci.codelocs, flags, cfg, collect(LineInfoNode, ci.linetable), argtypes, meta, sv.sp)
107+
IRCode(code, Any[], ci.codelocs, flags, cfg, collect(LineInfoNode, ci.linetable), argtypes, meta, sv.sptypes)
108108
end
109-
@timeit "construct_ssa" ir = construct_ssa!(ci, code, ir, domtree, defuse_insts, nargs, sv.sp, sv.slottypes)
109+
@timeit "construct_ssa" ir = construct_ssa!(ci, code, ir, domtree, defuse_insts, nargs, sv.sptypes, sv.slottypes)
110110
return ir
111111
end
112112

base/compiler/ssair/inlining.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -787,7 +787,7 @@ function assemble_inline_todo!(ir::IRCode, linetable::Vector{LineInfoNode}, sv::
787787
isempty(eargs) && continue
788788
arg1 = eargs[1]
789789

790-
ft = argextype(arg1, ir, sv.sp)
790+
ft = argextype(arg1, ir, sv.sptypes)
791791
has_free_typevars(ft) && continue
792792
f = singleton_type(ft)
793793
f === Core.Intrinsics.llvmcall && continue
@@ -797,7 +797,7 @@ function assemble_inline_todo!(ir::IRCode, linetable::Vector{LineInfoNode}, sv::
797797
atypes[1] = ft
798798
ok = true
799799
for i = 2:length(stmt.args)
800-
a = argextype(stmt.args[i], ir, sv.sp)
800+
a = argextype(stmt.args[i], ir, sv.sptypes)
801801
(a === Bottom || isvarargtype(a)) && (ok = false; break)
802802
atypes[i] = a
803803
end

base/compiler/ssair/ir.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -213,20 +213,20 @@ struct IRCode
213213
lines::Vector{Int32}
214214
flags::Vector{UInt8}
215215
argtypes::Vector{Any}
216-
spvals::SimpleVector
216+
sptypes::Vector{Any}
217217
linetable::Vector{LineInfoNode}
218218
cfg::CFG
219219
new_nodes::Vector{NewNode}
220220
meta::Vector{Any}
221221

222222
function IRCode(stmts::Vector{Any}, types::Vector{Any}, lines::Vector{Int32}, flags::Vector{UInt8},
223223
cfg::CFG, linetable::Vector{LineInfoNode}, argtypes::Vector{Any}, meta::Vector{Any},
224-
spvals::SimpleVector)
225-
return new(stmts, types, lines, flags, argtypes, spvals, linetable, cfg, NewNode[], meta)
224+
sptypes::Vector{Any})
225+
return new(stmts, types, lines, flags, argtypes, sptypes, linetable, cfg, NewNode[], meta)
226226
end
227227
function IRCode(ir::IRCode, stmts::Vector{Any}, types::Vector{Any}, lines::Vector{Int32}, flags::Vector{UInt8},
228228
cfg::CFG, new_nodes::Vector{NewNode})
229-
return new(stmts, types, lines, flags, ir.argtypes, ir.spvals, ir.linetable, cfg, new_nodes, ir.meta)
229+
return new(stmts, types, lines, flags, ir.argtypes, ir.sptypes, ir.linetable, cfg, new_nodes, ir.meta)
230230
end
231231
end
232232
copy(code::IRCode) = IRCode(code, copy(code.stmts), copy(code.types),
@@ -1143,7 +1143,7 @@ function maybe_erase_unused!(extra_worklist, compact, idx, callback = x->nothing
11431143
if compact_exprtype(compact, SSAValue(idx)) === Bottom
11441144
effect_free = false
11451145
else
1146-
effect_free = stmt_effect_free(stmt, compact.result_types[idx], compact, compact.ir.spvals)
1146+
effect_free = stmt_effect_free(stmt, compact.result_types[idx], compact, compact.ir.sptypes)
11471147
end
11481148
if effect_free
11491149
for ops in userefs(stmt)

base/compiler/ssair/legacy.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
# This file is a part of Julia. License is MIT: https://julialang.org/license
22

3-
inflate_ir(ci::CodeInfo) = inflate_ir(ci, Core.svec(), Any[ Any for i = 1:length(ci.slotnames) ])
3+
inflate_ir(ci::CodeInfo) = inflate_ir(ci, Any[], Any[ Any for i = 1:length(ci.slotnames) ])
44

55
function inflate_ir(ci::CodeInfo, linfo::MethodInstance)
6-
spvals = spvals_from_meth_instance(linfo)
6+
sptypes = sptypes_from_meth_instance(linfo)
77
if ci.inferred
88
argtypes, _ = matching_cache_argtypes(linfo, nothing)
99
else
1010
argtypes = Any[ Any for i = 1:length(ci.slotnames) ]
1111
end
12-
return inflate_ir(ci, spvals, argtypes)
12+
return inflate_ir(ci, sptypes, argtypes)
1313
end
1414

15-
function inflate_ir(ci::CodeInfo, spvals::SimpleVector, argtypes::Vector{Any})
15+
function inflate_ir(ci::CodeInfo, sptypes::Vector{Any}, argtypes::Vector{Any})
1616
code = copy_exprargs(ci.code)
1717
for i = 1:length(code)
1818
if isa(code[i], Expr)
@@ -46,7 +46,7 @@ function inflate_ir(ci::CodeInfo, spvals::SimpleVector, argtypes::Vector{Any})
4646
end
4747
ssavaluetypes = ci.ssavaluetypes isa Vector{Any} ? copy(ci.ssavaluetypes) : Any[ Any for i = 1:(ci.ssavaluetypes::Int) ]
4848
ir = IRCode(code, ssavaluetypes, copy(ci.codelocs), copy(ci.ssaflags), cfg, collect(LineInfoNode, ci.linetable),
49-
argtypes, Any[], spvals)
49+
argtypes, Any[], sptypes)
5050
return ir
5151
end
5252

0 commit comments

Comments
 (0)