Skip to content

Commit 2cf9e10

Browse files
Kenoaviatesk
authored andcommitted
Allow external lattice elements to properly union split (JuliaLang#49030)
Currently `MustAlias` is the only lattice element that is allowed to widen to union types. However, there are others in external packages. Expand the support we have for this in order to allow union splitting of lattice elements. Co-authored-by: Shuhei Kadowaki <40514306+aviatesk@users.noreply.github.com>
1 parent a0c0c64 commit 2cf9e10

File tree

6 files changed

+36
-27
lines changed

6 files changed

+36
-27
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
5959
# as we may want to concrete-evaluate this frame in cases when there are
6060
# no overlayed calls, try an additional effort now to check if this call
6161
# isn't overlayed rather than just handling it conservatively
62-
matches = find_matching_methods(arginfo.argtypes, atype, method_table(interp),
62+
matches = find_matching_methods(typeinf_lattice(interp), arginfo.argtypes, atype, method_table(interp),
6363
InferenceParams(interp).max_union_splitting, max_methods)
6464
if !isa(matches, FailedMethodMatch)
6565
nonoverlayed = matches.nonoverlayed
@@ -75,7 +75,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
7575
end
7676

7777
argtypes = arginfo.argtypes
78-
matches = find_matching_methods(argtypes, atype, method_table(interp),
78+
matches = find_matching_methods(typeinf_lattice(interp), argtypes, atype, method_table(interp),
7979
InferenceParams(interp).max_union_splitting, max_methods)
8080
if isa(matches, FailedMethodMatch)
8181
add_remark!(interp, sv, matches.reason)
@@ -273,11 +273,12 @@ struct UnionSplitMethodMatches
273273
end
274274
any_ambig(m::UnionSplitMethodMatches) = any(any_ambig, m.info.matches)
275275

276-
function find_matching_methods(argtypes::Vector{Any}, @nospecialize(atype), method_table::MethodTableView,
276+
function find_matching_methods(𝕃::AbstractLattice,
277+
argtypes::Vector{Any}, @nospecialize(atype), method_table::MethodTableView,
277278
max_union_splitting::Int, max_methods::Int)
278279
# NOTE this is valid as far as any "constant" lattice element doesn't represent `Union` type
279-
if 1 < unionsplitcost(argtypes) <= max_union_splitting
280-
split_argtypes = switchtupleunion(argtypes)
280+
if 1 < unionsplitcost(𝕃, argtypes) <= max_union_splitting
281+
split_argtypes = switchtupleunion(𝕃, argtypes)
281282
infos = MethodMatchInfo[]
282283
applicable = Any[]
283284
applicable_argtypes = Vector{Any}[] # arrays like `argtypes`, including constants, for each match
@@ -1496,7 +1497,7 @@ function abstract_apply(interp::AbstractInterpreter, argtypes::Vector{Any}, si::
14961497
end
14971498
res = Union{}
14981499
nargs = length(aargtypes)
1499-
splitunions = 1 < unionsplitcost(aargtypes) <= InferenceParams(interp).max_apply_union_enum
1500+
splitunions = 1 < unionsplitcost(typeinf_lattice(interp), aargtypes) <= InferenceParams(interp).max_apply_union_enum
15001501
ctypes = [Any[aft]]
15011502
infos = Vector{MaybeAbstractIterationInfo}[MaybeAbstractIterationInfo[]]
15021503
effects = EFFECTS_TOTAL

base/compiler/abstractlattice.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,10 @@ has_mustalias(𝕃::AbstractLattice) = has_mustalias(widenlattice(𝕃))
293293
has_mustalias(::AnyMustAliasesLattice) = true
294294
has_mustalias(::JLTypeLattice) = false
295295

296+
has_extended_unionsplit(𝕃::AbstractLattice) = has_extended_unionsplit(widenlattice(𝕃))
297+
has_extended_unionsplit(::AnyMustAliasesLattice) = true
298+
has_extended_unionsplit(::JLTypeLattice) = false
299+
296300
# Curried versions
297301
(lattice::AbstractLattice) = (@nospecialize(a), @nospecialize(b)) -> (lattice, a, b)
298302
(lattice::AbstractLattice) = (@nospecialize(a), @nospecialize(b)) -> (lattice, a, b)

base/compiler/tfuncs.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2542,7 +2542,7 @@ function abstract_applicable(interp::AbstractInterpreter, argtypes::Vector{Any},
25422542
isvarargtype(argtypes[2]) && return CallMeta(Bool, EFFECTS_UNKNOWN, NoCallInfo())
25432543
argtypes = argtypes[2:end]
25442544
atype = argtypes_to_type(argtypes)
2545-
matches = find_matching_methods(argtypes, atype, method_table(interp),
2545+
matches = find_matching_methods(typeinf_lattice(interp), argtypes, atype, method_table(interp),
25462546
InferenceParams(interp).max_union_splitting, max_methods)
25472547
if isa(matches, FailedMethodMatch)
25482548
rt = Bool # too many matches to analyze

base/compiler/typelattice.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,8 @@ end
120120
MustAlias(var::SlotNumber, @nospecialize(vartyp), fldidx::Int, @nospecialize(fldtyp)) =
121121
MustAlias(slot_id(var), vartyp, fldidx, fldtyp)
122122

123+
_uniontypes(x::MustAlias, ts) = _uniontypes(widenconst(x), ts)
124+
123125
"""
124126
alias::InterMustAlias
125127

base/compiler/typeutils.jl

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ function typesubtract(@nospecialize(a), @nospecialize(b), max_union_splitting::I
165165
if ub isa DataType
166166
if a.name === ub.name === Tuple.name &&
167167
length(a.parameters) == length(ub.parameters)
168-
if 1 < unionsplitcost(a.parameters) <= max_union_splitting
168+
if 1 < unionsplitcost(JLTypeLattice(), a.parameters) <= max_union_splitting
169169
ta = switchtupleunion(a)
170170
return typesubtract(Union{ta...}, b, 0)
171171
elseif b isa DataType
@@ -227,12 +227,11 @@ end
227227
# or outside of the Tuple/Union nesting, though somewhat more expensive to be
228228
# outside than inside because the representation is larger (because and it
229229
# informs the callee whether any splitting is possible).
230-
function unionsplitcost(argtypes::Union{SimpleVector,Vector{Any}})
230+
function unionsplitcost(𝕃::AbstractLattice, argtypes::Union{SimpleVector,Vector{Any}})
231231
nu = 1
232232
max = 2
233233
for ti in argtypes
234-
# TODO remove this to implement callsite refinement of MustAlias
235-
if isa(ti, MustAlias) && isa(widenconst(ti), Union)
234+
if has_extended_unionsplit(𝕃) && !isvarargtype(ti)
236235
ti = widenconst(ti)
237236
end
238237
if isa(ti, Union)
@@ -252,12 +251,12 @@ end
252251
# and `Union{return...} == ty`
253252
function switchtupleunion(@nospecialize(ty))
254253
tparams = (unwrap_unionall(ty)::DataType).parameters
255-
return _switchtupleunion(Any[tparams...], length(tparams), [], ty)
254+
return _switchtupleunion(JLTypeLattice(), Any[tparams...], length(tparams), [], ty)
256255
end
257256

258-
switchtupleunion(argtypes::Vector{Any}) = _switchtupleunion(argtypes, length(argtypes), [], nothing)
257+
switchtupleunion(𝕃::AbstractLattice, argtypes::Vector{Any}) = _switchtupleunion(𝕃, argtypes, length(argtypes), [], nothing)
259258

260-
function _switchtupleunion(t::Vector{Any}, i::Int, tunion::Vector{Any}, @nospecialize(origt))
259+
function _switchtupleunion(𝕃::AbstractLattice, t::Vector{Any}, i::Int, tunion::Vector{Any}, @nospecialize(origt))
261260
if i == 0
262261
if origt === nothing
263262
push!(tunion, copy(t))
@@ -268,17 +267,20 @@ function _switchtupleunion(t::Vector{Any}, i::Int, tunion::Vector{Any}, @nospeci
268267
else
269268
origti = ti = t[i]
270269
# TODO remove this to implement callsite refinement of MustAlias
271-
if isa(ti, MustAlias) && isa(widenconst(ti), Union)
272-
ti = widenconst(ti)
273-
end
274270
if isa(ti, Union)
275-
for ty in uniontypes(ti::Union)
271+
for ty in uniontypes(ti)
272+
t[i] = ty
273+
_switchtupleunion(𝕃, t, i - 1, tunion, origt)
274+
end
275+
t[i] = origti
276+
elseif has_extended_unionsplit(𝕃) && !isa(ti, Const) && !isvarargtype(ti) && isa(widenconst(ti), Union)
277+
for ty in uniontypes(ti)
276278
t[i] = ty
277-
_switchtupleunion(t, i - 1, tunion, origt)
279+
_switchtupleunion(𝕃, t, i - 1, tunion, origt)
278280
end
279281
t[i] = origti
280282
else
281-
_switchtupleunion(t, i - 1, tunion, origt)
283+
_switchtupleunion(𝕃, t, i - 1, tunion, origt)
282284
end
283285
end
284286
return tunion

test/compiler/inference.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2944,11 +2944,11 @@ end
29442944
# issue #28356
29452945
# unit test to make sure countunionsplit overflows gracefully
29462946
# we don't care what number is returned as long as it's large
2947-
@test Core.Compiler.unionsplitcost(Any[Union{Int32, Int64} for i=1:80]) > 100000
2948-
@test Core.Compiler.unionsplitcost(Any[Union{Int8, Int16, Int32, Int64}]) == 2
2949-
@test Core.Compiler.unionsplitcost(Any[Union{Int8, Int16, Int32, Int64}, Union{Int8, Int16, Int32, Int64}, Int8]) == 8
2950-
@test Core.Compiler.unionsplitcost(Any[Union{Int8, Int16, Int32, Int64}, Union{Int8, Int16, Int32}, Int8]) == 6
2951-
@test Core.Compiler.unionsplitcost(Any[Union{Int8, Int16, Int32}, Union{Int8, Int16, Int32, Int64}, Int8]) == 6
2947+
@test Core.Compiler.unionsplitcost(Core.Compiler.JLTypeLattice(), Any[Union{Int32, Int64} for i=1:80]) > 100000
2948+
@test Core.Compiler.unionsplitcost(Core.Compiler.JLTypeLattice(), Any[Union{Int8, Int16, Int32, Int64}]) == 2
2949+
@test Core.Compiler.unionsplitcost(Core.Compiler.JLTypeLattice(), Any[Union{Int8, Int16, Int32, Int64}, Union{Int8, Int16, Int32, Int64}, Int8]) == 8
2950+
@test Core.Compiler.unionsplitcost(Core.Compiler.JLTypeLattice(), Any[Union{Int8, Int16, Int32, Int64}, Union{Int8, Int16, Int32}, Int8]) == 6
2951+
@test Core.Compiler.unionsplitcost(Core.Compiler.JLTypeLattice(), Any[Union{Int8, Int16, Int32}, Union{Int8, Int16, Int32, Int64}, Int8]) == 6
29522952

29532953
# make sure compiler doesn't hang in union splitting
29542954

@@ -3949,13 +3949,13 @@ end
39493949

39503950
# argtypes
39513951
let
3952-
tunion = Core.Compiler.switchtupleunion(Any[Union{Int32,Int64}, Core.Const(nothing)])
3952+
tunion = Core.Compiler.switchtupleunion(Core.Compiler.ConstsLattice(), Any[Union{Int32,Int64}, Core.Const(nothing)])
39533953
@test length(tunion) == 2
39543954
@test Any[Int32, Core.Const(nothing)] in tunion
39553955
@test Any[Int64, Core.Const(nothing)] in tunion
39563956
end
39573957
let
3958-
tunion = Core.Compiler.switchtupleunion(Any[Union{Int32,Int64}, Union{Float32,Float64}, Core.Const(nothing)])
3958+
tunion = Core.Compiler.switchtupleunion(Core.Compiler.ConstsLattice(), Any[Union{Int32,Int64}, Union{Float32,Float64}, Core.Const(nothing)])
39593959
@test length(tunion) == 4
39603960
@test Any[Int32, Float32, Core.Const(nothing)] in tunion
39613961
@test Any[Int32, Float64, Core.Const(nothing)] in tunion

0 commit comments

Comments
 (0)