Skip to content

Commit 2a406b2

Browse files
Kenoaviatesk
andauthored
Add pattern matching for typeof into field type tparam (#50422)
* Add pattern matching for `typeof` into field type tparam This PR allows full elimination of the following, even in ill-typed code. ``` struct TParamTypeofTest{T} x::T @eval TParamTypeofTest(x) = $(Expr(:new, :(TParamTypeofTest{typeof(x)}), :x)) end function tparam_typeof_test_elim(x) TParamTypeofTest(x).x end ``` Before this PR, we would get: ``` julia> code_typed(tparam_typeof_test_elim, Tuple{Any}) 1-element Vector{Any}: CodeInfo( 1 ─ %1 = Main.typeof(x)::DataType β”‚ %2 = Core.apply_type(Main.TParamTypeofTest, %1)::Type{TParamTypeofTest{_A}} where _A β”‚ %new(%2, x)::TParamTypeofTest └── return x ``` Where the `new` is non-eliminable, because the compiler did not know that `x::_A`. Fix this by pattern matching this particular pattern (where the condition is guaranteed, because we computed `_A` by `typeof`). This is not particularly general, but this pattern comes up a lot, so it's surprisingly effective. * add test case for optimizing multiple abstract fields * improve robustness --------- Co-authored-by: Shuhei Kadowaki <aviatesk@gmail.com>
1 parent 46477cc commit 2a406b2

File tree

3 files changed

+110
-30
lines changed

3 files changed

+110
-30
lines changed

β€Žbase/compiler/optimize.jlβ€Ž

Lines changed: 36 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,41 @@ is_stmt_inline(stmt_flag::UInt8) = stmt_flag & IR_FLAG_INLINE β‰  0
215215
is_stmt_noinline(stmt_flag::UInt8) = stmt_flag & IR_FLAG_NOINLINE β‰  0
216216
is_stmt_throw_block(stmt_flag::UInt8) = stmt_flag & IR_FLAG_THROW_BLOCK β‰  0
217217

218+
function new_expr_effect_flags(𝕃ₒ::AbstractLattice, args::Vector{Any}, src::Union{IRCode,IncrementalCompact}, pattern_match=nothing)
219+
Targ = args[1]
220+
atyp = argextype(Targ, src)
221+
# `Expr(:new)` of unknown type could raise arbitrary TypeError.
222+
typ, isexact = instanceof_tfunc(atyp)
223+
if !isexact
224+
atyp = unwrap_unionall(widenconst(atyp))
225+
if isType(atyp) && isTypeDataType(atyp.parameters[1])
226+
typ = atyp.parameters[1]
227+
else
228+
return (false, false, false)
229+
end
230+
isabstracttype(typ) && return (false, false, false)
231+
else
232+
isconcretedispatch(typ) || return (false, false, false)
233+
end
234+
typ = typ::DataType
235+
fcount = datatype_fieldcount(typ)
236+
fcount === nothing && return (false, false, false)
237+
fcount >= length(args) - 1 || return (false, false, false)
238+
for fidx in 1:(length(args) - 1)
239+
farg = args[fidx + 1]
240+
eT = argextype(farg, src)
241+
fT = fieldtype(typ, fidx)
242+
if !isexact && has_free_typevars(fT)
243+
if pattern_match !== nothing && pattern_match(src, typ, fidx, Targ, farg)
244+
continue
245+
end
246+
return (false, false, false)
247+
end
248+
βŠ‘(𝕃ₒ, eT, fT) || return (false, false, false)
249+
end
250+
return (false, true, true)
251+
end
252+
218253
"""
219254
stmt_effect_flags(stmt, rt, src::Union{IRCode,IncrementalCompact}) ->
220255
(consistent::Bool, effect_free_and_nothrow::Bool, nothrow::Bool)
@@ -264,36 +299,7 @@ function stmt_effect_flags(𝕃ₒ::AbstractLattice, @nospecialize(stmt), @nospe
264299
nothrow = is_nothrow(effects)
265300
return (consistent, effect_free & nothrow, nothrow)
266301
elseif head === :new
267-
atyp = argextype(args[1], src)
268-
# `Expr(:new)` of unknown type could raise arbitrary TypeError.
269-
typ, isexact = instanceof_tfunc(atyp)
270-
if !isexact
271-
atyp = unwrap_unionall(widenconst(atyp))
272-
if isType(atyp) && isTypeDataType(atyp.parameters[1])
273-
typ = atyp.parameters[1]
274-
else
275-
return (false, false, false)
276-
end
277-
isabstracttype(typ) && return (false, false, false)
278-
else
279-
isconcretedispatch(typ) || return (false, false, false)
280-
end
281-
typ = typ::DataType
282-
fcount = datatype_fieldcount(typ)
283-
fcount === nothing && return (false, false, false)
284-
fcount >= length(args) - 1 || return (false, false, false)
285-
for fld_idx in 1:(length(args) - 1)
286-
eT = argextype(args[fld_idx + 1], src)
287-
fT = fieldtype(typ, fld_idx)
288-
# Currently, we cannot represent any type equality constraints
289-
# in the lattice, so if we see any type of type parameter,
290-
# there is very little we can say about it
291-
if !isexact && has_free_typevars(fT)
292-
return (false, false, false)
293-
end
294-
βŠ‘(𝕃ₒ, eT, fT) || return (false, false, false)
295-
end
296-
return (false, true, true)
302+
return new_expr_effect_flags(𝕃ₒ, args, src)
297303
elseif head === :foreigncall
298304
effects = foreigncall_effects(stmt) do @nospecialize x
299305
argextype(x, src)

β€Žbase/compiler/ssair/passes.jlβ€Ž

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -908,6 +908,62 @@ end
908908
return nothing
909909
end
910910

911+
struct IsEgal <: Function
912+
x::Any
913+
IsEgal(@nospecialize(x)) = new(x)
914+
end
915+
(x::IsEgal)(@nospecialize(y)) = x.x === y
916+
917+
# This tries to match patterns of the form
918+
# %ft = typeof(%farg)
919+
# %Targ = apply_type(Foo, ft)
920+
# %x = new(%Targ, %farg)
921+
#
922+
# and if possible refines the nothrowness of the new expr based on it.
923+
function pattern_match_typeof(compact::IncrementalCompact, typ::DataType, fidx::Int,
924+
@nospecialize(Targ), @nospecialize(farg))
925+
isa(Targ, SSAValue) || return false
926+
927+
Tdef = compact[Targ][:inst]
928+
is_known_call(Tdef, Core.apply_type, compact) || return false
929+
length(Tdef.args) β‰₯ 2 || return false
930+
931+
applyT = argextype(Tdef.args[2], compact)
932+
isa(applyT, Const) || return false
933+
934+
applyT = applyT.val
935+
tvars = Any[]
936+
while isa(applyT, UnionAll)
937+
applyTvar = applyT.var
938+
applyT = applyT.body
939+
push!(tvars, applyTvar)
940+
end
941+
942+
applyT.name === typ.name || return false
943+
fT = fieldtype(applyT, fidx)
944+
idx = findfirst(IsEgal(fT), tvars)
945+
idx === nothing && return false
946+
checkbounds(Bool, Tdef.args, 2+idx) || return false
947+
valarg = Tdef.args[2+idx]
948+
isa(valarg, SSAValue) || return false
949+
valdef = compact[valarg][:inst]
950+
is_known_call(valdef, typeof, compact) || return false
951+
952+
return valdef.args[2] === farg
953+
end
954+
955+
function refine_new_effects!(𝕃ₒ::AbstractLattice, compact::IncrementalCompact, idx::Int, stmt::Expr)
956+
(consistent, effect_free_and_nothrow, nothrow) = new_expr_effect_flags(𝕃ₒ, stmt.args, compact, pattern_match_typeof)
957+
if consistent
958+
compact[SSAValue(idx)][:flag] |= IR_FLAG_CONSISTENT
959+
end
960+
if effect_free_and_nothrow
961+
compact[SSAValue(idx)][:flag] |= IR_FLAG_EFFECT_FREE | IR_FLAG_NOTHROW
962+
elseif nothrow
963+
compact[SSAValue(idx)][:flag] |= IR_FLAG_NOTHROW
964+
end
965+
end
966+
911967
# NOTE we use `IdSet{Int}` instead of `BitSet` for in these passes since they work on IR after inlining,
912968
# which can be very large sometimes, and program counters in question are often very sparse
913969
const SPCSet = IdSet{Int}
@@ -1037,6 +1093,8 @@ function sroa_pass!(ir::IRCode, inlining::Union{Nothing,InliningState}=nothing)
10371093
lift_comparison!(===, compact, idx, stmt, lifting_cache, 𝕃ₒ)
10381094
elseif is_known_call(stmt, isa, compact)
10391095
lift_comparison!(isa, compact, idx, stmt, lifting_cache, 𝕃ₒ)
1096+
elseif isexpr(stmt, :new) && (compact[SSAValue(idx)][:flag] & IR_FLAG_NOTHROW) == 0x00
1097+
refine_new_effects!(𝕃ₒ, compact, idx, stmt)
10401098
end
10411099
continue
10421100
end

β€Žtest/compiler/irpasses.jlβ€Ž

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1355,3 +1355,19 @@ let src = code_typed1(mut50285, Tuple{Bool, Int, Float64})
13551355
@test count(isnew, src.code) == 0
13561356
@test count(iscall((src, typeassert)), src.code) == 0
13571357
end
1358+
1359+
# Test that we can eliminate new{typeof(x)}(x)
1360+
struct TParamTypeofTest1{T}
1361+
x::T
1362+
@eval TParamTypeofTest1(x) = $(Expr(:new, :(TParamTypeofTest1{typeof(x)}), :x))
1363+
end
1364+
tparam_typeof_test_elim1(x) = TParamTypeofTest1(x).x
1365+
@test fully_eliminated(tparam_typeof_test_elim1, Tuple{Any})
1366+
1367+
struct TParamTypeofTest2{S,T}
1368+
x::S
1369+
y::T
1370+
@eval TParamTypeofTest2(x, y) = $(Expr(:new, :(TParamTypeofTest2{typeof(x),typeof(y)}), :x, :y))
1371+
end
1372+
tparam_typeof_test_elim2(x, y) = TParamTypeofTest2(x, y).x
1373+
@test fully_eliminated(tparam_typeof_test_elim2, Tuple{Any,Any})

0 commit comments

Comments
Β (0)