Skip to content

Commit 68f71be

Browse files
committed
inference: form PartialStruct for extra type information propagation
This commit forms `PartialStruct` whenever there is any type-level refinement available about a field, even if it's not "constant" information. In Julia "definitions" are allowed to be abstract whereas "usages" (i.e. callsites) are often concrete. The basic idea is to allow inference to make more use of such precise callsite type information by encoding it as `PartialStruct`. This may increase optimization possibilities of "unidiomatic" Julia code, which may contain poorly-typed definitions, like this very contrived example: ```julia struct Problem n; s; c; t end function main(args...) prob = Problem(args...) s = 0 for i in 1:prob.n m = mod(i, 3) s += m == 0 ? sin(prob.s) : m == 1 ? cos(prob.c) : tan(prob.t) end return prob, s end main(10000, 1, 2, 3) ``` One of the obvious limitation is that this extra type information can be propagated inter-procedurally only as a const-propagation. I'm not sure this kind of "just a type-level" refinement can often make constant-prop' successful (i.e. shape-up a method body and allow it to be inlined, encoding the extra type information into the generated code), thus I didn't not modify any part of const-prop' heuristics. So the improvements from this change is almost for local analysis, and for very simple inter-procedural calls.
1 parent e38d3b7 commit 68f71be

File tree

4 files changed

+50
-27
lines changed

4 files changed

+50
-27
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1542,22 +1542,26 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
15421542
if isconcretetype(t) && !ismutabletype(t)
15431543
args = Vector{Any}(undef, length(e.args)-1)
15441544
ats = Vector{Any}(undef, length(e.args)-1)
1545-
anyconst = false
1546-
allconst = true
1545+
local anyconst = anyrefine = false
1546+
local allconst = true
15471547
for i = 2:length(e.args)
15481548
at = widenconditional(abstract_eval_value(interp, e.args[i], vtypes, sv))
15491549
if !anyconst
1550-
anyconst = has_nontrivial_const_info(at)
1550+
if has_nontrivial_const_info(at)
1551+
anyconst = true
1552+
elseif !anyrefine
1553+
anyrefine = at fieldtype(t, i - 1)
1554+
end
15511555
end
15521556
ats[i-1] = at
15531557
if at === Bottom
15541558
t = Bottom
1555-
allconst = anyconst = false
1559+
anyconst = anyrefine = allconst = false
15561560
break
15571561
elseif at isa Const
15581562
if !(at.val isa fieldtype(t, i - 1))
15591563
t = Bottom
1560-
allconst = anyconst = false
1564+
anyconst = anyrefine = allconst = false
15611565
break
15621566
end
15631567
args[i-1] = at.val
@@ -1569,7 +1573,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
15691573
if t !== Bottom && fieldcount(t) == length(ats)
15701574
if allconst
15711575
t = Const(ccall(:jl_new_structv, Any, (Any, Ptr{Cvoid}, UInt32), t, args, length(args)))
1572-
elseif anyconst
1576+
elseif anyconst || anyrefine
15731577
t = PartialStruct(t, ats)
15741578
end
15751579
end
@@ -1741,17 +1745,21 @@ function widenreturn(@nospecialize(rt), @nospecialize(bestguess), nslots::Int, s
17411745
isa(rt, Type) && return rt
17421746
if isa(rt, PartialStruct)
17431747
fields = copy(rt.fields)
1744-
haveconst = false
1748+
local anyconst = anyrefine = false
17451749
for i in 1:length(fields)
17461750
a = fields[i]
17471751
a = isvarargtype(a) ? a : widenreturn(a, bestguess, nslots, slottypes, changes)
1748-
if !haveconst && has_const_info(a)
1749-
# TODO: consider adding && const_prop_profitable(a) here?
1750-
haveconst = true
1752+
if !anyconst
1753+
if has_const_info(a)
1754+
# TODO: consider adding && const_prop_profitable(a) here?
1755+
anyconst = true
1756+
elseif !anyrefine
1757+
anyrefine = a fieldtype(rt.typ, i)
1758+
end
17511759
end
17521760
fields[i] = a
17531761
end
1754-
haveconst && return PartialStruct(rt.typ, fields)
1762+
(anyconst || anyrefine) && return PartialStruct(rt.typ, fields)
17551763
end
17561764
if isa(rt, PartialOpaque)
17571765
return rt # XXX: this case was missed in #39512

base/compiler/typelattice.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ function ⊑(@nospecialize(a), @nospecialize(b))
239239
return a === b
240240
end
241241
end
242+
(@nospecialize(a), @nospecialize(b)) = !(b, a)
242243

243244
# Check if two lattice elements are partial order equivalent. This is basically
244245
# `a ⊑ b && b ⊑ a` but with extra performance optimizations.

test/compiler/inference.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3645,3 +3645,26 @@ end
36453645

36463646
# issue #42646
36473647
@test only(Base.return_types(getindex, (Array{undef}, Int))) >: Union{} # check that it does not throw
3648+
3649+
# form PartialStruct for extra type information propagation
3650+
struct FieldTypeRefinement{S,T}
3651+
s::S
3652+
t::T
3653+
end
3654+
@test Base.return_types((Int,)) do s
3655+
o = FieldTypeRefinement{Any,Int}(s, s)
3656+
o.s
3657+
end |> only == Int
3658+
@test Base.return_types((Int,)) do s
3659+
o = FieldTypeRefinement{Int,Any}(s, s)
3660+
o.t
3661+
end |> only == Int
3662+
@test Base.return_types((Int,)) do s
3663+
o = FieldTypeRefinement{Any,Any}(s, s)
3664+
o.s, o.t
3665+
end |> only == Tuple{Int,Int}
3666+
@test Base.return_types((Int,)) do a
3667+
s1 = Some{Any}(a)
3668+
s2 = Some{Any}(s1)
3669+
s2.value.value
3670+
end |> only == Int

test/compiler/irpasses.jl

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -426,31 +426,22 @@ let # `getfield_elim_pass!` should work with constant globals
426426
end
427427
end
428428

429-
let # `typeassert_elim_pass!`
429+
let
430+
# `typeassert` elimination after SROA
431+
# NOTE we can remove this optimization once inference is able to reason about memory-effects
430432
src = @eval Module() begin
431-
struct Foo; x; end
433+
mutable struct Foo; x; end
432434

433435
code_typed((Int,)) do a
434436
x1 = Foo(a)
435437
x2 = Foo(x1)
436-
x3 = Foo(x2)
437-
438-
r1 = (x2.x::Foo).x
439-
r2 = (x2.x::Foo).x::Int
440-
r3 = (x2.x::Foo).x::Integer
441-
r4 = ((x3.x::Foo).x::Foo).x
442-
443-
return r1, r2, r3, r4
438+
return typeassert(x2.x, Foo).x
444439
end |> only |> first
445440
end
446-
# eliminate `typeassert(f2.a, Foo)`
447-
@test all(src.code) do @nospecialize(stmt)
441+
# eliminate `typeassert(x2.x, Foo)`
442+
@test all(src.code) do @nospecialize stmt
448443
Meta.isexpr(stmt, :call) || return true
449444
ft = Core.Compiler.argextype(stmt.args[1], src, Any[], src.slottypes)
450445
return Core.Compiler.widenconst(ft) !== typeof(typeassert)
451446
end
452-
# succeeding simple DCE will eliminate `Foo(a)`
453-
@test all(src.code) do @nospecialize(stmt)
454-
return !Meta.isexpr(stmt, :new)
455-
end
456447
end

0 commit comments

Comments
 (0)