Skip to content

Commit 795935f

Browse files
authored
more type-stable type-inference (#41697)
(this PR is the final output of my demo at [our workshop](https://github.com/aviatesk/juliacon2021-workshop-pkgdev)) This PR eliminated much of runtime dispatches within our type inference routine, that are reported by the following JET analysis: ```julia using JETTest const CC = Core.Compiler function function_filter(@nospecialize(ft)) ft === typeof(CC.isprimitivetype) && return false ft === typeof(CC.ismutabletype) && return false ft === typeof(CC.isbitstype) && return false ft === typeof(CC.widenconst) && return false ft === typeof(CC.widenconditional) && return false ft === typeof(CC.widenwrappedconditional) && return false ft === typeof(CC.maybe_extract_const_bool) && return false ft === typeof(CC.ignorelimited) && return false return true end function frame_filter((; linfo) = sv) meth = linfo.def isa(meth, Method) || return true return occursin("compiler/", string(meth.file)) end report_dispatch(CC.typeinf, (CC.NativeInterpreter, CC.InferenceState); function_filter, frame_filter) ``` > on master ``` ═════ 137 possible errors found ═════ ... ``` > on this PR ``` ═════ 51 possible errors found ═════ ... ``` And it seems like this PR makes JIT slightly faster: > on master ```julia ~/julia/julia master ❯ ./usr/bin/julia -e '@time using Plots; @time plot(rand(10,3));' 3.659865 seconds (7.19 M allocations: 497.982 MiB, 3.94% gc time, 0.39% compilation time) 2.696410 seconds (3.62 M allocations: 202.905 MiB, 7.49% gc time, 56.39% compilation time) ``` > on this PR ```julia ~/julia/julia avi/jetdemo* 7s ❯ ./usr/bin/julia -e '@time using Plots; @time plot(rand(10,3));' 3.396974 seconds (7.16 M allocations: 491.442 MiB, 4.80% gc time, 0.28% compilation time) 2.591130 seconds (3.48 M allocations: 196.026 MiB, 7.29% gc time, 56.72% compilation time) ```
1 parent 66f9b55 commit 795935f

File tree

9 files changed

+91
-73
lines changed

9 files changed

+91
-73
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 63 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,12 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
8585
push!(edges, edge)
8686
end
8787
this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i]
88-
const_rt, const_result = abstract_call_method_with_const_args(interp, result, f, this_argtypes, match, sv, false)
89-
if const_rt !== rt && const_rt rt
90-
rt = const_rt
88+
const_result = abstract_call_method_with_const_args(interp, result, f, this_argtypes, match, sv, false)
89+
if const_result !== nothing
90+
const_rt, const_result = const_result
91+
if const_rt !== rt && const_rt rt
92+
rt = const_rt
93+
end
9194
end
9295
push!(const_results, const_result)
9396
if const_result !== nothing
@@ -107,9 +110,12 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
107110
# try constant propagation with argtypes for this match
108111
# this is in preparation for inlining, or improving the return result
109112
this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i]
110-
const_this_rt, const_result = abstract_call_method_with_const_args(interp, result, f, this_argtypes, match, sv, false)
111-
if const_this_rt !== this_rt && const_this_rt this_rt
112-
this_rt = const_this_rt
113+
const_result = abstract_call_method_with_const_args(interp, result, f, this_argtypes, match, sv, false)
114+
if const_result !== nothing
115+
const_this_rt, const_result = const_result
116+
if const_this_rt !== this_rt && const_this_rt this_rt
117+
this_rt = const_this_rt
118+
end
113119
end
114120
push!(const_results, const_result)
115121
if const_result !== nothing
@@ -523,33 +529,35 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, resul
523529
@nospecialize(f), argtypes::Vector{Any}, match::MethodMatch,
524530
sv::InferenceState, va_override::Bool)
525531
mi = maybe_get_const_prop_profitable(interp, result, f, argtypes, match, sv)
526-
mi === nothing && return Any, nothing
532+
mi === nothing && return nothing
527533
# try constant prop'
528534
inf_cache = get_inference_cache(interp)
529535
inf_result = cache_lookup(mi, argtypes, inf_cache)
530536
if inf_result === nothing
531537
# if there might be a cycle, check to make sure we don't end up
532538
# calling ourselves here.
533-
if result.edgecycle && _any(InfStackUnwind(sv)) do infstate
534-
# if the type complexity limiting didn't decide to limit the call signature (`result.edgelimited = false`)
535-
# we can relax the cycle detection by comparing `MethodInstance`s and allow inference to
536-
# propagate different constant elements if the recursion is finite over the lattice
537-
return (result.edgelimited ? match.method === infstate.linfo.def : mi === infstate.linfo) &&
538-
any(infstate.result.overridden_by_const)
539+
let result = result # prevent capturing
540+
if result.edgecycle && _any(InfStackUnwind(sv)) do infstate
541+
# if the type complexity limiting didn't decide to limit the call signature (`result.edgelimited = false`)
542+
# we can relax the cycle detection by comparing `MethodInstance`s and allow inference to
543+
# propagate different constant elements if the recursion is finite over the lattice
544+
return (result.edgelimited ? match.method === infstate.linfo.def : mi === infstate.linfo) &&
545+
any(infstate.result.overridden_by_const)
546+
end
547+
add_remark!(interp, sv, "[constprop] Edge cycle encountered")
548+
return nothing
539549
end
540-
add_remark!(interp, sv, "[constprop] Edge cycle encountered")
541-
return Any, nothing
542550
end
543551
inf_result = InferenceResult(mi, argtypes, va_override)
544552
frame = InferenceState(inf_result, #=cache=#false, interp)
545-
frame === nothing && return Any, nothing # this is probably a bad generated function (unsound), but just ignore it
553+
frame === nothing && return nothing # this is probably a bad generated function (unsound), but just ignore it
546554
frame.parent = sv
547555
push!(inf_cache, inf_result)
548-
typeinf(interp, frame) || return Any, nothing
556+
typeinf(interp, frame) || return nothing
549557
end
550558
result = inf_result.result
551559
# if constant inference hits a cycle, just bail out
552-
isa(result, InferenceState) && return Any, nothing
560+
isa(result, InferenceState) && return nothing
553561
add_backedge!(mi, sv)
554562
return result, inf_result
555563
end
@@ -1174,7 +1182,8 @@ function abstract_invoke(interp::AbstractInterpreter, argtypes::Vector{Any}, sv:
11741182
nargtype === Bottom && return CallMeta(Bottom, false)
11751183
nargtype isa DataType || return CallMeta(Any, false) # other cases are not implemented below
11761184
isdispatchelem(ft) || return CallMeta(Any, false) # check that we might not have a subtype of `ft` at runtime, before doing supertype lookup below
1177-
types = rewrap_unionall(Tuple{ft, unwrap_unionall(types).parameters...}, types)
1185+
ft = ft::DataType
1186+
types = rewrap_unionall(Tuple{ft, unwrap_unionall(types).parameters...}, types)::Type
11781187
nargtype = Tuple{ft, nargtype.parameters...}
11791188
argtype = Tuple{ft, argtype.parameters...}
11801189
result = findsup(types, method_table(interp))
@@ -1196,12 +1205,14 @@ function abstract_invoke(interp::AbstractInterpreter, argtypes::Vector{Any}, sv:
11961205
# t, a = ti.parameters[i], argtypes′[i]
11971206
# argtypes′[i] = t ⊑ a ? t : a
11981207
# end
1199-
const_rt, const_result = abstract_call_method_with_const_args(interp, result, argtype_to_function(ft′), argtypes′, match, sv, false)
1200-
if const_rt !== rt && const_rt rt
1201-
return CallMeta(const_rt, InvokeCallInfo(match, const_result))
1202-
else
1203-
return CallMeta(rt, InvokeCallInfo(match, nothing))
1208+
const_result = abstract_call_method_with_const_args(interp, result, argtype_to_function(ft′), argtypes′, match, sv, false)
1209+
if const_result !== nothing
1210+
const_rt, const_result = const_result
1211+
if const_rt !== rt && const_rt rt
1212+
return CallMeta(const_rt, InvokeCallInfo(match, const_result))
1213+
end
12041214
end
1215+
return CallMeta(rt, InvokeCallInfo(match, nothing))
12051216
end
12061217

12071218
# call where the function is known exactly
@@ -1301,19 +1312,20 @@ end
13011312
function abstract_call_opaque_closure(interp::AbstractInterpreter, closure::PartialOpaque, argtypes::Vector{Any}, sv::InferenceState)
13021313
pushfirst!(argtypes, closure.env)
13031314
sig = argtypes_to_type(argtypes)
1304-
(; rt, edge) = result = abstract_call_method(interp, closure.source::Method, sig, Core.svec(), false, sv)
1315+
(; rt, edge) = result = abstract_call_method(interp, closure.source, sig, Core.svec(), false, sv)
13051316
edge !== nothing && add_backedge!(edge, sv)
13061317
tt = closure.typ
1307-
sigT = unwrap_unionall(tt).parameters[1]
1308-
match = MethodMatch(sig, Core.svec(), closure.source::Method, sig <: rewrap_unionall(sigT, tt))
1318+
sigT = (unwrap_unionall(tt)::DataType).parameters[1]
1319+
match = MethodMatch(sig, Core.svec(), closure.source, sig <: rewrap_unionall(sigT, tt))
13091320
info = OpaqueClosureCallInfo(match)
13101321
if !result.edgecycle
1311-
const_rettype, const_result = abstract_call_method_with_const_args(interp, result, closure, argtypes,
1322+
const_result = abstract_call_method_with_const_args(interp, result, closure, argtypes,
13121323
match, sv, closure.isva)
1313-
if const_rettype rt
1314-
rt = const_rettype
1315-
end
13161324
if const_result !== nothing
1325+
const_rettype, const_result = const_result
1326+
if const_rettype rt
1327+
rt = const_rettype
1328+
end
13171329
info = ConstCallInfo(info, Union{Nothing,InferenceResult}[const_result])
13181330
end
13191331
end
@@ -1323,7 +1335,7 @@ end
13231335
function most_general_argtypes(closure::PartialOpaque)
13241336
ret = Any[]
13251337
cc = widenconst(closure)
1326-
argt = unwrap_unionall(cc).parameters[1]
1338+
argt = (unwrap_unionall(cc)::DataType).parameters[1]
13271339
if !isa(argt, DataType) || argt.name !== typename(Tuple)
13281340
argt = Tuple
13291341
end
@@ -1338,8 +1350,8 @@ function abstract_call(interp::AbstractInterpreter, fargs::Union{Nothing,Vector{
13381350
f = argtype_to_function(ft)
13391351
if isa(ft, PartialOpaque)
13401352
return abstract_call_opaque_closure(interp, ft, argtypes[2:end], sv)
1341-
elseif isa(unwrap_unionall(ft), DataType) && unwrap_unionall(ft).name === typename(Core.OpaqueClosure)
1342-
return CallMeta(rewrap_unionall(unwrap_unionall(ft).parameters[2], ft), false)
1353+
elseif (uft = unwrap_unionall(ft); isa(uft, DataType) && uft.name === typename(Core.OpaqueClosure))
1354+
return CallMeta(rewrap_unionall((uft::DataType).parameters[2], ft), false)
13431355
elseif f === nothing
13441356
# non-constant function, but the number of arguments is known
13451357
# and the ft is not a Builtin or IntrinsicFunction
@@ -1534,12 +1546,12 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
15341546
if length(e.args) == 2 && isconcretetype(t) && !ismutabletype(t)
15351547
at = abstract_eval_value(interp, e.args[2], vtypes, sv)
15361548
n = fieldcount(t)
1537-
if isa(at, Const) && isa(at.val, Tuple) && n == length(at.val) &&
1538-
let t = t; _all(i->getfield(at.val, i) isa fieldtype(t, i), 1:n); end
1549+
if isa(at, Const) && isa(at.val, Tuple) && n == length(at.val::Tuple) &&
1550+
let t = t; _all(i->getfield(at.val::Tuple, i) isa fieldtype(t, i), 1:n); end
15391551
t = Const(ccall(:jl_new_structt, Any, (Any, Any), t, at.val))
1540-
elseif isa(at, PartialStruct) && at Tuple && n == length(at.fields) &&
1541-
let t = t, at = at; _all(i->at.fields[i] fieldtype(t, i), 1:n); end
1542-
t = PartialStruct(t, at.fields)
1552+
elseif isa(at, PartialStruct) && at Tuple && n == length(at.fields::Vector{Any}) &&
1553+
let t = t, at = at; _all(i->(at.fields::Vector{Any})[i] fieldtype(t, i), 1:n); end
1554+
t = PartialStruct(t, at.fields::Vector{Any})
15431555
end
15441556
end
15451557
elseif e.head === :new_opaque_closure
@@ -1587,7 +1599,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
15871599
sym = e.args[1]
15881600
t = Bool
15891601
if isa(sym, SlotNumber)
1590-
vtyp = vtypes[slot_id(sym)]
1602+
vtyp = vtypes[slot_id(sym)]::VarState
15911603
if vtyp.typ === Bottom
15921604
t = Const(false) # never assigned previously
15931605
elseif !vtyp.undef
@@ -1602,7 +1614,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
16021614
t = Const(true)
16031615
end
16041616
elseif isa(sym, Expr) && sym.head === :static_parameter
1605-
n = sym.args[1]
1617+
n = sym.args[1]::Int
16061618
if 1 <= n <= length(sv.sptypes)
16071619
spty = sv.sptypes[n]
16081620
if isa(spty, Const)
@@ -1637,7 +1649,7 @@ function abstract_eval_global(M::Module, s::Symbol)
16371649
end
16381650

16391651
function abstract_eval_ssavalue(s::SSAValue, src::CodeInfo)
1640-
typ = src.ssavaluetypes[s.id]
1652+
typ = (src.ssavaluetypes::Vector{Any})[s.id]
16411653
if typ === NOT_FOUND
16421654
return Bottom
16431655
end
@@ -1725,6 +1737,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
17251737
isva = isa(def, Method) && def.isva
17261738
nslots = nargs - isva
17271739
slottypes = frame.slottypes
1740+
ssavaluetypes = frame.src.ssavaluetypes::Vector{Any}
17281741
while frame.pc´´ <= n
17291742
# make progress on the active ip set
17301743
local pc::Int = frame.pc´´ # current program-counter
@@ -1828,7 +1841,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
18281841
for (caller, caller_pc) in frame.cycle_backedges
18291842
# notify backedges of updated type information
18301843
typeassert(caller.stmt_types[caller_pc], VarTable) # we must have visited this statement before
1831-
if !(caller.src.ssavaluetypes[caller_pc] === Any)
1844+
if !((caller.src.ssavaluetypes::Vector{Any})[caller_pc] === Any)
18321845
# no reason to revisit if that call-site doesn't affect the final result
18331846
if caller_pc < caller.pc´´
18341847
caller.pc´´ = caller_pc
@@ -1838,6 +1851,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
18381851
end
18391852
end
18401853
elseif hd === :enter
1854+
stmt = stmt::Expr
18411855
l = stmt.args[1]::Int
18421856
frame.cur_hand = Pair{Any,Any}(l, frame.cur_hand)
18431857
# propagate type info to exception handler
@@ -1853,21 +1867,24 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
18531867
typeassert(states[l], VarTable)
18541868
frame.handler_at[l] = frame.cur_hand
18551869
elseif hd === :leave
1870+
stmt = stmt::Expr
18561871
for i = 1:((stmt.args[1])::Int)
18571872
frame.cur_hand = (frame.cur_hand::Pair{Any,Any}).second
18581873
end
18591874
else
18601875
if hd === :(=)
1876+
stmt = stmt::Expr
18611877
t = abstract_eval_statement(interp, stmt.args[2], changes, frame)
18621878
if t === Bottom
18631879
break
18641880
end
1865-
frame.src.ssavaluetypes[pc] = t
1881+
ssavaluetypes[pc] = t
18661882
lhs = stmt.args[1]
18671883
if isa(lhs, SlotNumber)
18681884
changes = StateUpdate(lhs, VarState(t, false), changes, false)
18691885
end
18701886
elseif hd === :method
1887+
stmt = stmt::Expr
18711888
fname = stmt.args[1]
18721889
if isa(fname, SlotNumber)
18731890
changes = StateUpdate(fname, VarState(Any, false), changes, false)
@@ -1882,7 +1899,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
18821899
if !isempty(frame.ssavalue_uses[pc])
18831900
record_ssa_assign(pc, t, frame)
18841901
else
1885-
frame.src.ssavaluetypes[pc] = t
1902+
ssavaluetypes[pc] = t
18861903
end
18871904
end
18881905
if frame.cur_hand !== nothing && isa(changes, StateUpdate)
@@ -1903,7 +1920,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
19031920

19041921
if t === nothing
19051922
# mark other reached expressions as `Any` to indicate they don't throw
1906-
frame.src.ssavaluetypes[pc] = Any
1923+
ssavaluetypes[pc] = Any
19071924
end
19081925

19091926
pc´ > n && break # can't proceed with the fast-path fall-through

base/compiler/inferencestate.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ function sptypes_from_meth_instance(linfo::MethodInstance)
179179
while temp isa UnionAll
180180
temp = temp.body
181181
end
182-
sigtypes = temp.parameters
182+
sigtypes = (temp::DataType).parameters
183183
for j = 1:length(sigtypes)
184184
tj = sigtypes[j]
185185
if isType(tj) && tj.parameters[1] === Pi

base/compiler/ssair/legacy.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ function replace_code_newstyle!(ci::CodeInfo, ir::IRCode, nargs::Int)
4747
for metanode in ir.meta
4848
push!(ci.code, metanode)
4949
push!(ci.codelocs, 1)
50-
push!(ci.ssavaluetypes, Any)
50+
push!(ci.ssavaluetypes::Vector{Any}, Any)
5151
push!(ci.ssaflags, 0x00)
5252
end
5353
# Translate BB Edges to statement edges

base/compiler/ssair/passes.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1066,7 +1066,7 @@ function type_lift_pass!(ir::IRCode)
10661066
if haskey(processed, id)
10671067
val = processed[id]
10681068
else
1069-
push!(worklist, (id, up_id, new_phi, i))
1069+
push!(worklist, (id, up_id, new_phi::SSAValue, i))
10701070
continue
10711071
end
10721072
else

base/compiler/ssair/slot2ssa.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -873,7 +873,7 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, domtree::DomTree, defuse,
873873
changed = false
874874
for new_idx in type_refine_phi
875875
node = new_nodes.stmts[new_idx]
876-
new_typ = recompute_type(node[:inst], ci, ir, ir.sptypes, slottypes)
876+
new_typ = recompute_type(node[:inst]::Union{PhiNode,PhiCNode}, ci, ir, ir.sptypes, slottypes)
877877
if !(node[:type] new_typ) || !(new_typ node[:type])
878878
node[:type] = new_typ
879879
changed = true

base/compiler/tfuncs.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1564,7 +1564,7 @@ function builtin_tfunction(interp::AbstractInterpreter, @nospecialize(f), argtyp
15641564
if length(argtypes) - 1 == tf[2]
15651565
argtypes = argtypes[1:end-1]
15661566
else
1567-
vatype = argtypes[end]
1567+
vatype = argtypes[end]::Core.TypeofVararg
15681568
argtypes = argtypes[1:end-1]
15691569
while length(argtypes) < tf[1]
15701570
push!(argtypes, unwrapva(vatype))
@@ -1670,7 +1670,7 @@ function return_type_tfunc(interp::AbstractInterpreter, argtypes::Vector{Any}, s
16701670
aft = argtypes[2]
16711671
if isa(aft, Const) || (isType(aft) && !has_free_typevars(aft)) ||
16721672
(isconcretetype(aft) && !(aft <: Builtin))
1673-
af_argtype = isa(tt, Const) ? tt.val : tt.parameters[1]
1673+
af_argtype = isa(tt, Const) ? tt.val : (tt::DataType).parameters[1]
16741674
if isa(af_argtype, DataType) && af_argtype <: Tuple
16751675
argtypes_vec = Any[aft, af_argtype.parameters...]
16761676
if contains_is(argtypes_vec, Union{})

0 commit comments

Comments
 (0)