Skip to content

Commit 54004ba

Browse files
committed
optimizer: enhance SROA, handle partially-initialized allocations
During adding more test cases for our SROA pass, I found our SROA doesn't handle allocation sites with uninitialized fields at all. This commit is based on #42833 and tries to handle such "unsafe" allocations, if there are safe `setfield!` definitions. For example, this commit allows the allocation `r = Ref{Int}()` to be eliminated in the following example (adapted from <https://hackmd.io/bZz8k6SHQQuNUW-Vs7rqfw?view>): ```julia julia> code_typed() do r = Ref{Int}() r[] = 42 b = sin(r[]) return b end |> only ``` This commit comes with a plenty of basic test cases for our SROA pass also.
1 parent 4c6696e commit 54004ba

File tree

3 files changed

+214
-38
lines changed

3 files changed

+214
-38
lines changed

base/compiler/optimize.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ function run_passes(ci::CodeInfo, sv::OptimizationState)
325325
@timeit "Inlining" ir = ssa_inlining_pass!(ir, ir.linetable, sv.inlining, ci.propagate_inbounds)
326326
# @timeit "verify 2" verify_ir(ir)
327327
@timeit "compact 2" ir = compact!(ir)
328-
@timeit "SROA" ir = getfield_elim_pass!(ir)
328+
@timeit "SROA" ir = sroa_pass!(ir)
329329
@timeit "ADCE" ir = adce_pass!(ir)
330330
@timeit "type lift" ir = type_lift_pass!(ir)
331331
@timeit "compact 3" ir = compact!(ir)

base/compiler/ssair/passes.jl

Lines changed: 45 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,22 @@ function compute_value_for_block(ir::IRCode, domtree::DomTree, allblocks::Vector
7575
end
7676

7777
function compute_value_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, du::SSADefUse, phinodes::IdDict{Int, SSAValue}, fidx::Int, use_idx::Int)
78-
# Find the first dominating def
78+
def, stmtblock, curblock = find_def_for_use(ir, domtree, allblocks, du, use_idx)
79+
if def == 0
80+
if !haskey(phinodes, curblock)
81+
# If this happens, we need to search the predecessors for defs. Which
82+
# one doesn't matter - if it did, we'd have had a phinode
83+
return compute_value_for_block(ir, domtree, allblocks, du, phinodes, fidx, first(ir.cfg.blocks[stmtblock].preds))
84+
end
85+
# The use is the phinode
86+
return phinodes[curblock]
87+
else
88+
return val_for_def_expr(ir, def, fidx)
89+
end
90+
end
91+
92+
# find the first dominating def for the given use
93+
function find_def_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, du::SSADefUse, use_idx::Int)
7994
stmtblock = block_for_inst(ir.cfg, use_idx)
8095
curblock = find_curblock(domtree, allblocks, stmtblock)
8196
local def = 0
@@ -90,17 +105,7 @@ function compute_value_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{I
90105
end
91106
end
92107
end
93-
if def == 0
94-
if !haskey(phinodes, curblock)
95-
# If this happens, we need to search the predecessors for defs. Which
96-
# one doesn't matter - if it did, we'd have had a phinode
97-
return compute_value_for_block(ir, domtree, allblocks, du, phinodes, fidx, first(ir.cfg.blocks[stmtblock].preds))
98-
end
99-
# The use is the phinode
100-
return phinodes[curblock]
101-
else
102-
return val_for_def_expr(ir, def, fidx)
103-
end
108+
return def, stmtblock, curblock
104109
end
105110

106111
function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#),
@@ -538,7 +543,7 @@ function perform_lifting!(compact::IncrementalCompact,
538543
end
539544

540545
"""
541-
getfield_elim_pass!(ir::IRCode) -> newir::IRCode
546+
sroa_pass!(ir::IRCode) -> newir::IRCode
542547
543548
`getfield` elimination pass, a.k.a. Scalar Replacements of Aggregates optimization.
544549
@@ -555,7 +560,7 @@ its argument).
555560
In a case when all usages are fully eliminated, `struct` allocation may also be erased as
556561
a result of dead code elimination.
557562
"""
558-
function getfield_elim_pass!(ir::IRCode)
563+
function sroa_pass!(ir::IRCode)
559564
compact = IncrementalCompact(ir)
560565
defuses = IdDict{Int, Tuple{IdSet{Int}, SSADefUse}}()
561566
lifting_cache = IdDict{Pair{AnySSAValue, Any}, AnySSAValue}()
@@ -784,7 +789,6 @@ function getfield_elim_pass!(ir::IRCode)
784789
typ = typ::DataType
785790
# Partition defuses by field
786791
fielddefuse = SSADefUse[SSADefUse() for _ = 1:fieldcount(typ)]
787-
ok = true
788792
for use in defuse.uses
789793
stmt = ir[SSAValue(use)]
790794
# We may have discovered above that this use is dead
@@ -793,47 +797,52 @@ function getfield_elim_pass!(ir::IRCode)
793797
# the use in that case.
794798
stmt === nothing && continue
795799
field = try_compute_fieldidx_stmt(compact, stmt::Expr, typ)
796-
field === nothing && (ok = false; break)
800+
field === nothing && @goto skip
797801
push!(fielddefuse[field].uses, use)
798802
end
799-
ok || continue
800803
for use in defuse.defs
801804
field = try_compute_fieldidx_stmt(compact, ir[SSAValue(use)]::Expr, typ)
802-
field === nothing && (ok = false; break)
805+
field === nothing && @goto skip
803806
push!(fielddefuse[field].defs, use)
804807
end
805-
ok || continue
806808
# Check that the defexpr has defined values for all the fields
807809
# we're accessing. In the future, we may want to relax this,
808810
# but we should come up with semantics for well defined semantics
809811
# for uninitialized fields first.
810-
for (fidx, du) in pairs(fielddefuse)
812+
ndefuse = length(fielddefuse)
813+
blocks = Vector{Tuple{#=phiblocks=# Vector{Int}, #=allblocks=# Vector{Int}}}(undef, ndefuse)
814+
for fidx in 1:ndefuse
815+
du = fielddefuse[fidx]
811816
isempty(du.uses) && continue
817+
push!(du.defs, idx)
818+
ldu = compute_live_ins(ir.cfg, du)
819+
phiblocks = Int[]
820+
if !isempty(ldu.live_in_bbs)
821+
phiblocks = idf(ir.cfg, ldu, domtree)
822+
end
823+
allblocks = sort(vcat(phiblocks, ldu.def_bbs))
824+
blocks[fidx] = phiblocks, allblocks
812825
if fidx + 1 > length(defexpr.args)
813-
ok = false
814-
break
826+
for use in du.uses
827+
def = find_def_for_use(ir, domtree, allblocks, du, use)[1]
828+
(def == 0 || def == idx) && @goto skip
829+
end
815830
end
816831
end
817-
ok || continue
818832
preserve_uses = IdDict{Int, Vector{Any}}((idx=>Any[] for idx in IdSet{Int}(defuse.ccall_preserve_uses)))
819833
# Everything accounted for. Go field by field and perform idf
820-
for (fidx, du) in pairs(fielddefuse)
834+
for fidx in 1:ndefuse
835+
du = fielddefuse[fidx]
821836
ftyp = fieldtype(typ, fidx)
822837
if !isempty(du.uses)
823-
push!(du.defs, idx)
824-
ldu = compute_live_ins(ir.cfg, du)
825-
phiblocks = Int[]
826-
if !isempty(ldu.live_in_bbs)
827-
phiblocks = idf(ir.cfg, ldu, domtree)
828-
end
838+
phiblocks, allblocks = blocks[fidx]
829839
phinodes = IdDict{Int, SSAValue}()
830840
for b in phiblocks
831841
n = PhiNode()
832842
phinodes[b] = insert_node!(ir, first(ir.cfg.blocks[b].stmts),
833843
NewInstruction(n, ftyp))
834844
end
835845
# Now go through all uses and rewrite them
836-
allblocks = sort(vcat(phiblocks, ldu.def_bbs))
837846
for stmt in du.uses
838847
ir[SSAValue(stmt)] = compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, stmt)
839848
end
@@ -855,7 +864,6 @@ function getfield_elim_pass!(ir::IRCode)
855864
stmt == idx && continue
856865
ir[SSAValue(stmt)] = nothing
857866
end
858-
continue
859867
end
860868
isempty(defuse.ccall_preserve_uses) && continue
861869
push!(intermediaries, idx)
@@ -870,6 +878,8 @@ function getfield_elim_pass!(ir::IRCode)
870878
old_preserves..., new_preserves...)
871879
ir[SSAValue(use)] = new_expr
872880
end
881+
882+
@label skip
873883
end
874884

875885
return ir
@@ -919,14 +929,14 @@ In addition to a simple DCE for unused values and allocations,
919929
this pass also nullifies `typeassert` calls that can be proved to be no-op,
920930
in order to allow LLVM to emit simpler code down the road.
921931
922-
Note that this pass is more effective after SROA optimization (i.e. `getfield_elim_pass!`),
932+
Note that this pass is more effective after SROA optimization (i.e. `sroa_pass!`),
923933
since SROA often allows this pass to:
924934
- eliminate allocation of object whose field references are all replaced with scalar values, and
925935
- nullify `typeassert` call whose first operand has been replaced with a scalar value
926936
(, which may have introduced new type information that inference did not understand)
927937
928-
Also note that currently this pass _needs_ to run after `getfield_elim_pass!`, because
929-
the `typeassert` elimination depends on the transformation within `getfield_elim_pass!`
938+
Also note that currently this pass _needs_ to run after `sroa_pass!`, because
939+
the `typeassert` elimination depends on the transformation within `sroa_pass!`
930940
which redirects references of `typeassert`ed value to the corresponding `PiNode`.
931941
"""
932942
function adce_pass!(ir::IRCode)

test/compiler/irpasses.jl

Lines changed: 168 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,172 @@ end
6969

7070
# Tests for SROA
7171

72+
import Core.Compiler: argextype, singleton_type
73+
const EMPTY_SPTYPES = Core.Compiler.EMPTY_SLOTTYPES
74+
75+
code_typed1(args...; kwargs...) = first(only(code_typed(args...; kwargs...)))::Core.CodeInfo
76+
get_code(args...; kwargs...) = code_typed1(args...; kwargs...).code
77+
78+
# check if `x` is a statement with a given `head`
79+
isnew(@nospecialize x) = Meta.isexpr(x, :new)
80+
81+
# check if `x` is a dynamic call of a given function
82+
iscall(y) = @nospecialize(x) -> iscall(y, x)
83+
function iscall((src, f)::Tuple{Core.CodeInfo,Function}, @nospecialize(x))
84+
return iscall(x) do @nospecialize x
85+
singleton_type(argextype(x, src, EMPTY_SPTYPES)) === f
86+
end
87+
end
88+
iscall(pred::Function, @nospecialize(x)) = Meta.isexpr(x, :call) && pred(x.args[1])
89+
90+
struct ImmutableXYZ; x; y; z; end
91+
mutable struct MutableXYZ; x; y; z; end
92+
93+
# should optimize away very basic cases
94+
let src = code_typed1((Any,Any,Any)) do x, y, z
95+
xyz = ImmutableXYZ(x, y, z)
96+
xyz.x, xyz.y, xyz.z
97+
end
98+
@test !any(isnew, src.code)
99+
end
100+
let src = code_typed1((Any,Any,Any)) do x, y, z
101+
xyz = MutableXYZ(x, y, z)
102+
xyz.x, xyz.y, xyz.z
103+
end
104+
@test !any(isnew, src.code)
105+
end
106+
107+
# should handle simple mutabilities
108+
let src = code_typed1((Any,Any,Any)) do x, y, z
109+
xyz = MutableXYZ(x, y, z)
110+
xyz.y = 42
111+
xyz.x, xyz.y, xyz.z
112+
end
113+
@test !any(isnew, src.code)
114+
@test any(src.code) do @nospecialize x
115+
iscall((src, tuple), x) &&
116+
x.args[2:end] == Any[#=x=# Core.Argument(2), 42, #=x=# Core.Argument(4)]
117+
end
118+
end
119+
let src = code_typed1((Any,Any,Any)) do x, y, z
120+
xyz = MutableXYZ(x, y, z)
121+
xyz.x, xyz.z = xyz.z, xyz.x
122+
xyz.x, xyz.y, xyz.z
123+
end
124+
@test !any(isnew, src.code)
125+
@test any(src.code) do @nospecialize x
126+
iscall((src, tuple), x) &&
127+
x.args[2:end] == Any[#=z=# Core.Argument(4), #=y=# Core.Argument(3), #=x=# Core.Argument(2)]
128+
end
129+
end
130+
# circumvent uninitialized fields as far as there is a solid `setfield!` definition
131+
let src = code_typed1() do
132+
r = Ref{Any}()
133+
r[] = 42
134+
return r[]
135+
end
136+
@test !any(isnew, src.code)
137+
end
138+
let src = code_typed1((Bool,)) do cond
139+
r = Ref{Any}()
140+
if cond
141+
r[] = 42
142+
return r[]
143+
else
144+
r[] = 32
145+
return r[]
146+
end
147+
end
148+
@test !any(isnew, src.code)
149+
end
150+
# FIXME to handle this case, we need a more strong alias analysis
151+
let src = code_typed1((Bool,)) do cond
152+
r = Ref{Any}()
153+
if cond
154+
r[] = 42
155+
else
156+
r[] = 32
157+
end
158+
return r[]
159+
end
160+
@test_broken !any(isnew, src.code)
161+
end
162+
let
163+
src = code_typed1((Bool,)) do cond
164+
r = Ref{Any}()
165+
if cond
166+
r[] = 42
167+
end
168+
return r[]
169+
end
170+
# N.B. `r` should be allocated since `cond` might be `false` and then it will be thrown
171+
@test any(isnew, src.code)
172+
end
173+
174+
# should include a simple alias analysis
175+
struct ImmutableOuter{T}; x::T; y::T; z::T; end
176+
mutable struct MutableOuter{T}; x::T; y::T; z::T; end
177+
let src = code_typed1((Any,Any,Any)) do x, y, z
178+
xyz = ImmutableXYZ(x, y, z)
179+
outer = ImmutableOuter(xyz, xyz, xyz)
180+
outer.x.x, outer.y.y, outer.z.z
181+
end
182+
@test !any(src.code) do @nospecialize x
183+
Meta.isexpr(x, :new)
184+
end
185+
@test any(src.code) do @nospecialize x
186+
iscall((src, tuple), x) &&
187+
x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=y=# Core.Argument(4)]
188+
end
189+
end
190+
let src = code_typed1((Any,Any,Any)) do x, y, z
191+
xyz = ImmutableXYZ(x, y, z)
192+
# #42831 forms ::PartialStruct(ImmutableOuter{Any}, Any[ImmutableXYZ, ImmutableXYZ, ImmutableXYZ])
193+
# so the succeeding `getproperty`s are type stable and inlined
194+
outer = ImmutableOuter{Any}(xyz, xyz, xyz)
195+
outer.x.x, outer.y.y, outer.z.z
196+
end
197+
@test !any(isnew, src.code)
198+
@test any(src.code) do @nospecialize x
199+
iscall((src, tuple), x) &&
200+
x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=y=# Core.Argument(4)]
201+
end
202+
end
203+
# FIXME our analysis isn't yet so powerful at this moment, e.g. it can't handle nested mutable objects
204+
let src = code_typed1((Any,Any,Any)) do x, y, z
205+
xyz = MutableXYZ(x, y, z)
206+
outer = ImmutableOuter(xyz, xyz, xyz)
207+
outer.x.x, outer.y.y, outer.z.z
208+
end
209+
@test_broken !any(isnew, src.code)
210+
end
211+
let src = code_typed1((Any,Any,Any)) do x, y, z
212+
xyz = ImmutableXYZ(x, y, z)
213+
outer = MutableOuter(xyz, xyz, xyz)
214+
outer.x.x, outer.y.y, outer.z.z
215+
end
216+
@test_broken !any(isnew, src.code)
217+
end
218+
219+
# should work nicely with inlining to optimize away a complicated case
220+
# adapted from http://wiki.luajit.org/Allocation-Sinking-Optimization#implementation%5B
221+
struct Point
222+
x::Float64
223+
y::Float64
224+
end
225+
#=@inline=# add(a::Point, b::Point) = Point(a.x + b.x, a.y + b.y)
226+
function compute()
227+
a = Point(1.5, 2.5)
228+
b = Point(2.25, 4.75)
229+
for i in 0:(100000000-1)
230+
a = add(add(a, b), b)
231+
end
232+
a.x, a.y
233+
end
234+
let src = code_typed1(compute)
235+
@test !any(isnew, src.code)
236+
end
237+
72238
mutable struct Foo30594; x::Float64; end
73239
Base.copy(x::Foo30594) = Foo30594(x.x)
74240
function add!(p::Foo30594, off::Foo30594)
@@ -180,7 +346,7 @@ let m = Meta.@lower 1 + 1
180346
src.ssaflags = fill(Int32(0), nstmts)
181347
ir = Core.Compiler.inflate_ir(src, Any[], Any[Any, Any])
182348
@test Core.Compiler.verify_ir(ir) === nothing
183-
ir = @test_nowarn Core.Compiler.getfield_elim_pass!(ir)
349+
ir = @test_nowarn Core.Compiler.sroa_pass!(ir)
184350
@test Core.Compiler.verify_ir(ir) === nothing
185351
end
186352

@@ -384,7 +550,7 @@ exc39508 = ErrorException("expected")
384550
end
385551
@test test39508() === exc39508
386552

387-
let # `getfield_elim_pass!` should work with constant globals
553+
let # `sroa_pass!` should work with constant globals
388554
# immutable pass
389555
src = @eval Module() begin
390556
const REF_FLD = :x

0 commit comments

Comments
 (0)