Skip to content

Commit 72715ee

Browse files
committed
optimizer: enhance SROA, handle partially-initialized allocations
During adding more test cases for our SROA pass, I found our SROA doesn't handle allocation sites with uninitialized fields at all. This commit is based on #42833 and tries to handle such "unsafe" allocations, if there are safe `setfield!` definitions. For example, this commit allows the allocation `r = Ref{Int}()` to be eliminated in the following example (adapted from <https://hackmd.io/bZz8k6SHQQuNUW-Vs7rqfw?view>): ```julia julia> code_typed() do r = Ref{Int}() r[] = 42 b = sin(r[]) return b end |> only ``` This commit comes with a plenty of basic test cases for our SROA pass also.
1 parent 4c3f77a commit 72715ee

File tree

4 files changed

+230
-40
lines changed

4 files changed

+230
-40
lines changed

base/compiler/optimize.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ function run_passes(ci::CodeInfo, sv::OptimizationState)
324324
@timeit "Inlining" ir = ssa_inlining_pass!(ir, ir.linetable, sv.inlining, ci.propagate_inbounds)
325325
# @timeit "verify 2" verify_ir(ir)
326326
@timeit "compact 2" ir = compact!(ir)
327-
@timeit "SROA" ir = getfield_elim_pass!(ir)
327+
@timeit "SROA" ir = sroa_pass!(ir)
328328
@timeit "ADCE" ir = adce_pass!(ir)
329329
@timeit "type lift" ir = type_lift_pass!(ir)
330330
@timeit "compact 3" ir = compact!(ir)

base/compiler/ssair/passes.jl

Lines changed: 45 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,22 @@ function compute_value_for_block(ir::IRCode, domtree::DomTree, allblocks::Vector
7575
end
7676

7777
function compute_value_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, du::SSADefUse, phinodes::IdDict{Int, SSAValue}, fidx::Int, use_idx::Int)
78-
# Find the first dominating def
78+
def, stmtblock, curblock = find_def_for_use(ir, domtree, allblocks, du, use_idx)
79+
if def == 0
80+
if !haskey(phinodes, curblock)
81+
# If this happens, we need to search the predecessors for defs. Which
82+
# one doesn't matter - if it did, we'd have had a phinode
83+
return compute_value_for_block(ir, domtree, allblocks, du, phinodes, fidx, first(ir.cfg.blocks[stmtblock].preds))
84+
end
85+
# The use is the phinode
86+
return phinodes[curblock]
87+
else
88+
return val_for_def_expr(ir, def, fidx)
89+
end
90+
end
91+
92+
# find the first dominating def for the given use
93+
function find_def_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, du::SSADefUse, use_idx::Int)
7994
stmtblock = block_for_inst(ir.cfg, use_idx)
8095
curblock = find_curblock(domtree, allblocks, stmtblock)
8196
local def = 0
@@ -90,17 +105,7 @@ function compute_value_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{I
90105
end
91106
end
92107
end
93-
if def == 0
94-
if !haskey(phinodes, curblock)
95-
# If this happens, we need to search the predecessors for defs. Which
96-
# one doesn't matter - if it did, we'd have had a phinode
97-
return compute_value_for_block(ir, domtree, allblocks, du, phinodes, fidx, first(ir.cfg.blocks[stmtblock].preds))
98-
end
99-
# The use is the phinode
100-
return phinodes[curblock]
101-
else
102-
return val_for_def_expr(ir, def, fidx)
103-
end
108+
return def, stmtblock, curblock
104109
end
105110

106111
function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#),
@@ -538,7 +543,7 @@ function perform_lifting!(compact::IncrementalCompact,
538543
end
539544

540545
"""
541-
getfield_elim_pass!(ir::IRCode) -> newir::IRCode
546+
sroa_pass!(ir::IRCode) -> newir::IRCode
542547
543548
`getfield` elimination pass, a.k.a. Scalar Replacements of Aggregates optimization.
544549
@@ -555,7 +560,7 @@ its argument).
555560
In a case when all usages are fully eliminated, `struct` allocation may also be erased as
556561
a result of dead code elimination.
557562
"""
558-
function getfield_elim_pass!(ir::IRCode)
563+
function sroa_pass!(ir::IRCode)
559564
compact = IncrementalCompact(ir)
560565
defuses = IdDict{Int, Tuple{IdSet{Int}, SSADefUse}}()
561566
lifting_cache = IdDict{Pair{AnySSAValue, Any}, AnySSAValue}()
@@ -784,7 +789,6 @@ function getfield_elim_pass!(ir::IRCode)
784789
typ = typ::DataType
785790
# Partition defuses by field
786791
fielddefuse = SSADefUse[SSADefUse() for _ = 1:fieldcount(typ)]
787-
ok = true
788792
for use in defuse.uses
789793
stmt = ir[SSAValue(use)]
790794
# We may have discovered above that this use is dead
@@ -793,47 +797,52 @@ function getfield_elim_pass!(ir::IRCode)
793797
# the use in that case.
794798
stmt === nothing && continue
795799
field = try_compute_fieldidx_stmt(compact, stmt::Expr, typ)
796-
field === nothing && (ok = false; break)
800+
field === nothing && @goto skip
797801
push!(fielddefuse[field].uses, use)
798802
end
799-
ok || continue
800803
for use in defuse.defs
801804
field = try_compute_fieldidx_stmt(compact, ir[SSAValue(use)]::Expr, typ)
802-
field === nothing && (ok = false; break)
805+
field === nothing && @goto skip
803806
push!(fielddefuse[field].defs, use)
804807
end
805-
ok || continue
806808
# Check that the defexpr has defined values for all the fields
807809
# we're accessing. In the future, we may want to relax this,
808810
# but we should come up with semantics for well defined semantics
809811
# for uninitialized fields first.
810-
for (fidx, du) in pairs(fielddefuse)
812+
ndefuse = length(fielddefuse)
813+
blocks = Vector{Tuple{#=phiblocks=# Vector{Int}, #=allblocks=# Vector{Int}}}(undef, ndefuse)
814+
for fidx in 1:ndefuse
815+
du = fielddefuse[fidx]
811816
isempty(du.uses) && continue
817+
push!(du.defs, idx)
818+
ldu = compute_live_ins(ir.cfg, du)
819+
phiblocks = Int[]
820+
if !isempty(ldu.live_in_bbs)
821+
phiblocks = idf(ir.cfg, ldu, domtree)
822+
end
823+
allblocks = sort(vcat(phiblocks, ldu.def_bbs))
824+
blocks[fidx] = phiblocks, allblocks
812825
if fidx + 1 > length(defexpr.args)
813-
ok = false
814-
break
826+
for use in du.uses
827+
def = find_def_for_use(ir, domtree, allblocks, du, use)[1]
828+
(def == 0 || def == idx) && @goto skip
829+
end
815830
end
816831
end
817-
ok || continue
818832
preserve_uses = IdDict{Int, Vector{Any}}((idx=>Any[] for idx in IdSet{Int}(defuse.ccall_preserve_uses)))
819833
# Everything accounted for. Go field by field and perform idf
820-
for (fidx, du) in pairs(fielddefuse)
834+
for fidx in 1:ndefuse
835+
du = fielddefuse[fidx]
821836
ftyp = fieldtype(typ, fidx)
822837
if !isempty(du.uses)
823-
push!(du.defs, idx)
824-
ldu = compute_live_ins(ir.cfg, du)
825-
phiblocks = Int[]
826-
if !isempty(ldu.live_in_bbs)
827-
phiblocks = idf(ir.cfg, ldu, domtree)
828-
end
838+
phiblocks, allblocks = blocks[fidx]
829839
phinodes = IdDict{Int, SSAValue}()
830840
for b in phiblocks
831841
n = PhiNode()
832842
phinodes[b] = insert_node!(ir, first(ir.cfg.blocks[b].stmts),
833843
NewInstruction(n, ftyp))
834844
end
835845
# Now go through all uses and rewrite them
836-
allblocks = sort(vcat(phiblocks, ldu.def_bbs))
837846
for stmt in du.uses
838847
ir[SSAValue(stmt)] = compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, stmt)
839848
end
@@ -855,7 +864,6 @@ function getfield_elim_pass!(ir::IRCode)
855864
stmt == idx && continue
856865
ir[SSAValue(stmt)] = nothing
857866
end
858-
continue
859867
end
860868
isempty(defuse.ccall_preserve_uses) && continue
861869
push!(intermediaries, idx)
@@ -870,6 +878,8 @@ function getfield_elim_pass!(ir::IRCode)
870878
old_preserves..., new_preserves...)
871879
ir[SSAValue(use)] = new_expr
872880
end
881+
882+
@label skip
873883
end
874884

875885
return ir
@@ -919,14 +929,14 @@ In addition to a simple DCE for unused values and allocations,
919929
this pass also nullifies `typeassert` calls that can be proved to be no-op,
920930
in order to allow LLVM to emit simpler code down the road.
921931
922-
Note that this pass is more effective after SROA optimization (i.e. `getfield_elim_pass!`),
932+
Note that this pass is more effective after SROA optimization (i.e. `sroa_pass!`),
923933
since SROA often allows this pass to:
924934
- eliminate allocation of object whose field references are all replaced with scalar values, and
925935
- nullify `typeassert` call whose first operand has been replaced with a scalar value
926936
(, which may have introduced new type information that inference did not understand)
927937
928-
Also note that currently this pass _needs_ to run after `getfield_elim_pass!`, because
929-
the `typeassert` elimination depends on the transformation within `getfield_elim_pass!`
938+
Also note that currently this pass _needs_ to run after `sroa_pass!`, because
939+
the `typeassert` elimination depends on the transformation within `sroa_pass!`
930940
which redirects references of `typeassert`ed value to the corresponding `PiNode`.
931941
"""
932942
function adce_pass!(ir::IRCode)

test/compiler/inline.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ using Base.Experimental: @opaque
380380
f_oc_getfield(x) = (@opaque ()->x)()
381381
@test fully_eliminated(f_oc_getfield, Tuple{Int})
382382

383-
import Core.Compiler: argextype
383+
import Core.Compiler: argextype, singleton_type
384384
const EMPTY_SPTYPES = Core.Compiler.EMPTY_SLOTTYPES
385385

386386
code_typed1(args...; kwargs...) = first(only(code_typed(args...; kwargs...)))::Core.CodeInfo
@@ -389,7 +389,7 @@ get_code(args...; kwargs...) = code_typed1(args...; kwargs...).code
389389
# check if `x` is a dynamic call of a given function
390390
function iscall((src, f)::Tuple{Core.CodeInfo,Function}, @nospecialize(x))
391391
return iscall(x) do @nospecialize x
392-
argextype(x, src, EMPTY_SPTYPES) === typeof(f)
392+
singleton_type(argextype(x, src, EMPTY_SPTYPES)) === f
393393
end
394394
end
395395
iscall(pred, @nospecialize(x)) = Meta.isexpr(x, :call) && pred(x.args[1])

0 commit comments

Comments
 (0)