Skip to content

optimizer: improve general code quality #42357

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1316,7 +1316,7 @@ function iterate(compact::IncrementalCompact, (idx, active_bb)::Tuple{Int, Int}=
compact.result[old_result_idx][:inst]), (compact.idx, active_bb)
end

function maybe_erase_unused!(extra_worklist, compact, idx, callback = x->nothing)
function maybe_erase_unused!(extra_worklist::Vector{Int}, compact::IncrementalCompact, idx::Int, callback = x::SSAValue->nothing)
stmt = compact.result[idx][:inst]
stmt === nothing && return false
if compact_exprtype(compact, SSAValue(idx)) === Bottom
Expand Down
90 changes: 59 additions & 31 deletions base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ function compute_value_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{I
end
end

function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#), pi_callback=(pi, idx)->false)
function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#),
callback = (@nospecialize(pi), @nospecialize(idx)) -> false)
while true
if isa(defssa, OldSSAValue)
if already_inserted(compact, defssa)
Expand All @@ -124,7 +125,7 @@ function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSA
end
def = compact[defssa]
if isa(def, PiNode)
if pi_callback(def, defssa)
if callback(def, defssa)
return defssa
end
def = def.val
Expand All @@ -135,7 +136,7 @@ function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSA
end
defssa = def
elseif isa(def, AnySSAValue)
pi_callback(def, defssa)
callback(def, defssa)
if isa(def, SSAValue)
is_old(compact, defssa) && (def = OldSSAValue(def.id))
end
Expand All @@ -148,12 +149,15 @@ function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSA
end
end

function simple_walk_constraint(compact::IncrementalCompact, @nospecialize(defidx), @nospecialize(typeconstraint) = types(compact)[defidx])
function simple_walk_constraint(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#),
@nospecialize(typeconstraint) = types(compact)[defssa])
callback = function (@nospecialize(pi), @nospecialize(idx))
isa(pi, PiNode) && (typeconstraint = typeintersect(typeconstraint, widenconst(pi.typ)))
if isa(pi, PiNode)
typeconstraint = typeintersect(typeconstraint, widenconst(pi.typ))
end
return false
end
def = simple_walk(compact, defidx, callback)
def = simple_walk(compact, defssa, callback)
return Pair{Any, Any}(def, typeconstraint)
end

Expand Down Expand Up @@ -273,8 +277,10 @@ function is_getfield_captures(@nospecialize(def), compact::IncrementalCompact)
return oc ⊑ Core.OpaqueClosure
end

function lift_leaves(compact::IncrementalCompact, @nospecialize(stmt),
@nospecialize(result_t), field::Int, leaves::Vector{Any})
# try to compute lifted values that can replace `getfield(x, field)` call
# where `x` is an immutable struct that are defined at any of `leaves`
function lift_leaves(compact::IncrementalCompact,
@nospecialize(result_t), field::Int, leaves::Vector{Any})
# For every leaf, the lifted value
lifted_leaves = IdDict{Any, Any}()
maybe_undef = false
Expand Down Expand Up @@ -396,13 +402,13 @@ function lift_leaves(compact::IncrementalCompact, @nospecialize(stmt),
elseif isa(leaf, Union{Argument, Expr})
return nothing
end
!ismutable(leaf) || return nothing
ismutable(leaf) && return nothing
isdefined(leaf, field) || return nothing
val = getfield(leaf, field)
is_inlineable_constant(val) || return nothing
lifted_leaves[leaf_key] = RefValue{Any}(quoted(val))
end
lifted_leaves, maybe_undef
return lifted_leaves, maybe_undef
end

make_MaybeUndef(@nospecialize(typ)) = isa(typ, MaybeUndef) ? typ : MaybeUndef(typ)
Expand All @@ -415,13 +421,11 @@ function lift_comparison!(compact::IncrementalCompact, idx::Int,
typeconstraint = widenconst(c2)
val = stmt.args[3]
else
cmp = c2
cmp = c2::Const
typeconstraint = widenconst(c1)
val = stmt.args[2]
end

is_type_only = isdefined(typeof(cmp), :instance)

if isa(val, Union{OldSSAValue, SSAValue})
val, typeconstraint = simple_walk_constraint(compact, val, typeconstraint)
end
Expand Down Expand Up @@ -497,7 +501,7 @@ function perform_lifting!(compact::IncrementalCompact,
if is_old(compact, old_node_ssa) && isa(val, SSAValue)
val = OldSSAValue(val.id)
end
if isa(val, Union{NewSSAValue, SSAValue, OldSSAValue})
if isa(val, AnySSAValue)
val = simple_walk(compact, val)
end
if val in keys(lifted_leaves)
Expand All @@ -508,11 +512,12 @@ function perform_lifting!(compact::IncrementalCompact,
continue
end
lifted_val = lifted_val.x
if isa(lifted_val, Union{NewSSAValue, SSAValue, OldSSAValue})
lifted_val = simple_walk(compact, lifted_val, (pi, idx)->true)
if isa(lifted_val, AnySSAValue)
callback = (@nospecialize(pi), @nospecialize(idx)) -> true
lifted_val = simple_walk(compact, lifted_val, callback)
end
push!(new_node.values, lifted_val)
elseif isa(val, Union{NewSSAValue, SSAValue, OldSSAValue}) && val in keys(reverse_mapping)
elseif isa(val, AnySSAValue) && val in keys(reverse_mapping)
push!(new_node.edges, edge)
push!(new_node.values, lifted_phis[reverse_mapping[val]].ssa)
else
Expand All @@ -532,14 +537,31 @@ function perform_lifting!(compact::IncrementalCompact,

if stmt_val in keys(lifted_leaves)
stmt_val = lifted_leaves[stmt_val]
elseif isa(stmt_val, Union{NewSSAValue, SSAValue, OldSSAValue}) && stmt_val in keys(reverse_mapping)
elseif isa(stmt_val, AnySSAValue) && stmt_val in keys(reverse_mapping)
stmt_val = RefValue{Any}(lifted_phis[reverse_mapping[stmt_val]].ssa)
end

return stmt_val
end

assertion_counter = 0
"""
getfield_elim_pass!(ir::IRCode) -> newir::IRCode

`getfield` elimination pass, a.k.a. Scalar Replacements of Aggregates optimization.

This pass is based on a local alias analysis that collects field information by def-use chain walking.
It looks for struct allocation sites ("definitions"), and `getfield` calls as well as
`:foreigncall`s that preserve the structs ("usages"). If "definitions" have enough information,
then this pass will replace corresponding usages with lifted values.
`mutable struct`s require additional cares and need to be handled separately from immutables.
For `mutable struct`s, `setfield!` calls account for "definitions" also, and the pass should
give up the lifting conservatively when there are any "intermediate usages" that may escape
the mutable struct (e.g. non-inlined generic function call that takes the mutable struct as
its argument).

In a case when all usages are fully eliminated, `struct` allocation may also be erased as
a result of dead code elimination.
"""
function getfield_elim_pass!(ir::IRCode)
compact = IncrementalCompact(ir)
insertions = Vector{Any}()
Expand All @@ -554,7 +576,6 @@ function getfield_elim_pass!(ir::IRCode)
result_t = compact_exprtype(compact, SSAValue(idx))
is_getfield = is_setfield = false
field_ordering = :unspecified
is_ccall = false
# Step 1: Check whether the statement we're looking at is a getfield/setfield!
if is_known_call(stmt, setfield!, compact)
is_setfield = true
Expand Down Expand Up @@ -610,8 +631,8 @@ function getfield_elim_pass!(ir::IRCode)
old_preserves = stmt.args[(6+nccallargs):end]
for (pidx, preserved_arg) in enumerate(old_preserves)
isa(preserved_arg, SSAValue) || continue
let intermediaries = IdSet()
callback = function(@nospecialize(pi), ssa::AnySSAValue)
let intermediaries = IdSet{Int}()
callback = function (@nospecialize(pi), @nospecialize(ssa))
push!(intermediaries, ssa.id)
return false
end
Expand Down Expand Up @@ -670,8 +691,8 @@ function getfield_elim_pass!(ir::IRCode)

if ismutabletype(struct_typ)
isa(def, SSAValue) || continue
let intermediaries = IdSet()
callback = function(@nospecialize(pi), ssa::AnySSAValue)
let intermediaries = IdSet{Int}()
callback = function (@nospecialize(pi), @nospecialize(ssa))
push!(intermediaries, ssa.id)
return false
end
Expand All @@ -691,6 +712,8 @@ function getfield_elim_pass!(ir::IRCode)
continue
end

# perform SROA on immutable structs here on

if isa(def, Union{OldSSAValue, SSAValue})
def, typeconstraint = simple_walk_constraint(compact, def, typeconstraint)
end
Expand All @@ -703,7 +726,7 @@ function getfield_elim_pass!(ir::IRCode)
field = try_compute_fieldidx(struct_typ, field)
field === nothing && continue

r = lift_leaves(compact, stmt, result_t, field, leaves)
r = lift_leaves(compact, result_t, field, leaves)
r === nothing && continue
lifted_leaves, any_undef = r

Expand Down Expand Up @@ -736,14 +759,13 @@ function getfield_elim_pass!(ir::IRCode)
@assert val !== nothing
end

global assertion_counter
assertion_counter::Int += 1
# global assertion_counter
# assertion_counter::Int += 1
Comment on lines +762 to +763
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🙀

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need another optimizer that optimizes our optimizer :p

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit seriously, we may want to put more efforts on integrating JET with Julia base.

#insert_node_here!(compact, Expr(:assert_egal, Symbol(string("assert_egal_", assertion_counter)), SSAValue(idx), val), nothing, 0, true)
#continue
compact[idx] = val === nothing ? nothing : val.x
end


non_dce_finish!(compact)
# Copy the use count, `simple_dce!` may modify it and for our predicate
# below we need it consistent with the state of the IR here (after tracking
Expand Down Expand Up @@ -874,11 +896,12 @@ function getfield_elim_pass!(ir::IRCode)
end
ir
end
# assertion_counter = 0

function adce_erase!(phi_uses::Vector{Int}, extra_worklist::Vector{Int}, compact::IncrementalCompact, idx::Int)
# return whether this made a change
if isa(compact.result[idx][:inst], PhiNode)
return maybe_erase_unused!(extra_worklist, compact, idx, val -> phi_uses[val.id] -= 1)
return maybe_erase_unused!(extra_worklist, compact, idx, val::SSAValue -> phi_uses[val.id] -= 1)
else
return maybe_erase_unused!(extra_worklist, compact, idx)
end
Expand All @@ -893,7 +916,7 @@ function count_uses(@nospecialize(stmt), uses::Vector{Int})
end
end

function mark_phi_cycles(compact::IncrementalCompact, safe_phis::BitSet, phi::Int)
function mark_phi_cycles!(compact::IncrementalCompact, safe_phis::BitSet, phi::Int)
worklist = Int[]
push!(worklist, phi)
while !isempty(worklist)
Expand All @@ -909,6 +932,11 @@ function mark_phi_cycles(compact::IncrementalCompact, safe_phis::BitSet, phi::In
end
end

"""
adce_pass!(ir::IRCode) -> newir::IRCode

Aggressive Dead Code Elimination pass to eliminate code.
"""
function adce_pass!(ir::IRCode)
phi_uses = fill(0, length(ir.stmts) + length(ir.new_nodes))
all_phis = Int[]
Expand Down Expand Up @@ -940,7 +968,7 @@ function adce_pass!(ir::IRCode)
for phi in all_phis
# Save any phi cycles that have non-phi uses
if compact.used_ssas[phi] - phi_uses[phi] != 0
mark_phi_cycles(compact, safe_phis, phi)
mark_phi_cycles!(compact, safe_phis, phi)
end
end
for phi in all_phis
Expand Down