Skip to content

Commit 1725b3b

Browse files
committed
optimizer: inline abstract union-split callsite
Currently the optimizer handles abstract callsite only when there is a single dispatch candidate (in most cases), and so inlining and static-dispatch are prohibited when the callsite is union-split (in other word, union-split happens only when all the dispatch candidates are concrete). However, there are certain patterns of code (most notably our Julia-level compiler code) that inherently need to deal with abstract callsite. The following example is taken from `Core.Compiler` utility: ```julia julia> @inline isType(@nospecialize t) = isa(t, DataType) && t.name === Type.body.name isType (generic function with 1 method) julia> code_typed((Any,)) do x # abstract, but no union-split, successful inlining isType(x) end |> only CodeInfo( 1 ─ %1 = (x isa Main.DataType)::Bool └── goto #3 if not %1 2 ─ %3 = π (x, DataType) │ %4 = Base.getfield(%3, :name)::Core.TypeName │ %5 = Base.getfield(Type{T}, :name)::Core.TypeName │ %6 = (%4 === %5)::Bool └── goto #4 3 ─ goto #4 4 ┄ %9 = φ (#2 => %6, #3 => false)::Bool └── return %9 ) => Bool julia> code_typed((Union{Type,Nothing},)) do x # abstract, union-split, unsuccessful inlining isType(x) end |> only CodeInfo( 1 ─ %1 = (isa)(x, Nothing)::Bool └── goto #3 if not %1 2 ─ goto #4 3 ─ %4 = Main.isType(x)::Bool └── goto #4 4 ┄ %6 = φ (#2 => false, #3 => %4)::Bool └── return %6 ) => Bool ``` (note that this is a limitation of the inlining algorithm, and so any user-provided hints like callsite inlining annotation doesn't help here) This commit enables inlining and static dispatch for abstract union-split callsite. The core idea here is that we can simulate our dispatch semantics by generating `isa` checks in order of the specialities of dispatch candidates: ```julia julia> code_typed((Union{Type,Nothing},)) do x # union-split, unsuccessful inlining isType(x) end |> only CodeInfo( 1 ─ %1 = (isa)(x, Nothing)::Bool └── goto #3 if not %1 2 ─ goto #9 3 ─ %4 = (isa)(x, Type)::Bool └── goto #8 if not %4 4 ─ %6 = π (x, Type) │ %7 = (%6 isa Main.DataType)::Bool └── goto #6 if not %7 5 ─ %9 = π (%6, DataType) │ %10 = Base.getfield(%9, :name)::Core.TypeName │ %11 = Base.getfield(Type{T}, :name)::Core.TypeName │ %12 = (%10 === %11)::Bool └── goto #7 6 ─ goto #7 7 ┄ %15 = φ (#5 => %12, #6 => false)::Bool └── goto #9 8 ─ Core.throw(ErrorException("fatal error in type inference (type bound)"))::Union{} └── unreachable 9 ┄ %19 = φ (#2 => false, #7 => %15)::Bool └── return %19 ) => Bool ``` Inlining/static-dispatch of abstract union-split callsite will improve the performance in such situations (and so this commit will improve the latency of our JIT compilation). Especially, this commit helps us avoid excessive specializations of `Core.Compiler` code by statically-resolving `@nospecialize`d callsites, and as the result, the # of precompiled statements is now reduced from `1956` ([`master`](dc45d77)) to `1901` (this commit). And also, as a side effect, the implementation of our inlining algorithm gets much simplified now since we no longer need the previous special handlings for abstract callsites. One possible drawback would be increased code size. This change seems to certainly increase the size of sysimage, but I think these numbers are in an acceptable range: > [`master`](dc45d77) ``` ❯ du -sh usr/lib/julia/* 17M usr/lib/julia/corecompiler.ji 188M usr/lib/julia/sys-o.a 164M usr/lib/julia/sys.dylib 23M usr/lib/julia/sys.dylib.dSYM 101M usr/lib/julia/sys.ji ``` > this commit ``` ❯ du -sh usr/lib/julia/* 17M usr/lib/julia/corecompiler.ji 190M usr/lib/julia/sys-o.a 166M usr/lib/julia/sys.dylib 23M usr/lib/julia/sys.dylib.dSYM 102M usr/lib/julia/sys.ji ```
1 parent dc45d77 commit 1725b3b

File tree

3 files changed

+75
-87
lines changed

3 files changed

+75
-87
lines changed

base/compiler/ssair/inlining.jl

+28-82
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ function cfg_inline_unionsplit!(ir::IRCode, idx::Int,
241241
push!(from_bbs, length(state.new_cfg_blocks))
242242
# TODO: Right now we unconditionally generate a fallback block
243243
# in case of subtyping errors - This is probably unnecessary.
244-
if i != length(cases) || (!fully_covered || (!params.trust_inference && isdispatchtuple(cases[i].sig)))
244+
if i != length(cases) || (!fully_covered || (!params.trust_inference))
245245
# This block will have the next condition or the final else case
246246
push!(state.new_cfg_blocks, BasicBlock(StmtRange(idx, idx)))
247247
push!(state.new_cfg_blocks[cond_bb].succs, length(state.new_cfg_blocks))
@@ -313,7 +313,6 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
313313
spec = item.spec::ResolvedInliningSpec
314314
sparam_vals = item.mi.sparam_vals
315315
def = item.mi.def::Method
316-
inline_cfg = spec.ir.cfg
317316
linetable_offset::Int32 = length(linetable)
318317
# Append the linetable of the inlined function to our line table
319318
inlined_at = Int(compact.result[idx][:line])
@@ -472,6 +471,10 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int,
472471
pn = PhiNode()
473472
local bb = compact.active_result_bb
474473
@assert length(bbs) >= length(cases)
474+
# we may deal with abstract union-split callsites here,
475+
# and we need to sort all inlining candidates by signature speciality
476+
# so that the generated `isa` checks simulates the dispatch semantics
477+
sort!(cases; lt=morespecific, by=case::InliningCase->case.sig)
475478
for i in 1:length(cases)
476479
ithcase = cases[i]
477480
mtype = ithcase.sig::DataType # checked within `handle_cases!`
@@ -480,8 +483,7 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int,
480483
cond = true
481484
nparams = fieldcount(atype)
482485
@assert nparams == fieldcount(mtype)
483-
if i != length(cases) || !fully_covered ||
484-
(!params.trust_inference && isdispatchtuple(cases[i].sig))
486+
if i != length(cases) || !fully_covered || !params.trust_inference
485487
for i = 1:nparams
486488
a, m = fieldtype(atype, i), fieldtype(mtype, i)
487489
# If this is always true, we don't need to check for it
@@ -538,7 +540,7 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int,
538540
bb += 1
539541
# We're now in the fall through block, decide what to do
540542
if fully_covered
541-
if !params.trust_inference && isdispatchtuple(cases[end].sig)
543+
if !params.trust_inference
542544
e = Expr(:call, GlobalRef(Core, :throw), FATAL_TYPE_BOUND_ERROR)
543545
insert_node_here!(compact, NewInstruction(e, Union{}, line))
544546
insert_node_here!(compact, NewInstruction(ReturnNode(), Union{}, line))
@@ -561,7 +563,7 @@ function batch_inline!(todo::Vector{Pair{Int, Any}}, ir::IRCode, linetable::Vect
561563
state = CFGInliningState(ir)
562564
for (idx, item) in todo
563565
if isa(item, UnionSplit)
564-
cfg_inline_unionsplit!(ir, idx, item::UnionSplit, state, params)
566+
cfg_inline_unionsplit!(ir, idx, item, state, params)
565567
else
566568
item = item::InliningTodo
567569
spec = item.spec::ResolvedInliningSpec
@@ -1175,12 +1177,8 @@ function analyze_single_call!(
11751177
sig::Signature, state::InliningState, todo::Vector{Pair{Int, Any}})
11761178
argtypes = sig.argtypes
11771179
cases = InliningCase[]
1178-
local only_method = nothing # keep track of whether there is one matching method
1179-
local meth::MethodLookupResult
1180+
local any_fully_covered = false
11801181
local handled_all_cases = true
1181-
local any_covers_full = false
1182-
local revisit_idx = nothing
1183-
11841182
for i in 1:length(infos)
11851183
meth = infos[i].results
11861184
if meth.ambig
@@ -1191,66 +1189,20 @@ function analyze_single_call!(
11911189
# No applicable methods; try next union split
11921190
handled_all_cases = false
11931191
continue
1194-
else
1195-
if length(meth) == 1 && only_method !== false
1196-
if only_method === nothing
1197-
only_method = meth[1].method
1198-
elseif only_method !== meth[1].method
1199-
only_method = false
1200-
end
1201-
else
1202-
only_method = false
1203-
end
12041192
end
1205-
for (j, match) in enumerate(meth)
1206-
any_covers_full |= match.fully_covers
1207-
if !isdispatchtuple(match.spec_types)
1208-
if !match.fully_covers
1209-
handled_all_cases = false
1210-
continue
1211-
end
1212-
if revisit_idx === nothing
1213-
revisit_idx = (i, j)
1214-
else
1215-
handled_all_cases = false
1216-
revisit_idx = nothing
1217-
end
1218-
else
1219-
handled_all_cases &= handle_match!(match, argtypes, flag, state, cases)
1220-
end
1193+
for match in meth
1194+
handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, true)
1195+
any_fully_covered |= match.fully_covers
12211196
end
12221197
end
12231198

1224-
atype = argtypes_to_type(argtypes)
1225-
if handled_all_cases && revisit_idx !== nothing
1226-
# If there's only one case that's not a dispatchtuple, we can
1227-
# still unionsplit by visiting all the other cases first.
1228-
# This is useful for code like:
1229-
# foo(x::Int) = 1
1230-
# foo(@nospecialize(x::Any)) = 2
1231-
# where we where only a small number of specific dispatchable
1232-
# cases are split off from an ::Any typed fallback.
1233-
(i, j) = revisit_idx
1234-
match = infos[i].results[j]
1235-
handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, true)
1236-
elseif length(cases) == 0 && only_method isa Method
1237-
# if the signature is fully covered and there is only one applicable method,
1238-
# we can try to inline it even if the signature is not a dispatch tuple.
1239-
# -- But don't try it if we already tried to handle the match in the revisit_idx
1240-
# case, because that'll (necessarily) be the same method.
1241-
if length(infos) > 1
1242-
(metharg, methsp) = ccall(:jl_type_intersection_with_env, Any, (Any, Any),
1243-
atype, only_method.sig)::SimpleVector
1244-
match = MethodMatch(metharg, methsp::SimpleVector, only_method, true)
1245-
else
1246-
@assert length(meth) == 1
1247-
match = meth[1]
1248-
end
1249-
handle_match!(match, argtypes, flag, state, cases, true) || return nothing
1250-
any_covers_full = handled_all_cases = match.fully_covers
1199+
if !handled_all_cases
1200+
# if we've not seen all candidates, union split is valid only for dispatch tuples
1201+
filter!(case::InliningCase->isdispatchtuple(case.sig), cases)
12511202
end
12521203

1253-
handle_cases!(ir, idx, stmt, atype, cases, any_covers_full && handled_all_cases, todo, state.params)
1204+
handle_cases!(ir, idx, stmt, argtypes_to_type(argtypes), cases,
1205+
handled_all_cases & any_fully_covered, todo, state.params)
12541206
end
12551207

12561208
# similar to `analyze_single_call!`, but with constant results
@@ -1261,8 +1213,8 @@ function handle_const_call!(
12611213
(; call, results) = cinfo
12621214
infos = isa(call, MethodMatchInfo) ? MethodMatchInfo[call] : call.matches
12631215
cases = InliningCase[]
1216+
local any_fully_covered = false
12641217
local handled_all_cases = true
1265-
local any_covers_full = false
12661218
local j = 0
12671219
for i in 1:length(infos)
12681220
meth = infos[i].results
@@ -1278,32 +1230,26 @@ function handle_const_call!(
12781230
for match in meth
12791231
j += 1
12801232
result = results[j]
1281-
any_covers_full |= match.fully_covers
1233+
any_fully_covered |= match.fully_covers
12821234
if isa(result, ConstResult)
12831235
case = const_result_item(result, state)
12841236
push!(cases, InliningCase(result.mi.specTypes, case))
12851237
elseif isa(result, InferenceResult)
1286-
handled_all_cases &= handle_inf_result!(result, argtypes, flag, state, cases)
1238+
handled_all_cases &= handle_inf_result!(result, argtypes, flag, state, cases, true)
12871239
else
12881240
@assert result === nothing
1289-
handled_all_cases &= handle_match!(match, argtypes, flag, state, cases)
1241+
handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, true)
12901242
end
12911243
end
12921244
end
12931245

1294-
# if the signature is fully covered and there is only one applicable method,
1295-
# we can try to inline it even if the signature is not a dispatch tuple
1296-
atype = argtypes_to_type(argtypes)
1297-
if length(cases) == 0
1298-
length(results) == 1 || return nothing
1299-
result = results[1]
1300-
isa(result, InferenceResult) || return nothing
1301-
handle_inf_result!(result, argtypes, flag, state, cases, true) || return nothing
1302-
spec_types = cases[1].sig
1303-
any_covers_full = handled_all_cases = atype <: spec_types
1246+
if !handled_all_cases
1247+
# if we've not seen all candidates, union split is valid only for dispatch tuples
1248+
filter!(case::InliningCase->isdispatchtuple(case.sig), cases)
13041249
end
13051250

1306-
handle_cases!(ir, idx, stmt, atype, cases, any_covers_full && handled_all_cases, todo, state.params)
1251+
handle_cases!(ir, idx, stmt, argtypes_to_type(argtypes), cases,
1252+
handled_all_cases & any_fully_covered, todo, state.params)
13071253
end
13081254

13091255
function handle_match!(
@@ -1313,7 +1259,6 @@ function handle_match!(
13131259
allow_abstract || isdispatchtuple(spec_types) || return false
13141260
item = analyze_method!(match, argtypes, flag, state)
13151261
item === nothing && return false
1316-
_any(case->case.sig === spec_types, cases) && return true
13171262
push!(cases, InliningCase(spec_types, item))
13181263
return true
13191264
end
@@ -1445,7 +1390,8 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState)
14451390

14461391
analyze_single_call!(ir, idx, stmt, infos, flag, sig, state, todo)
14471392
end
1448-
todo
1393+
1394+
return todo
14491395
end
14501396

14511397
function linear_inline_eligible(ir::IRCode)

base/sort.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ module Sort
55
import ..@__MODULE__, ..parentmodule
66
const Base = parentmodule(@__MODULE__)
77
using .Base.Order
8-
using .Base: copymutable, LinearIndices, length, (:),
8+
using .Base: copymutable, LinearIndices, length, (:), iterate,
99
eachindex, axes, first, last, similar, zip, OrdinalRange,
1010
AbstractVector, @inbounds, AbstractRange, @eval, @inline, Vector, @noinline,
1111
AbstractMatrix, AbstractUnitRange, isless, identity, eltype, >, <, <=, >=, |, +, -, *, !,

test/compiler/inline.jl

+46-4
Original file line numberDiff line numberDiff line change
@@ -810,6 +810,48 @@ let
810810
@test invoke(Any[10]) === false
811811
end
812812

813+
# test union-split, non-dispatchtuple callsite inlining
814+
815+
@constprop :none @noinline abstract_unionsplit(@nospecialize x::Any) = Base.inferencebarrier(0)
816+
@constprop :none @noinline abstract_unionsplit(@nospecialize x::Integer) = Base.inferencebarrier(1)
817+
let src = code_typed1((Any,)) do x
818+
abstract_unionsplit(x)
819+
end
820+
@test count(isinvoke(:abstract_unionsplit), src.code) == 2
821+
@test count(iscall((src, abstract_unionsplit)), src.code) == 0 # no need to insert a fallback dispatch
822+
end
823+
824+
@constprop :none @noinline abstract_unionsplit_fallback(@nospecialize x::Number) = Base.inferencebarrier(0)
825+
@constprop :none @noinline abstract_unionsplit_fallback(@nospecialize x::Integer) = Base.inferencebarrier(1)
826+
let src = code_typed1((Any,)) do x
827+
abstract_unionsplit_fallback(x)
828+
end
829+
@test count(isinvoke(:abstract_unionsplit_fallback), src.code) == 2
830+
@test count(iscall((src, abstract_unionsplit_fallback)), src.code) == 1 # fallback dispatch
831+
end
832+
833+
@constprop :aggressive @inline abstract_unionsplit(c, @nospecialize x::Any) = (c && println("erase me"); typeof(x))
834+
@constprop :aggressive @inline abstract_unionsplit(c, @nospecialize x::Integer) = (c && println("erase me"); typeof(x))
835+
let src = code_typed1((Any,)) do x
836+
abstract_unionsplit(false, x)
837+
end
838+
@test count(iscall((src, typeof)), src.code) == 2
839+
@test count(isinvoke(:println), src.code) == 0
840+
@test count(iscall((src, println)), src.code) == 0
841+
@test count(iscall((src, abstract_unionsplit)), src.code) == 0 # no need to insert a fallback dispatch
842+
end
843+
844+
@constprop :aggressive @inline abstract_unionsplit_fallback(c, @nospecialize x::Number) = (c && println("erase me"); typeof(x))
845+
@constprop :aggressive @inline abstract_unionsplit_fallback(c, @nospecialize x::Integer) = (c && println("erase me"); typeof(x))
846+
let src = code_typed1((Any,)) do x
847+
abstract_unionsplit_fallback(false, x)
848+
end
849+
@test count(iscall((src, typeof)), src.code) == 2
850+
@test count(isinvoke(:println), src.code) == 0
851+
@test count(iscall((src, println)), src.code) == 0
852+
@test count(iscall((src, abstract_unionsplit_fallback)), src.code) == 1 # fallback dispatch
853+
end
854+
813855
# issue 43104
814856

815857
@inline isGoodType(@nospecialize x::Type) =
@@ -1090,11 +1132,11 @@ end
10901132

10911133
global x44200::Int = 0
10921134
function f44200()
1093-
global x = 0
1094-
while x < 10
1095-
x += 1
1135+
global x44200 = 0
1136+
while x44200 < 10
1137+
x44200 += 1
10961138
end
1097-
x
1139+
x44200
10981140
end
10991141
let src = code_typed1(f44200)
11001142
@test count(x -> isa(x, Core.PiNode), src.code) == 0

0 commit comments

Comments
 (0)