Skip to content

Commit 21bb0c7

Browse files
authored
Remove union penalties for inlining cost (#50429)
I added this code back in #27057, when I first made Union-full signatures inlineable. The justification was to try to encourage the union splitting to happen on the outside. However (and I believe this changed since this code was introduced), these days inference is in complete control of union splitting and we do not take inlineability or non-inlineability of the non-unionsplit function into account when deciding how to inline. As a result, the only effect of the union split penalties was to prevent inlining of functions that are not union-split eligible (e.g. `+(::Vararg{Union{Int, Missing}, 3})`), but are nevertheless cheap by our inlining metric. There is really no reason not to try to inline such functions, so delete this logic.
1 parent 0718995 commit 21bb0c7

File tree

3 files changed

+15
-31
lines changed

3 files changed

+15
-31
lines changed

base/compiler/optimize.jl

Lines changed: 10 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -406,18 +406,9 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState,
406406
opt.ir = ir
407407

408408
# determine and cache inlineability
409-
union_penalties = false
410409
if !force_noinline
411410
sig = unwrap_unionall(specTypes)
412-
if isa(sig, DataType) && sig.name === Tuple.name
413-
for P in sig.parameters
414-
P = unwrap_unionall(P)
415-
if isa(P, Union)
416-
union_penalties = true
417-
break
418-
end
419-
end
420-
else
411+
if !(isa(sig, DataType) && sig.name === Tuple.name)
421412
force_noinline = true
422413
end
423414
if !is_declared_inline(src) && result === Bottom
@@ -448,7 +439,7 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState,
448439
cost_threshold += 4*default
449440
end
450441
end
451-
src.inlining_cost = inline_cost(ir, params, union_penalties, cost_threshold)
442+
src.inlining_cost = inline_cost(ir, params, cost_threshold)
452443
end
453444
end
454445
return nothing
@@ -645,7 +636,7 @@ plus_saturate(x::Int, y::Int) = max(x, y, x+y)
645636
isknowntype(@nospecialize T) = (T === Union{}) || isa(T, Const) || isconcretetype(widenconst(T))
646637

647638
function statement_cost(ex::Expr, line::Int, src::Union{CodeInfo, IRCode}, sptypes::Vector{VarState},
648-
union_penalties::Bool, params::OptimizationParams, error_path::Bool = false)
639+
params::OptimizationParams, error_path::Bool = false)
649640
head = ex.head
650641
if is_meta_expr_head(head)
651642
return 0
@@ -683,13 +674,6 @@ function statement_cost(ex::Expr, line::Int, src::Union{CodeInfo, IRCode}, sptyp
683674
return isknowntype(atyp) ? 4 : error_path ? params.inline_error_path_cost : params.inline_nonleaf_penalty
684675
elseif f === typeassert && isconstType(widenconst(argextype(ex.args[3], src, sptypes)))
685676
return 1
686-
elseif f === Core.isa
687-
# If we're in a union context, we penalize type computations
688-
# on union types. In such cases, it is usually better to perform
689-
# union splitting on the outside.
690-
if union_penalties && isa(argextype(ex.args[2], src, sptypes), Union)
691-
return params.inline_nonleaf_penalty
692-
end
693677
end
694678
fidx = find_tfunc(f)
695679
if fidx === nothing
@@ -720,7 +704,7 @@ function statement_cost(ex::Expr, line::Int, src::Union{CodeInfo, IRCode}, sptyp
720704
end
721705
a = ex.args[2]
722706
if a isa Expr
723-
cost = plus_saturate(cost, statement_cost(a, -1, src, sptypes, union_penalties, params, error_path))
707+
cost = plus_saturate(cost, statement_cost(a, -1, src, sptypes, params, error_path))
724708
end
725709
return cost
726710
elseif head === :copyast
@@ -736,11 +720,11 @@ function statement_cost(ex::Expr, line::Int, src::Union{CodeInfo, IRCode}, sptyp
736720
end
737721

738722
function statement_or_branch_cost(@nospecialize(stmt), line::Int, src::Union{CodeInfo, IRCode}, sptypes::Vector{VarState},
739-
union_penalties::Bool, params::OptimizationParams)
723+
params::OptimizationParams)
740724
thiscost = 0
741725
dst(tgt) = isa(src, IRCode) ? first(src.cfg.blocks[tgt].stmts) : tgt
742726
if stmt isa Expr
743-
thiscost = statement_cost(stmt, line, src, sptypes, union_penalties, params,
727+
thiscost = statement_cost(stmt, line, src, sptypes, params,
744728
is_stmt_throw_block(isa(src, IRCode) ? src.stmts.flag[line] : src.ssaflags[line]))::Int
745729
elseif stmt isa GotoNode
746730
# loops are generally always expensive
@@ -753,24 +737,24 @@ function statement_or_branch_cost(@nospecialize(stmt), line::Int, src::Union{Cod
753737
return thiscost
754738
end
755739

756-
function inline_cost(ir::IRCode, params::OptimizationParams, union_penalties::Bool=false,
740+
function inline_cost(ir::IRCode, params::OptimizationParams,
757741
cost_threshold::Integer=params.inline_cost_threshold)::InlineCostType
758742
bodycost::Int = 0
759743
for line = 1:length(ir.stmts)
760744
stmt = ir.stmts[line][:inst]
761-
thiscost = statement_or_branch_cost(stmt, line, ir, ir.sptypes, union_penalties, params)
745+
thiscost = statement_or_branch_cost(stmt, line, ir, ir.sptypes, params)
762746
bodycost = plus_saturate(bodycost, thiscost)
763747
bodycost > cost_threshold && return MAX_INLINE_COST
764748
end
765749
return inline_cost_clamp(bodycost)
766750
end
767751

768-
function statement_costs!(cost::Vector{Int}, body::Vector{Any}, src::Union{CodeInfo, IRCode}, sptypes::Vector{VarState}, unionpenalties::Bool, params::OptimizationParams)
752+
function statement_costs!(cost::Vector{Int}, body::Vector{Any}, src::Union{CodeInfo, IRCode}, sptypes::Vector{VarState}, params::OptimizationParams)
769753
maxcost = 0
770754
for line = 1:length(body)
771755
stmt = body[line]
772756
thiscost = statement_or_branch_cost(stmt, line, src, sptypes,
773-
unionpenalties, params)
757+
params)
774758
cost[line] = thiscost
775759
if thiscost > maxcost
776760
maxcost = thiscost

base/reflection.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1669,7 +1669,7 @@ function print_statement_costs(io::IO, @nospecialize(tt::Type);
16691669
empty!(cst)
16701670
resize!(cst, length(code.code))
16711671
sptypes = Core.Compiler.VarState[Core.Compiler.VarState(sp, false) for sp in match.sparams]
1672-
maxcost = Core.Compiler.statement_costs!(cst, code.code, code, sptypes, false, params)
1672+
maxcost = Core.Compiler.statement_costs!(cst, code.code, code, sptypes, params)
16731673
nd = ndigits(maxcost)
16741674
irshow_config = IRShow.IRShowConfig() do io, linestart, idx
16751675
print(io, idx > 0 ? lpad(cst[idx], nd+1) : " "^(nd+1), " ")

test/offsetarray.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -627,15 +627,15 @@ end
627627
B = OffsetArray(reshape(1:24, 4, 3, 2), -5, 6, -7)
628628
for R in (fill(0, -4:-1), fill(0, -4:-1, 7:7), fill(0, -4:-1, 7:7, -6:-6))
629629
@test @inferred(maximum!(R, B)) == reshape(maximum(B, dims=(2,3)), axes(R)) == reshape(21:24, axes(R))
630-
@test @allocated(maximum!(R, B)) <= 800
630+
@test @allocated(maximum!(R, B)) <= 1300
631631
@test @inferred(minimum!(R, B)) == reshape(minimum(B, dims=(2,3)), axes(R)) == reshape(1:4, axes(R))
632-
@test @allocated(minimum!(R, B)) <= 800
632+
@test @allocated(minimum!(R, B)) <= 1300
633633
end
634634
for R in (fill(0, -4:-4, 7:9), fill(0, -4:-4, 7:9, -6:-6))
635635
@test @inferred(maximum!(R, B)) == reshape(maximum(B, dims=(1,3)), axes(R)) == reshape(16:4:24, axes(R))
636-
@test @allocated(maximum!(R, B)) <= 800
636+
@test @allocated(maximum!(R, B)) <= 1300
637637
@test @inferred(minimum!(R, B)) == reshape(minimum(B, dims=(1,3)), axes(R)) == reshape(1:4:9, axes(R))
638-
@test @allocated(minimum!(R, B)) <= 800
638+
@test @allocated(minimum!(R, B)) <= 1300
639639
end
640640
@test_throws DimensionMismatch maximum!(fill(0, -4:-1, 7:7, -6:-6, 1:1), B)
641641
@test_throws DimensionMismatch minimum!(fill(0, -4:-1, 7:7, -6:-6, 1:1), B)

0 commit comments

Comments
 (0)