-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NewOptimizer] Better handling in the presence of select value #26969
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
The benchmarks contain code like this: ``` x::Union{Nothing, Int} result += ifelse(x === nothing, 0, x) ``` which, perhaps somewhat ironically is quite a bit harder on the new optimizer than an equivalent code sequence using ternary operators. The reason for this is that ifelse gets inferred as `Union{Int, Nothing}`, creating a phi node of that type, which then causes a union split + that the optimizer can't really get rid of easily. What this commit does is add some local improvements to help with the situation. First, it adds some minimal back inference during inlining. As a result, when inlining decides to unionsplit `ifelse(x === nothing, 0, x::Union{Nothing, Int})`, it looks back at the definition of `x === nothing`, realizes it's constrained by the union split and inserts the appropriate boolean constant. Next, a new `type_tightening_pass` goes back and annotates more precise types for the inlinined `select_value` and phi nodes. This is sufficient to get the above code to behave reasonably and should hopefully fix the performance regression on the various union sum benchmarks seen in #26795.
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -346,6 +346,61 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector | |
return_value | ||
end | ||
|
||
# Constraints are generally small, so a linear search is the bets option | ||
function find_constraint(val, constraints) | ||
for i = 1:length(constraints) | ||
if val === constraints[i][1] | ||
return constraints[i][2] | ||
end | ||
end | ||
return nothing | ||
end | ||
|
||
# Performs minimal backwards inference to catch a couple of interesting, common cases | ||
function minimal_backinf(compact, constraints, unconstrained_types, argexprs) | ||
for i = 2:length(argexprs) | ||
isa(argexprs[i], SSAValue) || continue | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lift |
||
# Check if the argexpr is in the constraint list directly | ||
c = find_constraint(argexprs[i], constraints) | ||
if c !== nothing | ||
unconstrained_types[i] = c | ||
end | ||
# For boolean values check for type predicates on any of the constraints | ||
ut = unconstrained_types[i] | ||
if ut === Bool | ||
def = compact[argexprs[i]] | ||
isa(def, Expr) || continue | ||
if is_known_call(def, ===, compact) | ||
v1, v2, = def.args[2:3] | ||
c = find_constraint(v1, constraints) | ||
if c !== nothing | ||
refined = egal_tfunc(c, compact_exprtype(compact, v2)) | ||
if !(ut ⊑ refined) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. might be better to preserve the inference type, unless we've proven that we have a narrower bound (instead of checking if we're certain we don't have a wider bound): |
||
unconstrained_types[i] = refined | ||
end | ||
end | ||
c = find_constraint(v2, constraints) | ||
if c !== nothing | ||
refined = egal_tfunc(compact_exprtype(compact, v1), c) | ||
if !(ut ⊑ refined) | ||
unconstrained_types[i] = refined | ||
end | ||
end | ||
elseif is_known_call(def, isa, compact) | ||
v = def.args[2] | ||
c = find_constraint(v, constraints) | ||
if c !== nothing | ||
refined = isa_tfunc(c, compact_exprtype(compact, def.args[3])) | ||
if !(ut ⊑ refined) | ||
unconstrained_types[i] = refined | ||
end | ||
end | ||
end | ||
end | ||
end | ||
unconstrained_types | ||
end | ||
|
||
function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int, | ||
argexprs::Vector{Any}, linetable::Vector{LineInfoNode}, | ||
item::UnionSplit, boundscheck::Symbol, todo_bbs::Vector{Tuple{Int, Int}}) | ||
|
@@ -379,11 +434,40 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int, | |
insert_node_here!(compact, GotoIfNot(cond, next_cond_bb), Union{}, line) | ||
bb = next_cond_bb - 1 | ||
finish_current_bb!(compact) | ||
# Insert Pi nodes here | ||
if !isa(case, ConstantCase) | ||
argexprs′ = copy(argexprs) | ||
constraints = Pair{SSAValue, Any}[] | ||
unconstrained_types = Any[atype.parameters...] | ||
for i = 2:length(metharg.parameters) | ||
a, m = unconstrained_types[i], metharg.parameters[i] | ||
isa(argexprs[i], SSAValue) || continue | ||
if !(a <: m) | ||
push!(constraints, Pair{SSAValue, Any}(argexprs[i], m)) | ||
end | ||
end | ||
constrained_types = minimal_backinf(compact, constraints, unconstrained_types, argexprs) | ||
for i = 2:length(metharg.parameters) | ||
if !(atype.parameters[i] ⊑ constrained_types[i]) | ||
if isa(constrained_types[i], Const) | ||
argexprs′[i] = constrained_types[i].val | ||
else | ||
ct = widenconst(constrained_types[i]) | ||
if isa(ct, DataType) && isdefined(ct, :instance) | ||
argexprs′[i] = ct.instance | ||
else | ||
argexprs′[i] = insert_node_here!(compact, PiNode(argexprs′[i], constrained_types[i]), | ||
constrained_types[i], line) | ||
end | ||
end | ||
end | ||
end | ||
else | ||
argexprs′ = argexprs | ||
end | ||
if isa(case, InliningTodo) | ||
val = ir_inline_item!(compact, idx, argexprs, linetable, case, boundscheck, todo_bbs) | ||
val = ir_inline_item!(compact, idx, argexprs′, linetable, case, boundscheck, todo_bbs) | ||
elseif isa(case, MethodInstance) | ||
val = insert_node_here!(compact, Expr(:invoke, case, argexprs...), typ, line) | ||
val = insert_node_here!(compact, Expr(:invoke, case, argexprs′...), typ, line) | ||
else | ||
case = case::ConstantCase | ||
val = case.val | ||
|
@@ -854,6 +938,7 @@ function assemble_inline_todo!(ir::IRCode, linetable::Vector{LineInfoNode}, sv:: | |
# Now, if profitable union split the atypes into dispatch tuples and match the appropriate method | ||
nu = countunionsplit(atypes) | ||
if nu != 1 && nu <= sv.params.MAX_UNION_SPLITTING | ||
fully_covered = true | ||
for sig in UnionSplitSignature(atypes) | ||
metharg′ = argtypes_to_type(sig) | ||
if !isdispatchtuple(metharg′) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add type annotations