Skip to content

Commit 1c6bf3b

Browse files
committed
Remove union penalties for inlining cost
I added this code back in #27057, when I first made Union-full signatures inlineable. The justification was to try to encourage the union splitting to happen on the outside. However (and I believe this changed since this code was introduced), these days inference is in complete control of union splitting and we do not take inlineability or non-inlineability of the non-unionsplit function into account when deciding how to inline. As a result, the only effect of the union split penalties was to prevent inlining of functions that are not union-split eligible (e.g. `+(::Vararg{Union{Int, Missing}, 3})`), but are nevertheless cheap by our inlining metric. There is really no reason not to try to inline such functions, so delete this logic.
1 parent 23c0418 commit 1c6bf3b

File tree

2 files changed

+11
-27
lines changed

2 files changed

+11
-27
lines changed

base/compiler/optimize.jl

Lines changed: 10 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -400,18 +400,9 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState,
400400
opt.ir = ir
401401

402402
# determine and cache inlineability
403-
union_penalties = false
404403
if !force_noinline
405404
sig = unwrap_unionall(specTypes)
406-
if isa(sig, DataType) && sig.name === Tuple.name
407-
for P in sig.parameters
408-
P = unwrap_unionall(P)
409-
if isa(P, Union)
410-
union_penalties = true
411-
break
412-
end
413-
end
414-
else
405+
if !(isa(sig, DataType) && sig.name === Tuple.name)
415406
force_noinline = true
416407
end
417408
if !is_declared_inline(src) && result === Bottom
@@ -442,7 +433,7 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState,
442433
cost_threshold += 4*default
443434
end
444435
end
445-
src.inlining_cost = inline_cost(ir, params, union_penalties, cost_threshold)
436+
src.inlining_cost = inline_cost(ir, params, cost_threshold)
446437
end
447438
end
448439
return nothing
@@ -639,7 +630,7 @@ plus_saturate(x::Int, y::Int) = max(x, y, x+y)
639630
isknowntype(@nospecialize T) = (T === Union{}) || isa(T, Const) || isconcretetype(widenconst(T))
640631

641632
function statement_cost(ex::Expr, line::Int, src::Union{CodeInfo, IRCode}, sptypes::Vector{VarState},
642-
union_penalties::Bool, params::OptimizationParams, error_path::Bool = false)
633+
params::OptimizationParams, error_path::Bool = false)
643634
head = ex.head
644635
if is_meta_expr_head(head)
645636
return 0
@@ -677,13 +668,6 @@ function statement_cost(ex::Expr, line::Int, src::Union{CodeInfo, IRCode}, sptyp
677668
return isknowntype(atyp) ? 4 : error_path ? params.inline_error_path_cost : params.inline_nonleaf_penalty
678669
elseif f === typeassert && isconstType(widenconst(argextype(ex.args[3], src, sptypes)))
679670
return 1
680-
elseif f === Core.isa
681-
# If we're in a union context, we penalize type computations
682-
# on union types. In such cases, it is usually better to perform
683-
# union splitting on the outside.
684-
if union_penalties && isa(argextype(ex.args[2], src, sptypes), Union)
685-
return params.inline_nonleaf_penalty
686-
end
687671
end
688672
fidx = find_tfunc(f)
689673
if fidx === nothing
@@ -714,7 +698,7 @@ function statement_cost(ex::Expr, line::Int, src::Union{CodeInfo, IRCode}, sptyp
714698
end
715699
a = ex.args[2]
716700
if a isa Expr
717-
cost = plus_saturate(cost, statement_cost(a, -1, src, sptypes, union_penalties, params, error_path))
701+
cost = plus_saturate(cost, statement_cost(a, -1, src, sptypes, params, error_path))
718702
end
719703
return cost
720704
elseif head === :copyast
@@ -730,11 +714,11 @@ function statement_cost(ex::Expr, line::Int, src::Union{CodeInfo, IRCode}, sptyp
730714
end
731715

732716
function statement_or_branch_cost(@nospecialize(stmt), line::Int, src::Union{CodeInfo, IRCode}, sptypes::Vector{VarState},
733-
union_penalties::Bool, params::OptimizationParams)
717+
params::OptimizationParams)
734718
thiscost = 0
735719
dst(tgt) = isa(src, IRCode) ? first(src.cfg.blocks[tgt].stmts) : tgt
736720
if stmt isa Expr
737-
thiscost = statement_cost(stmt, line, src, sptypes, union_penalties, params,
721+
thiscost = statement_cost(stmt, line, src, sptypes, params,
738722
is_stmt_throw_block(isa(src, IRCode) ? src.stmts.flag[line] : src.ssaflags[line]))::Int
739723
elseif stmt isa GotoNode
740724
# loops are generally always expensive
@@ -747,24 +731,24 @@ function statement_or_branch_cost(@nospecialize(stmt), line::Int, src::Union{Cod
747731
return thiscost
748732
end
749733

750-
function inline_cost(ir::IRCode, params::OptimizationParams, union_penalties::Bool=false,
734+
function inline_cost(ir::IRCode, params::OptimizationParams,
751735
cost_threshold::Integer=params.inline_cost_threshold)::InlineCostType
752736
bodycost::Int = 0
753737
for line = 1:length(ir.stmts)
754738
stmt = ir.stmts[line][:inst]
755-
thiscost = statement_or_branch_cost(stmt, line, ir, ir.sptypes, union_penalties, params)
739+
thiscost = statement_or_branch_cost(stmt, line, ir, ir.sptypes, params)
756740
bodycost = plus_saturate(bodycost, thiscost)
757741
bodycost > cost_threshold && return MAX_INLINE_COST
758742
end
759743
return inline_cost_clamp(bodycost)
760744
end
761745

762-
function statement_costs!(cost::Vector{Int}, body::Vector{Any}, src::Union{CodeInfo, IRCode}, sptypes::Vector{VarState}, unionpenalties::Bool, params::OptimizationParams)
746+
function statement_costs!(cost::Vector{Int}, body::Vector{Any}, src::Union{CodeInfo, IRCode}, sptypes::Vector{VarState}, params::OptimizationParams)
763747
maxcost = 0
764748
for line = 1:length(body)
765749
stmt = body[line]
766750
thiscost = statement_or_branch_cost(stmt, line, src, sptypes,
767-
unionpenalties, params)
751+
params)
768752
cost[line] = thiscost
769753
if thiscost > maxcost
770754
maxcost = thiscost

base/reflection.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1669,7 +1669,7 @@ function print_statement_costs(io::IO, @nospecialize(tt::Type);
16691669
empty!(cst)
16701670
resize!(cst, length(code.code))
16711671
sptypes = Core.Compiler.VarState[Core.Compiler.VarState(sp, false) for sp in match.sparams]
1672-
maxcost = Core.Compiler.statement_costs!(cst, code.code, code, sptypes, false, params)
1672+
maxcost = Core.Compiler.statement_costs!(cst, code.code, code, sptypes, params)
16731673
nd = ndigits(maxcost)
16741674
irshow_config = IRShow.IRShowConfig() do io, linestart, idx
16751675
print(io, idx > 0 ? lpad(cst[idx], nd+1) : " "^(nd+1), " ")

0 commit comments

Comments
 (0)