Skip to content

Refactor the meaning of NewSSAValue in IncrementalCompact #45610

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 76 additions & 40 deletions base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,24 @@ struct OldSSAValue
id::Int
end

# SSA values that are in `new_new_nodes` of an `IncrementalCompact` and are to
# be actually inserted next time (they become `new_nodes` next time)
## TODO: This description currently omits the use of NewSSAValue during slot2ssa,
## which doesn't use IncrementalCompact, but does something similar and also uses
## NewSSAValue to refer to new_nodes. Ideally that use of NewSSAValue would go away
## during a refactor.
"""
struct NewSSAValue

`NewSSAValue`s occur in the context of IncrementalCompact. Their meaning depends
on where they appear:

1. In already-compacted nodes,
i. a `NewSSAValue` with positive `id` has the same meaning as a regular SSAValue.
ii. a `NewSSAValue` with negative `id` refers to post-compaction `new_node` node.

2. In non-compacted nodes,
i. a `NewSSAValue` with positive `id` refers to the index of an already-compacted instructions.
ii. a `NewSSAValue` with negative `id` has the same meaning as in compacted nodes.
"""
struct NewSSAValue
id::Int
end
Expand Down Expand Up @@ -618,38 +634,33 @@ struct TypesView{T}
end
types(ir::Union{IRCode, IncrementalCompact}) = TypesView(ir)

# TODO We can be a bit better about access here by using a pattern similar to InstructionStream
function getindex(compact::IncrementalCompact, idx::Int)
if idx < compact.result_idx
return compact.result[idx][:inst]
else
return compact.ir.stmts[idx][:inst]
end
end

function getindex(compact::IncrementalCompact, ssa::SSAValue)
@assert ssa.id < compact.result_idx
return compact.result[ssa.id][:inst]
return compact.result[ssa.id]
end

function getindex(compact::IncrementalCompact, ssa::OldSSAValue)
id = ssa.id
if id < compact.idx
new_idx = compact.ssa_rename[id]
return compact.result[new_idx][:inst]
return compact.result[new_idx]
elseif id <= length(compact.ir.stmts)
return compact.ir.stmts[id][:inst]
return compact.ir.stmts[id]
end
id -= length(compact.ir.stmts)
if id <= length(compact.ir.new_nodes)
return compact.ir.new_nodes.stmts[id][:inst]
return compact.ir.new_nodes.stmts[id]
end
id -= length(compact.ir.new_nodes)
return compact.pending_nodes.stmts[id][:inst]
return compact.pending_nodes.stmts[id]
end

function getindex(compact::IncrementalCompact, ssa::NewSSAValue)
return compact.new_new_nodes.stmts[ssa.id][:inst]
if ssa.id < 0
return compact.new_new_nodes.stmts[-ssa.id]
else
return compact[SSAValue(ssa.id)]
end
end

function block_for_inst(compact::IncrementalCompact, idx::SSAValue)
Expand All @@ -671,7 +682,12 @@ function block_for_inst(compact::IncrementalCompact, idx::OldSSAValue)
end

function block_for_inst(compact::IncrementalCompact, idx::NewSSAValue)
block_for_inst(compact, SSAValue(compact.new_new_nodes.info[idx.id].pos))
if idx.id > 0
@assert idx.id < compact.result_idx
return block_for_inst(compact, SSAValue(idx.id))
else
return block_for_inst(compact, SSAValue(compact.new_new_nodes.info[-idx.id].pos))
end
end

function dominates_ssa(compact::IncrementalCompact, domtree::DomTree, x::AnySSAValue, y::AnySSAValue)
Expand All @@ -682,16 +698,24 @@ function dominates_ssa(compact::IncrementalCompact, domtree::DomTree, x::AnySSAV
if isa(x, OldSSAValue)
x′ = compact.ssa_rename[x.id]::SSAValue
elseif isa(x, NewSSAValue)
xinfo = compact.new_new_nodes.info[x.id]
x′ = SSAValue(xinfo.pos)
if x.id > 0
x′ = SSAValue(x.id)
else
xinfo = compact.new_new_nodes.info[-x.id]
x′ = SSAValue(xinfo.pos)
end
else
x′ = x
end
if isa(y, OldSSAValue)
y′ = compact.ssa_rename[y.id]::SSAValue
elseif isa(y, NewSSAValue)
yinfo = compact.new_new_nodes.info[y.id]
y′ = SSAValue(yinfo.pos)
if y.id > 0
y′ = SSAValue(y.id)
else
yinfo = compact.new_new_nodes.info[-y.id]
y′ = SSAValue(yinfo.pos)
end
else
y′ = y
end
Expand Down Expand Up @@ -719,7 +743,8 @@ function count_added_node!(compact::IncrementalCompact, @nospecialize(v))
if isa(val, SSAValue)
compact.used_ssas[val.id] += 1
elseif isa(val, NewSSAValue)
compact.new_new_used_ssas[val.id] += 1
@assert val.id < 0 # Newly added nodes should be canonicalized
compact.new_new_used_ssas[-val.id] += 1
needs_late_fixup = true
end
end
Expand All @@ -743,7 +768,7 @@ function insert_node!(compact::IncrementalCompact, before, inst::NewInstruction,
node = add!(compact.new_new_nodes, before.id, attach_after)
push!(compact.new_new_used_ssas, 0)
node[:inst], node[:type], node[:line], node[:flag] = inst.stmt, inst.type, line, inst.flag
return NewSSAValue(node.idx)
return NewSSAValue(-node.idx)
else
line = something(inst.line, compact.ir.stmts[before.id][:line])
node = add_pending!(compact, before.id, attach_after)
Expand All @@ -762,7 +787,7 @@ function insert_node!(compact::IncrementalCompact, before, inst::NewInstruction,
node = add!(compact.new_new_nodes, renamed.id, attach_after)
push!(compact.new_new_used_ssas, 0)
node[:inst], node[:type], node[:line], node[:flag] = inst.stmt, inst.type, line, inst.flag
return NewSSAValue(node.idx)
return NewSSAValue(-node.idx)
else
if pos > length(compact.ir.stmts)
#@assert attach_after
Expand All @@ -778,12 +803,13 @@ function insert_node!(compact::IncrementalCompact, before, inst::NewInstruction,
return os
end
elseif isa(before, NewSSAValue)
before_entry = compact.new_new_nodes.info[before.id]
line = something(inst.line, compact.new_new_nodes.stmts[before.id][:line])
# TODO: This is incorrect and does not maintain ordering among the new nodes
before_entry = compact.new_new_nodes.info[-before.id]
line = something(inst.line, compact.new_new_nodes.stmts[-before.id][:line])
new_entry = add!(compact.new_new_nodes, before_entry.pos, attach_after)
new_entry[:inst], new_entry[:type], new_entry[:line], new_entry[:flag] = inst.stmt, inst.type, line, inst.flag
push!(compact.new_new_used_ssas, 0)
return NewSSAValue(new_entry.idx)
return NewSSAValue(-new_entry.idx)
else
error("Unsupported")
end
Expand Down Expand Up @@ -838,8 +864,9 @@ function kill_current_uses(compact::IncrementalCompact, @nospecialize(stmt))
@assert compact.used_ssas[val.id] >= 1
compact.used_ssas[val.id] -= 1
elseif isa(val, NewSSAValue)
@assert compact.new_new_used_ssas[val.id] >= 1
compact.new_new_used_ssas[val.id] -= 1
@assert val.id < 0
@assert compact.new_new_used_ssas[-val.id] >= 1
compact.new_new_used_ssas[-val.id] -= 1
end
end
end
Expand Down Expand Up @@ -929,11 +956,7 @@ function getindex(view::TypesView, idx::Int)
end

function getindex(view::TypesView, idx::NewSSAValue)
if isa(view.ir, IncrementalCompact)
return view.ir.new_new_nodes.stmts[idx.id][:type]
else
return view.ir.new_nodes.stmts[idx.id][:type]
end
return view.ir[idx][:type]
end

function process_phinode_values(old_values::Vector{Any}, late_fixup::Vector{Int},
Expand Down Expand Up @@ -964,8 +987,13 @@ function process_phinode_values(old_values::Vector{Any}, late_fixup::Vector{Int}
val = renumber_ssa2(SSAValue(val.id), ssa_rename, used_ssas, new_new_used_ssas, true)
end
elseif isa(val, NewSSAValue)
push!(late_fixup, result_idx)
new_new_used_ssas[val.id] += 1
if val.id < 0
push!(late_fixup, result_idx)
new_new_used_ssas[-val.id] += 1
else
@assert do_rename_ssa
val = SSAValue(val.id)
end
end
values[i] = val
end
Expand All @@ -989,8 +1017,13 @@ end

function renumber_ssa2(val::NewSSAValue, ssanums::Vector{Any}, used_ssas::Vector{Int},
new_new_used_ssas::Vector{Int}, do_rename_ssa::Bool)
new_new_used_ssas[val.id] += 1
return val
if val.id < 0
new_new_used_ssas[-val.id] += 1
return val
else
used_ssas[val.id] += 1
return SSAValue(val.id)
end
end

function renumber_ssa2!(@nospecialize(stmt), ssanums::Vector{Any}, used_ssas::Vector{Int}, new_new_used_ssas::Vector{Int}, late_fixup::Vector{Int}, result_idx::Int, do_rename_ssa::Bool)
Expand Down Expand Up @@ -1225,6 +1258,8 @@ function process_node!(compact::IncrementalCompact, result_idx::Int, inst::Instr
stmt = ssa_rename[stmt.id]
end
ssa_rename[idx] = stmt
elseif isa(stmt, NewSSAValue)
ssa_rename[idx] = SSAValue(stmt.id)
else
# Constant assign, replace uses of this ssa value with its result
ssa_rename[idx] = stmt
Expand Down Expand Up @@ -1466,7 +1501,8 @@ function fixup_node(compact::IncrementalCompact, @nospecialize(stmt))
elseif isa(stmt, PhiCNode)
return PhiCNode(fixup_phinode_values!(compact, stmt.values))
elseif isa(stmt, NewSSAValue)
return SSAValue(length(compact.result) + stmt.id)
@assert stmt.id < 0
return SSAValue(length(compact.result) - stmt.id)
elseif isa(stmt, OldSSAValue)
val = compact.ssa_rename[stmt.id]
if isa(val, SSAValue)
Expand Down
20 changes: 10 additions & 10 deletions base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSA
return rename
end
end
def = compact[defssa]
def = compact[defssa][:inst]
if isa(def, PiNode)
if callback(def, defssa)
return defssa
Expand Down Expand Up @@ -246,7 +246,7 @@ Starting at `val` walk use-def chains to get all the leaves feeding into this `v
function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospecialize(typeconstraint))
visited_phinodes = AnySSAValue[]
isa(defssa, AnySSAValue) || return Any[defssa], visited_phinodes
def = compact[defssa]
def = compact[defssa][:inst]
isa(def, PhiNode) || return Any[defssa], visited_phinodes
visited_constraints = IdDict{AnySSAValue, Any}()
worklist_defs = AnySSAValue[]
Expand All @@ -258,7 +258,7 @@ function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospe
defssa = pop!(worklist_defs)
typeconstraint = pop!(worklist_constraints)
visited_constraints[defssa] = typeconstraint
def = compact[defssa]
def = compact[defssa][:inst]
if isa(def, PhiNode)
push!(visited_phinodes, defssa)
possible_predecessors = Int[]
Expand Down Expand Up @@ -479,12 +479,12 @@ function walk_to_def(compact::IncrementalCompact, @nospecialize(leaf))
leaf = simple_walk(compact, leaf)
end
if isa(leaf, AnySSAValue)
def = compact[leaf]
def = compact[leaf][:inst]
else
def = leaf
end
elseif isa(leaf, AnySSAValue)
def = compact[leaf]
def = compact[leaf][:inst]
else
def = leaf
end
Expand Down Expand Up @@ -653,7 +653,7 @@ function perform_lifting!(compact::IncrementalCompact,
cached = false
if cached
ssa = lifting_cache[ckey]
push!(lifted_phis, LiftedPhi(ssa, compact[ssa]::PhiNode, false))
push!(lifted_phis, LiftedPhi(ssa, compact[ssa][:inst]::PhiNode, false))
continue
end
n = PhiNode()
Expand All @@ -664,7 +664,7 @@ function perform_lifting!(compact::IncrementalCompact,

# Fix up arguments
for (old_node_ssa, lf) in zip(visited_phinodes, lifted_phis)
old_node = compact[old_node_ssa]::PhiNode
old_node = compact[old_node_ssa][:inst]::PhiNode
new_node = lf.node
lf.need_argupdate || continue
for i = 1:length(old_node.edges)
Expand Down Expand Up @@ -790,7 +790,7 @@ function sroa_pass!(ir::IRCode, inlining::Union{Nothing, InliningState} = nothin
def = simple_walk(compact, preserved_arg, callback)
isa(def, SSAValue) || continue
defidx = def.id
def = compact[defidx]
def = compact[def][:inst]
if is_known_call(def, tuple, compact)
record_immutable_preserve!(new_preserves, def, compact)
push!(preserved, preserved_arg.id)
Expand Down Expand Up @@ -1289,7 +1289,7 @@ function mark_phi_cycles!(compact::IncrementalCompact, safe_phis::SPCSet, phi::I
for ur in userefs(compact.result[phi][:inst])
val = ur[]
isa(val, SSAValue) || continue
isa(compact[val], PhiNode) || continue
isa(compact[val][:inst], PhiNode) || continue
(val.id in safe_phis) && continue
push!(worklist, val.id)
end
Expand Down Expand Up @@ -1399,7 +1399,7 @@ function adce_pass!(ir::IRCode)
continue
end
to_drop = Int[]
stmt = compact[phi]
stmt = compact[SSAValue(phi)][:inst]
stmt === nothing && continue
stmt = stmt::PhiNode
for i = 1:length(stmt.values)
Expand Down
12 changes: 6 additions & 6 deletions base/compiler/ssair/slot2ssa.jl
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ function compute_live_ins(cfg::CFG, defs::Vector{Int}, uses::Vector{Int})
BlockLiveness(bb_defs, bb_uses)
end

function recompute_type(node::Union{PhiNode, PhiCNode}, ci::CodeInfo, ir::IRCode, sptypes::Vector{Any}, slottypes::Vector{Any})
function recompute_type(node::Union{PhiNode, PhiCNode}, ci::CodeInfo, ir::IRCode, sptypes::Vector{Any}, slottypes::Vector{Any}, nstmts::Int)
new_typ = Union{}
for i = 1:length(node.values)
if isa(node, PhiNode) && !isassigned(node.values, i)
Expand All @@ -579,7 +579,7 @@ function recompute_type(node::Union{PhiNode, PhiCNode}, ci::CodeInfo, ir::IRCode
end
@assert !isa(typ, MaybeUndef)
while isa(typ, DelayedTyp)
typ = types(ir)[typ.phi::NewSSAValue]
typ = types(ir)[new_to_regular(typ.phi::NewSSAValue, nstmts)]
end
new_typ = tmerge(new_typ, was_maybe_undef ? MaybeUndef(typ) : typ)
end
Expand Down Expand Up @@ -856,7 +856,7 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, domtree::DomTree,
orig_typ = typ = typ_for_val(phic_values[i], ci, ir.sptypes, -1, slottypes)
@assert !isa(typ, MaybeUndef)
while isa(typ, DelayedTyp)
typ = types(ir)[typ.phi::NewSSAValue]
typ = types(ir)[new_to_regular(typ.phi::NewSSAValue, nstmts)]
end
new_typ = tmerge(new_typ, typ)
end
Expand All @@ -871,7 +871,7 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, domtree::DomTree,
changed = false
for new_idx in type_refine_phi
node = new_nodes.stmts[new_idx]
new_typ = recompute_type(node[:inst]::Union{PhiNode,PhiCNode}, ci, ir, ir.sptypes, slottypes)
new_typ = recompute_type(node[:inst]::Union{PhiNode,PhiCNode}, ci, ir, ir.sptypes, slottypes, nstmts)
if !(node[:type] ⊑ new_typ) || !(new_typ ⊑ node[:type])
node[:type] = new_typ
changed = true
Expand All @@ -881,14 +881,14 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, domtree::DomTree,
for i in 1:length(result_types)
rt_i = result_types[i]
if rt_i isa DelayedTyp
result_types[i] = types(ir)[rt_i.phi::NewSSAValue]
result_types[i] = types(ir)[new_to_regular(rt_i.phi::NewSSAValue, nstmts)]
end
end
for i = 1:length(new_nodes)
local node = new_nodes.stmts[i]
local typ = node[:type]
if isa(typ, DelayedTyp)
node[:type] = types(ir)[typ.phi::NewSSAValue]
node[:type] = types(ir)[new_to_regular(typ.phi::NewSSAValue, nstmts)]
end
end
# Renumber SSA values
Expand Down