Skip to content

Make SROA pass more aggressive #44494

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 92 additions & 55 deletions base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ end
function block_for_inst(index::Vector{Int}, inst::Int)
return searchsortedfirst(index, inst, lt=(<=))
end

function block_for_inst(index::Vector{BasicBlock}, inst::Int)
return searchsortedfirst(index, BasicBlock(StmtRange(inst, inst)), by=x->first(x.stmts), lt=(<=))-1
end

block_for_inst(cfg::CFG, inst::Int) = block_for_inst(cfg.index, inst)

function basic_blocks_starts(stmts::Vector{Any})
Expand Down Expand Up @@ -553,6 +558,7 @@ mutable struct IncrementalCompact
new_nodes_idx::Int
# This supports insertion while compacting
new_new_nodes::NewNodeStream # New nodes that were before the compaction point at insertion time
new_new_used_ssas::Vector{Int}
# TODO: Switch these two to a min-heap of some sort
pending_nodes::NewNodeStream # New nodes that were after the compaction point at insertion time
pending_perm::Vector{Int}
Expand All @@ -573,6 +579,7 @@ mutable struct IncrementalCompact
new_len = length(code.stmts) + length(code.new_nodes)
result = InstructionStream(new_len)
used_ssas = fill(0, new_len)
new_new_used_ssas = Vector{Int}()
blocks = code.cfg.blocks
if allow_cfg_transforms
bb_rename = Vector{Int}(undef, length(blocks))
Expand Down Expand Up @@ -615,7 +622,7 @@ mutable struct IncrementalCompact
pending_nodes = NewNodeStream()
pending_perm = Int[]
return new(code, result, result_bbs, ssa_rename, bb_rename, bb_rename, used_ssas, late_fixup, perm, 1,
new_new_nodes, pending_nodes, pending_perm,
new_new_nodes, new_new_used_ssas, pending_nodes, pending_perm,
1, 1, 1, false, allow_cfg_transforms, allow_cfg_transforms)
end

Expand All @@ -625,6 +632,7 @@ mutable struct IncrementalCompact
new_len = length(code.stmts) + length(code.new_nodes)
ssa_rename = Any[SSAValue(i) for i = 1:new_len]
used_ssas = fill(0, new_len)
new_new_used_ssas = Vector{Int}()
late_fixup = Vector{Int}()
bb_rename = Vector{Int}()
new_new_nodes = NewNodeStream()
Expand All @@ -633,7 +641,7 @@ mutable struct IncrementalCompact
return new(code, parent.result,
parent.result_bbs, ssa_rename, bb_rename, bb_rename, parent.used_ssas,
late_fixup, perm, 1,
new_new_nodes, pending_nodes, pending_perm,
new_new_nodes, new_new_used_ssas, pending_nodes, pending_perm,
1, result_offset, parent.active_result_bb, false, false, false)
end
end
Expand Down Expand Up @@ -673,16 +681,27 @@ function getindex(compact::IncrementalCompact, ssa::NewSSAValue)
return compact.new_new_nodes.stmts[ssa.id][:inst]
end

function block_for_inst(compact::IncrementalCompact, idx::SSAValue)
if idx.id < compact.result_idx
return block_for_inst(compact.result_bbs, idx.id)
else
return block_for_inst(compact.ir.cfg, idx.idx)
end
end

function count_added_node!(compact::IncrementalCompact, @nospecialize(v))
needs_late_fixup = isa(v, NewSSAValue)
if isa(v, SSAValue)
compact.used_ssas[v.id] += 1
elseif isa(v, NewSSAValue)
compact.new_new_used_ssas[v.id] += 1
else
for ops in userefs(v)
val = ops[]
if isa(val, SSAValue)
compact.used_ssas[val.id] += 1
elseif isa(val, NewSSAValue)
compact.new_new_used_ssas[val.id] += 1
needs_late_fixup = true
end
end
Expand All @@ -705,6 +724,7 @@ function insert_node!(compact::IncrementalCompact, before, inst::NewInstruction,
count_added_node!(compact, inst.stmt)
line = something(inst.line, compact.result[before.id][:line])
node = add!(compact.new_new_nodes, before.id, attach_after)
push!(compact.new_new_used_ssas, 0)
node[:inst], node[:type], node[:line], node[:flag] = inst.stmt, inst.type, line, inst.flag
return NewSSAValue(node.idx)
else
Expand All @@ -723,6 +743,7 @@ function insert_node!(compact::IncrementalCompact, before, inst::NewInstruction,
count_added_node!(compact, inst.stmt)
line = something(inst.line, compact.result[renamed.id][:line])
node = add!(compact.new_new_nodes, renamed.id, attach_after)
push!(compact.new_new_used_ssas, 0)
node[:inst], node[:type], node[:line], node[:flag] = inst.stmt, inst.type, line, inst.flag
return NewSSAValue(node.idx)
else
Expand All @@ -744,6 +765,7 @@ function insert_node!(compact::IncrementalCompact, before, inst::NewInstruction,
line = something(inst.line, compact.new_new_nodes.stmts[before.id][:line])
new_entry = add!(compact.new_new_nodes, before_entry.pos, attach_after)
new_entry[:inst], new_entry[:type], new_entry[:line], new_entry[:flag] = inst.stmt, inst.type, line, inst.flag
push!(compact.new_new_used_ssas, 0)
return NewSSAValue(new_entry.idx)
else
error("Unsupported")
Expand Down Expand Up @@ -803,6 +825,9 @@ function setindex!(compact::IncrementalCompact, @nospecialize(v), idx::SSAValue)
if isa(val, SSAValue)
@assert compact.used_ssas[val.id] >= 1
compact.used_ssas[val.id] -= 1
elseif isa(val, NewSSAValue)
@assert compact.new_new_used_ssas[val.id] >= 1
compact.new_new_used_ssas[val.id] -= 1
end
end
compact.result[idx.id][:inst] = v
Expand Down Expand Up @@ -853,6 +878,7 @@ end
function process_phinode_values(old_values::Vector{Any}, late_fixup::Vector{Int},
processed_idx::Int, result_idx::Int,
ssa_rename::Vector{Any}, used_ssas::Vector{Int},
new_new_used_ssas::Vector{Int},
do_rename_ssa::Bool)
values = Vector{Any}(undef, length(old_values))
for i = 1:length(old_values)
Expand All @@ -864,7 +890,7 @@ function process_phinode_values(old_values::Vector{Any}, late_fixup::Vector{Int}
push!(late_fixup, result_idx)
val = OldSSAValue(val.id)
else
val = renumber_ssa2(val, ssa_rename, used_ssas, do_rename_ssa)
val = renumber_ssa2(val, ssa_rename, used_ssas, new_new_used_ssas, do_rename_ssa)
end
else
used_ssas[val.id] += 1
Expand All @@ -874,17 +900,19 @@ function process_phinode_values(old_values::Vector{Any}, late_fixup::Vector{Int}
push!(late_fixup, result_idx)
else
# Always renumber these. do_rename_ssa applies only to actual SSAValues
val = renumber_ssa2(SSAValue(val.id), ssa_rename, used_ssas, true)
val = renumber_ssa2(SSAValue(val.id), ssa_rename, used_ssas, new_new_used_ssas, true)
end
elseif isa(val, NewSSAValue)
push!(late_fixup, result_idx)
new_new_used_ssas[val.id] += 1
end
values[i] = val
end
return values
end

function renumber_ssa2(val::SSAValue, ssanums::Vector{Any}, used_ssas::Vector{Int}, do_rename_ssa::Bool)
function renumber_ssa2(val::SSAValue, ssanums::Vector{Any}, used_ssas::Vector{Int},
new_new_used_ssas::Vector{Int}, do_rename_ssa::Bool)
id = val.id
if id > length(ssanums)
return val
Expand All @@ -893,22 +921,26 @@ function renumber_ssa2(val::SSAValue, ssanums::Vector{Any}, used_ssas::Vector{In
val = ssanums[id]
end
if isa(val, SSAValue)
if used_ssas !== nothing
used_ssas[val.id] += 1
end
used_ssas[val.id] += 1
end
return val
end

function renumber_ssa2!(@nospecialize(stmt), ssanums::Vector{Any}, used_ssas::Vector{Int}, late_fixup::Vector{Int}, result_idx::Int, do_rename_ssa::Bool)
function renumber_ssa2(val::NewSSAValue, ssanums::Vector{Any}, used_ssas::Vector{Int},
new_new_used_ssas::Vector{Int}, do_rename_ssa::Bool)
new_new_used_ssas[val.id] += 1
return val
end

function renumber_ssa2!(@nospecialize(stmt), ssanums::Vector{Any}, used_ssas::Vector{Int}, new_new_used_ssas::Vector{Int}, late_fixup::Vector{Int}, result_idx::Int, do_rename_ssa::Bool)
urs = userefs(stmt)
for op in urs
val = op[]
if isa(val, OldSSAValue) || isa(val, NewSSAValue)
push!(late_fixup, result_idx)
end
if isa(val, SSAValue)
val = renumber_ssa2(val, ssanums, used_ssas, do_rename_ssa)
if isa(val, Union{SSAValue, NewSSAValue})
val = renumber_ssa2(val, ssanums, used_ssas, new_new_used_ssas, do_rename_ssa)
end
if isa(val, OldSSAValue) || isa(val, NewSSAValue)
push!(late_fixup, result_idx)
Expand Down Expand Up @@ -992,6 +1024,7 @@ function process_node!(compact::IncrementalCompact, result_idx::Int, inst::Instr
ssa_rename = compact.ssa_rename
late_fixup = compact.late_fixup
used_ssas = compact.used_ssas
new_new_used_ssas = compact.new_new_used_ssas
ssa_rename[idx] = SSAValue(result_idx)
if stmt === nothing
ssa_rename[idx] = stmt
Expand All @@ -1008,7 +1041,7 @@ function process_node!(compact::IncrementalCompact, result_idx::Int, inst::Instr
result[result_idx][:inst] = stmt
result_idx += 1
elseif isa(stmt, GotoIfNot) && compact.cfg_transforms_enabled
stmt = renumber_ssa2!(stmt, ssa_rename, used_ssas, late_fixup, result_idx, do_rename_ssa)::GotoIfNot
stmt = renumber_ssa2!(stmt, ssa_rename, used_ssas, new_new_used_ssas, late_fixup, result_idx, do_rename_ssa)::GotoIfNot
result[result_idx][:inst] = stmt
cond = stmt.cond
if compact.fold_constant_branches
Expand All @@ -1033,7 +1066,7 @@ function process_node!(compact::IncrementalCompact, result_idx::Int, inst::Instr
result_idx += 1
end
elseif isa(stmt, Expr)
stmt = renumber_ssa2!(stmt, ssa_rename, used_ssas, late_fixup, result_idx, do_rename_ssa)::Expr
stmt = renumber_ssa2!(stmt, ssa_rename, used_ssas, new_new_used_ssas, late_fixup, result_idx, do_rename_ssa)::Expr
if compact.cfg_transforms_enabled && isexpr(stmt, :enter)
stmt.args[1] = compact.bb_rename_succ[stmt.args[1]::Int]
end
Expand All @@ -1043,7 +1076,7 @@ function process_node!(compact::IncrementalCompact, result_idx::Int, inst::Instr
# As an optimization, we eliminate any trivial pinodes. For performance, we use ===
# type equality. We may want to consider using == in either a separate pass or if
# performance turns out ok
stmt = renumber_ssa2!(stmt, ssa_rename, used_ssas, late_fixup, result_idx, do_rename_ssa)::PiNode
stmt = renumber_ssa2!(stmt, ssa_rename, used_ssas, new_new_used_ssas, late_fixup, result_idx, do_rename_ssa)::PiNode
pi_val = stmt.val
if isa(pi_val, SSAValue)
if stmt.typ === compact.result[pi_val.id][:type]
Expand All @@ -1065,7 +1098,7 @@ function process_node!(compact::IncrementalCompact, result_idx::Int, inst::Instr
result[result_idx][:inst] = stmt
result_idx += 1
elseif isa(stmt, ReturnNode) || isa(stmt, UpsilonNode) || isa(stmt, GotoIfNot)
result[result_idx][:inst] = renumber_ssa2!(stmt, ssa_rename, used_ssas, late_fixup, result_idx, do_rename_ssa)
result[result_idx][:inst] = renumber_ssa2!(stmt, ssa_rename, used_ssas, new_new_used_ssas, late_fixup, result_idx, do_rename_ssa)
result_idx += 1
elseif isa(stmt, PhiNode)
if compact.cfg_transforms_enabled
Expand Down Expand Up @@ -1102,7 +1135,7 @@ function process_node!(compact::IncrementalCompact, result_idx::Int, inst::Instr
values = stmt.values
end

values = process_phinode_values(values, late_fixup, processed_idx, result_idx, ssa_rename, used_ssas, do_rename_ssa)
values = process_phinode_values(values, late_fixup, processed_idx, result_idx, ssa_rename, used_ssas, new_new_used_ssas, do_rename_ssa)
# Don't remove the phi node if it is before the definition of its value
# because doing so can create forward references. This should only
# happen with dead loops, but can cause problems when optimization
Expand All @@ -1121,7 +1154,7 @@ function process_node!(compact::IncrementalCompact, result_idx::Int, inst::Instr
result_idx += 1
end
elseif isa(stmt, PhiCNode)
result[result_idx][:inst] = PhiCNode(process_phinode_values(stmt.values, late_fixup, processed_idx, result_idx, ssa_rename, used_ssas, do_rename_ssa))
result[result_idx][:inst] = PhiCNode(process_phinode_values(stmt.values, late_fixup, processed_idx, result_idx, ssa_rename, used_ssas, new_new_used_ssas, do_rename_ssa))
result_idx += 1
elseif isa(stmt, SSAValue)
# identity assign, replace uses of this ssa value with its result
Expand Down Expand Up @@ -1323,31 +1356,39 @@ function iterate(compact::IncrementalCompact, (idx, active_bb)::Tuple{Int, Int}=
end

function maybe_erase_unused!(
extra_worklist::Vector{Int}, compact::IncrementalCompact, idx::Int,
extra_worklist::Vector{Int}, compact::IncrementalCompact, idx::Int, in_worklist::Bool,
callback = null_dce_callback)
stmt = compact.result[idx][:inst]

inst = idx <= length(compact.result) ? compact.result[idx] :
compact.new_new_nodes.stmts[idx - length(compact.result)]
stmt = inst[:inst]
stmt === nothing && return false
if argextype(SSAValue(idx), compact) === Bottom
if inst[:type] === Bottom
effect_free = false
else
effect_free = compact.result[idx][:flag] & IR_FLAG_EFFECT_FREE != 0
effect_free = inst[:flag] & IR_FLAG_EFFECT_FREE != 0
end
function kill_ssa_value(val::SSAValue)
if compact.used_ssas[val.id] == 1
if val.id < idx || in_worklist
push!(extra_worklist, val.id)
end
end
compact.used_ssas[val.id] -= 1
callback(val)
end
if effect_free
for ops in userefs(stmt)
val = ops[]
# If the pass we ran inserted new nodes, it's possible for those
# to be outside our used_ssas count.
if isa(val, SSAValue) && val.id <= length(compact.used_ssas)
if compact.used_ssas[val.id] == 1
if val.id < idx
push!(extra_worklist, val.id)
end
if isa(stmt, SSAValue)
kill_ssa_value(stmt)
else
for ops in userefs(stmt)
val = ops[]
if isa(val, SSAValue)
kill_ssa_value(val)
end
compact.used_ssas[val.id] -= 1
callback(val)
end
end
compact.result[idx][:inst] = nothing
inst[:inst] = nothing
return true
end
return false
Expand All @@ -1358,13 +1399,8 @@ function fixup_phinode_values!(compact::IncrementalCompact, old_values::Vector{A
for i = 1:length(old_values)
isassigned(old_values, i) || continue
val = old_values[i]
if isa(val, OldSSAValue)
val = compact.ssa_rename[val.id]
if isa(val, SSAValue)
compact.used_ssas[val.id] += 1
end
elseif isa(val, NewSSAValue)
val = SSAValue(length(compact.result) + val.id)
if isa(val, Union{OldSSAValue, NewSSAValue})
val = fixup_node(compact, val)
end
values[i] = val
end
Expand All @@ -1379,29 +1415,30 @@ function fixup_node(compact::IncrementalCompact, @nospecialize(stmt))
elseif isa(stmt, NewSSAValue)
return SSAValue(length(compact.result) + stmt.id)
elseif isa(stmt, OldSSAValue)
return compact.ssa_rename[stmt.id]
val = compact.ssa_rename[stmt.id]
if isa(val, SSAValue)
# If `val.id` is greater than the length of `compact.result` or
# `compact.used_ssas`, this SSA value is in `new_new_nodes`, so
# don't count the use
compact.used_ssas[val.id] += 1
end
return val
else
urs = userefs(stmt)
for ur in urs
val = ur[]
if isa(val, NewSSAValue)
val = SSAValue(length(compact.result) + val.id)
elseif isa(val, OldSSAValue)
val = compact.ssa_rename[val.id]
end
if isa(val, SSAValue) && val.id <= length(compact.used_ssas)
# If `val.id` is greater than the length of `compact.result` or
# `compact.used_ssas`, this SSA value is in `new_new_nodes`, so
# don't count the use
compact.used_ssas[val.id] += 1
if isa(val, Union{NewSSAValue, OldSSAValue})
ur[] = fixup_node(compact, val)
end
ur[] = val
end
return urs[]
end
end

function just_fixup!(compact::IncrementalCompact)
resize!(compact.used_ssas, length(compact.result))
append!(compact.used_ssas, compact.new_new_used_ssas)
empty!(compact.new_new_used_ssas)
for idx in compact.late_fixup
stmt = compact.result[idx][:inst]
new_stmt = fixup_node(compact, stmt)
Expand All @@ -1419,14 +1456,14 @@ end

function simple_dce!(compact::IncrementalCompact, callback = null_dce_callback)
# Perform simple DCE for unused values
@assert isempty(compact.new_new_used_ssas) # just_fixup! wasn't run?
extra_worklist = Int[]
for (idx, nused) in Iterators.enumerate(compact.used_ssas)
idx >= compact.result_idx && break
nused == 0 || continue
maybe_erase_unused!(extra_worklist, compact, idx, callback)
maybe_erase_unused!(extra_worklist, compact, idx, false, callback)
end
while !isempty(extra_worklist)
maybe_erase_unused!(extra_worklist, compact, pop!(extra_worklist), callback)
maybe_erase_unused!(extra_worklist, compact, pop!(extra_worklist), true, callback)
end
end

Expand Down
Loading