Skip to content

Commit 0dbe6fd

Browse files
committed
Make an inference hot-path slightly faster
This aims to improve performance of inference slightly by removing a dynamic dispatch from calls to `widenwrappedconditional`, which appears in various hot paths and showed up in profiling of inference. There's two changes here: 1. Improve inlining for calls to functions of the form ``` f(x::Int) = 1 f(@nospecialize(x::Any)) = 2 ``` Previously, we would peel of the `x::Int` case and then generate a dynamic dispatch for the `x::Any` case. After this change, we directly emit an `:invoke` for the `x::Any` case (as well as enabling inlining of it in general). 2. Refactor `widenwrappedconditional` itself to avoid a signature with a union in it, since ironically union splitting cannot currently deal with that (it can only split unions if they're manifest in the call arguments).
1 parent bf6d9de commit 0dbe6fd

File tree

4 files changed

+68
-27
lines changed

4 files changed

+68
-27
lines changed

base/compiler/ssair/inlining.jl

Lines changed: 48 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ function cfg_inline_unionsplit!(ir::IRCode, idx::Int,
241241
push!(from_bbs, length(state.new_cfg_blocks))
242242
# TODO: Right now we unconditionally generate a fallback block
243243
# in case of subtyping errors - This is probably unnecessary.
244-
if i != length(cases) || (!fully_covered || !params.trust_inference)
244+
if i != length(cases) || (!fully_covered || (!params.trust_inference && isdispatchtuple(cases[i].sig)))
245245
# This block will have the next condition or the final else case
246246
push!(state.new_cfg_blocks, BasicBlock(StmtRange(idx, idx)))
247247
push!(state.new_cfg_blocks[cond_bb].succs, length(state.new_cfg_blocks))
@@ -481,7 +481,8 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int,
481481
cond = true
482482
aparams, mparams = atype.parameters::SimpleVector, metharg.parameters::SimpleVector
483483
@assert length(aparams) == length(mparams)
484-
if i != length(cases) || !fully_covered || !params.trust_inference
484+
if i != length(cases) || !fully_covered ||
485+
(!params.trust_inference && isdispatchtuple(cases[i].sig))
485486
for i in 1:length(aparams)
486487
a, m = aparams[i], mparams[i]
487488
# If this is always true, we don't need to check for it
@@ -538,7 +539,7 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int,
538539
bb += 1
539540
# We're now in the fall through block, decide what to do
540541
if fully_covered
541-
if !params.trust_inference
542+
if !params.trust_inference && isdispatchtuple(cases[end].sig)
542543
e = Expr(:call, GlobalRef(Core, :throw), FATAL_TYPE_BOUND_ERROR)
543544
insert_node_here!(compact, NewInstruction(e, Union{}, line))
544545
insert_node_here!(compact, NewInstruction(ReturnNode(), Union{}, line))
@@ -1170,7 +1171,10 @@ function analyze_single_call!(
11701171
cases = InliningCase[]
11711172
local only_method = nothing # keep track of whether there is one matching method
11721173
local meth::MethodLookupResult
1173-
local fully_covered = true
1174+
local handled_all_cases = true
1175+
local any_covers_full = false
1176+
local revisit_idx = nothing
1177+
11741178
for i in 1:length(infos)
11751179
meth = infos[i].results
11761180
if meth.ambig
@@ -1179,7 +1183,7 @@ function analyze_single_call!(
11791183
return nothing
11801184
elseif length(meth) == 0
11811185
# No applicable methods; try next union split
1182-
fully_covered = false
1186+
handled_all_cases = false
11831187
continue
11841188
else
11851189
if length(meth) == 1 && only_method !== false
@@ -1192,12 +1196,38 @@ function analyze_single_call!(
11921196
only_method = false
11931197
end
11941198
end
1195-
for match in meth
1196-
fully_covered &= handle_match!(match, argtypes, flag, state, cases)
1197-
fully_covered &= match.fully_covers
1199+
for (j, match) in enumerate(meth)
1200+
any_covers_full |= match.fully_covers
1201+
if !isdispatchtuple(match.spec_types)
1202+
if !match.fully_covers
1203+
handled_all_cases = false
1204+
continue
1205+
end
1206+
if revisit_idx === nothing
1207+
revisit_idx = (i, j)
1208+
else
1209+
handled_all_cases = false
1210+
revisit_idx = nothing
1211+
end
1212+
else
1213+
handled_all_cases &= handle_match!(match, argtypes, flag, state, cases)
1214+
end
11981215
end
11991216
end
12001217

1218+
# If there's only one case that's not a dispatchtuple, we can
1219+
# still unionsplit by visiting all the other cases first.
1220+
# This is useful for code like:
1221+
# foo(x::Int) = 1
1222+
# foo(@nospecialize(x::Any)) = 2
1223+
# where we where only a small number of specific dispatchable
1224+
# cases are split off from an ::Any typed fallback.
1225+
if handled_all_cases && revisit_idx !== nothing
1226+
(i, j) = revisit_idx
1227+
match = infos[i].results[j]
1228+
handled_all_cases &= handle_match!(match, argtypes, flag, state, cases)
1229+
end
1230+
12011231
# if the signature is fully covered and there is only one applicable method,
12021232
# we can try to inline it even if the signature is not a dispatch tuple
12031233
atype = argtypes_to_type(argtypes)
@@ -1213,10 +1243,10 @@ function analyze_single_call!(
12131243
item = analyze_method!(match, argtypes, flag, state)
12141244
item === nothing && return nothing
12151245
push!(cases, InliningCase(match.spec_types, item))
1216-
fully_covered = match.fully_covers
1246+
any_covers_full = handled_all_cases = match.fully_covers
12171247
end
12181248

1219-
handle_cases!(ir, idx, stmt, atype, cases, fully_covered, todo, state.params)
1249+
handle_cases!(ir, idx, stmt, atype, cases, any_covers_full && handled_all_cases, todo, state.params)
12201250
end
12211251

12221252
# similar to `analyze_single_call!`, but with constant results
@@ -1227,7 +1257,8 @@ function handle_const_call!(
12271257
(; call, results) = cinfo
12281258
infos = isa(call, MethodMatchInfo) ? MethodMatchInfo[call] : call.matches
12291259
cases = InliningCase[]
1230-
local fully_covered = true
1260+
local handled_all_cases = true
1261+
local any_covers_full = false
12311262
local j = 0
12321263
for i in 1:length(infos)
12331264
meth = infos[i].results
@@ -1237,22 +1268,22 @@ function handle_const_call!(
12371268
return nothing
12381269
elseif length(meth) == 0
12391270
# No applicable methods; try next union split
1240-
fully_covered = false
1271+
handled_all_cases = false
12411272
continue
12421273
end
12431274
for match in meth
12441275
j += 1
12451276
result = results[j]
1277+
any_covers_full |= match.fully_covers
12461278
if isa(result, ConstResult)
12471279
case = const_result_item(result, state)
12481280
push!(cases, InliningCase(result.mi.specTypes, case))
12491281
elseif isa(result, InferenceResult)
1250-
fully_covered &= handle_inf_result!(result, argtypes, flag, state, cases)
1282+
handled_all_cases &= handle_inf_result!(result, argtypes, flag, state, cases)
12511283
else
12521284
@assert result === nothing
1253-
fully_covered &= handle_match!(match, argtypes, flag, state, cases)
1285+
handled_all_cases &= isdispatchtuple(match.spec_types) && handle_match!(match, argtypes, flag, state, cases)
12541286
end
1255-
fully_covered &= match.fully_covers
12561287
end
12571288
end
12581289

@@ -1265,17 +1296,16 @@ function handle_const_call!(
12651296
validate_sparams(mi.sparam_vals) || return nothing
12661297
item === nothing && return nothing
12671298
push!(cases, InliningCase(mi.specTypes, item))
1268-
fully_covered = atype <: mi.specTypes
1299+
any_covers_full = handled_all_cases = atype <: mi.specTypes
12691300
end
12701301

1271-
handle_cases!(ir, idx, stmt, atype, cases, fully_covered, todo, state.params)
1302+
handle_cases!(ir, idx, stmt, atype, cases, any_covers_full && handled_all_cases, todo, state.params)
12721303
end
12731304

12741305
function handle_match!(
12751306
match::MethodMatch, argtypes::Vector{Any}, flag::UInt8, state::InliningState,
12761307
cases::Vector{InliningCase})
12771308
spec_types = match.spec_types
1278-
isdispatchtuple(spec_types) || return false
12791309
item = analyze_method!(match, argtypes, flag, state)
12801310
item === nothing && return false
12811311
_any(case->case.sig === spec_types, cases) && return true

base/compiler/typelattice.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -314,15 +314,17 @@ end
314314
@inline tchanged(@nospecialize(n), @nospecialize(o)) = o === NOT_FOUND || (n !== NOT_FOUND && !(n o))
315315
@inline schanged(@nospecialize(n), @nospecialize(o)) = (n !== o) && (o === NOT_FOUND || (n !== NOT_FOUND && !issubstate(n::VarState, o::VarState)))
316316

317-
widenconditional(@nospecialize typ) = typ
318-
function widenconditional(typ::AnyConditional)
319-
if typ.vtype === Union{}
320-
return Const(false)
321-
elseif typ.elsetype === Union{}
322-
return Const(true)
323-
else
324-
return Bool
317+
function widenconditional(@nospecialize typ)
318+
if isa(typ, AnyConditional)
319+
if typ.vtype === Union{}
320+
return Const(false)
321+
elseif typ.elsetype === Union{}
322+
return Const(true)
323+
else
324+
return Bool
325+
end
325326
end
327+
return typ
326328
end
327329
widenconditional(t::LimitedAccuracy) = error("unhandled LimitedAccuracy")
328330

test/compiler/inline.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1099,3 +1099,12 @@ end
10991099
let src = code_typed1(f44200)
11001100
@test count(x -> isa(x, Core.PiNode), src.code) == 0
11011101
end
1102+
1103+
# Test that peeling off one case from (::Any) doesn't introduce
1104+
# a dynamic dispatch.
1105+
@noinline f_peel(x::Int) = Base.inferencebarrier(1)
1106+
@noinline f_peel(@nospecialize(x::Any)) = Base.inferencebarrier(2)
1107+
g_call_peel(x) = f_peel(x)
1108+
let src = code_typed1(g_call_peel, Tuple{Any})
1109+
@test count(isinvoke(:f_peel), src.code) == 2
1110+
end

test/worlds.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ f_gen265(x::Type{Int}) = 3
191191
# intermediate worlds by later additions to the method table that
192192
# would have capped those specializations if they were still valid
193193
f26506(@nospecialize(x)) = 1
194-
g26506(x) = f26506(x[1])
194+
g26506(x) = Base.inferencebarrier(f26506)(x[1])
195195
z = Any["ABC"]
196196
f26506(x::Int) = 2
197197
g26506(z) # Places an entry for f26506(::String) in mt.name.cache

0 commit comments

Comments
 (0)