Skip to content

Commit d0b15c2

Browse files
authored
lattice: Thread through lattice argument for getfield_tfunc (#47097)
Like `tuple`, `getfield` needs some lattice awareness to give the correct answer in the presence of extended lattices. Refactor to split and thread through the lattice argument through _getfield_tfunc so external lattices can provide `getfield` tfuncs for their custom elements.
1 parent 25e3809 commit d0b15c2

File tree

4 files changed

+76
-35
lines changed

4 files changed

+76
-35
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1326,13 +1326,13 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
13261326
if !isa(stateordonet_widened, DataType) || !(stateordonet_widened <: Tuple) || isvatuple(stateordonet_widened) || length(stateordonet_widened.parameters) != 2
13271327
break
13281328
end
1329-
nstatetype = getfield_tfunc(stateordonet, Const(2))
1329+
nstatetype = getfield_tfunc(typeinf_lattice(interp), stateordonet, Const(2))
13301330
# If there's no new information in this statetype, don't bother continuing,
13311331
# the iterator won't be finite.
13321332
if (typeinf_lattice(interp), nstatetype, statetype)
13331333
return Any[Bottom], nothing
13341334
end
1335-
valtype = getfield_tfunc(stateordonet, Const(1))
1335+
valtype = getfield_tfunc(typeinf_lattice(interp), stateordonet, Const(1))
13361336
push!(ret, valtype)
13371337
statetype = nstatetype
13381338
call = abstract_call_known(interp, iteratef, ArgInfo(nothing, Any[Const(iteratef), itertype, statetype]), StmtInfo(true), sv)

base/compiler/tfuncs.jl

Lines changed: 70 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -854,25 +854,33 @@ function getfield_nothrow(@nospecialize(s00), @nospecialize(name), boundscheck::
854854
return false
855855
end
856856

857-
function getfield_tfunc(s00, name, boundscheck_or_order)
858-
@nospecialize
857+
function getfield_tfunc(@specialize(lattice::AbstractLattice), @nospecialize(s00),
858+
@nospecialize(name), @nospecialize(boundscheck_or_order))
859859
t = isvarargtype(boundscheck_or_order) ? unwrapva(boundscheck_or_order) :
860860
widenconst(boundscheck_or_order)
861861
hasintersect(t, Symbol) || hasintersect(t, Bool) || return Bottom
862-
return getfield_tfunc(s00, name)
862+
return getfield_tfunc(lattice, s00, name)
863863
end
864-
function getfield_tfunc(s00, name, order, boundscheck)
865-
@nospecialize
864+
function getfield_tfunc(@nospecialize(s00), name, boundscheck_or_order)
865+
return getfield_tfunc(fallback_lattice, s00, name, boundscheck_or_order)
866+
end
867+
function getfield_tfunc(@specialize(lattice::AbstractLattice), @nospecialize(s00),
868+
@nospecialize(name), @nospecialize(order), @nospecialize(boundscheck))
866869
hasintersect(widenconst(order), Symbol) || return Bottom
867870
if isvarargtype(boundscheck)
868871
t = unwrapva(boundscheck)
869872
hasintersect(t, Symbol) || hasintersect(t, Bool) || return Bottom
870873
else
871874
hasintersect(widenconst(boundscheck), Bool) || return Bottom
872875
end
873-
return getfield_tfunc(s00, name)
876+
return getfield_tfunc(lattice, s00, name)
874877
end
875-
getfield_tfunc(@nospecialize(s00), @nospecialize(name)) = _getfield_tfunc(s00, name, false)
878+
function getfield_tfunc(@nospecialize(s00), @nospecialize(name), @nospecialize(order), @nospecialize(boundscheck))
879+
return getfield_tfunc(fallback_lattice, s00, name, order, boundscheck)
880+
end
881+
getfield_tfunc(@nospecialize(s00), @nospecialize(name)) = _getfield_tfunc(fallback_lattice, s00, name, false)
882+
getfield_tfunc(@specialize(lattice::AbstractLattice), @nospecialize(s00), @nospecialize(name)) = _getfield_tfunc(lattice, s00, name, false)
883+
876884

877885
function _getfield_fieldindex(@nospecialize(s), name::Const)
878886
nv = name.val
@@ -902,10 +910,46 @@ function _getfield_tfunc_const(@nospecialize(sv), name::Const, setfield::Bool)
902910
return nothing
903911
end
904912

905-
function _getfield_tfunc(@nospecialize(s00), @nospecialize(name), setfield::Bool)
906-
if isa(s00, Conditional)
913+
function _getfield_tfunc(@specialize(lattice::InferenceLattice), @nospecialize(s00), @nospecialize(name), setfield::Bool)
914+
if isa(s00, LimitedAccuracy)
915+
# This will error, but it's better than duplicating the error here
916+
s00 = widenconst(s00)
917+
end
918+
return _getfield_tfunc(widenlattice(lattice), s00, name, setfield)
919+
end
920+
921+
function _getfield_tfunc(@specialize(lattice::OptimizerLattice), @nospecialize(s00), @nospecialize(name), setfield::Bool)
922+
# If undef, that's a Union, but that doesn't affect the rt when tmerged
923+
# into the unwrapped result.
924+
isa(s00, MaybeUndef) && (s00 = s00.typ)
925+
return _getfield_tfunc(widenlattice(lattice), s00, name, setfield)
926+
end
927+
928+
function _getfield_tfunc(@specialize(lattice::AnyConditionalsLattice), @nospecialize(s00), @nospecialize(name), setfield::Bool)
929+
if isa(s00, AnyConditional)
907930
return Bottom # Bool has no fields
908-
elseif isa(s00, Const)
931+
end
932+
return _getfield_tfunc(widenlattice(lattice), s00, name, setfield)
933+
end
934+
935+
function _getfield_tfunc(@specialize(lattice::PartialsLattice), @nospecialize(s00), @nospecialize(name), setfield::Bool)
936+
if isa(s00, PartialStruct)
937+
s = widenconst(s00)
938+
sty = unwrap_unionall(s)::DataType
939+
if isa(name, Const)
940+
nv = _getfield_fieldindex(sty, name)
941+
if isa(nv, Int) && 1 <= nv <= length(s00.fields)
942+
return unwrapva(s00.fields[nv])
943+
end
944+
end
945+
s00 = s
946+
end
947+
948+
return _getfield_tfunc(widenlattice(lattice), s00, name, setfield)
949+
end
950+
951+
function _getfield_tfunc(lattice::ConstsLattice, @nospecialize(s00), @nospecialize(name), setfield::Bool)
952+
if isa(s00, Const)
909953
sv = s00.val
910954
if isa(name, Const)
911955
nv = name.val
@@ -919,30 +963,24 @@ function _getfield_tfunc(@nospecialize(s00), @nospecialize(name), setfield::Bool
919963
r = _getfield_tfunc_const(sv, name, setfield)
920964
r !== nothing && return r
921965
end
922-
s = typeof(sv)
923-
elseif isa(s00, PartialStruct)
924-
s = widenconst(s00)
925-
sty = unwrap_unionall(s)::DataType
926-
if isa(name, Const)
927-
nv = _getfield_fieldindex(sty, name)
928-
if isa(nv, Int) && 1 <= nv <= length(s00.fields)
929-
return unwrapva(s00.fields[nv])
930-
end
931-
end
932-
else
933-
s = unwrap_unionall(s00)
966+
s00 = widenconst(s00)
934967
end
968+
return _getfield_tfunc(widenlattice(lattice), s00, name, setfield)
969+
end
970+
971+
function _getfield_tfunc(lattice::JLTypeLattice, @nospecialize(s00), @nospecialize(name), setfield::Bool)
972+
s = unwrap_unionall(s00)
935973
if isa(s, Union)
936-
return tmerge(_getfield_tfunc(rewrap_unionall(s.a, s00), name, setfield),
937-
_getfield_tfunc(rewrap_unionall(s.b, s00), name, setfield))
974+
return tmerge(_getfield_tfunc(lattice, rewrap_unionall(s.a, s00), name, setfield),
975+
_getfield_tfunc(lattice, rewrap_unionall(s.b, s00), name, setfield))
938976
end
939977
if isType(s)
940978
if isconstType(s)
941979
sv = s00.parameters[1]
942-
if isa(name, Const)
980+
if isa(name, Const)
943981
r = _getfield_tfunc_const(sv, name, setfield)
944982
r !== nothing && return r
945-
end
983+
end
946984
s = typeof(sv)
947985
else
948986
sv = s.parameters[1]
@@ -982,7 +1020,7 @@ function _getfield_tfunc(@nospecialize(s00), @nospecialize(name), setfield::Bool
9821020
if !(_ts <: Tuple)
9831021
return Any
9841022
end
985-
return _getfield_tfunc(_ts, name, setfield)
1023+
return _getfield_tfunc(lattice, _ts, name, setfield)
9861024
end
9871025
ftypes = datatype_fieldtypes(s)
9881026
nf = length(ftypes)
@@ -1090,7 +1128,7 @@ end
10901128
function setfield!_tfunc(o, f, v)
10911129
@nospecialize
10921130
mutability_errorcheck(o) || return Bottom
1093-
ft = _getfield_tfunc(o, f, true)
1131+
ft = _getfield_tfunc(fallback_lattice, o, f, true)
10941132
ft === Bottom && return Bottom
10951133
hasintersect(widenconst(v), widenconst(ft)) || return Bottom
10961134
return v
@@ -1168,7 +1206,7 @@ function abstract_modifyfield!(interp::AbstractInterpreter, argtypes::Vector{Any
11681206
# as well as compute the info for the method matches
11691207
op = unwrapva(argtypes[4])
11701208
v = unwrapva(argtypes[5])
1171-
TF = getfield_tfunc(o, f)
1209+
TF = getfield_tfunc(typeinf_lattice(interp), o, f)
11721210
callinfo = abstract_call(interp, ArgInfo(nothing, Any[op, TF, v]), StmtInfo(true), sv, #=max_methods=# 1)
11731211
TF2 = tmeet(callinfo.rt, widenconst(TF))
11741212
if TF2 === Bottom
@@ -2118,6 +2156,9 @@ function builtin_tfunction(interp::AbstractInterpreter, @nospecialize(f), argtyp
21182156
# wrong # of args
21192157
return Bottom
21202158
end
2159+
if f === getfield
2160+
return getfield_tfunc(typeinf_lattice(interp), argtypes...)
2161+
end
21212162
return tf[3](argtypes...)
21222163
end
21232164

base/compiler/typelattice.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,7 @@ function tmeet(lattice::PartialsLattice, @nospecialize(v), @nospecialize(t::Type
419419
if isvarargtype(vfi)
420420
new_fields[i] = vfi
421421
else
422-
new_fields[i] = tmeet(lattice, vfi, widenconst(getfield_tfunc(t, Const(i))))
422+
new_fields[i] = tmeet(lattice, vfi, widenconst(getfield_tfunc(lattice, t, Const(i))))
423423
if new_fields[i] === Bottom
424424
return Bottom
425425
end

base/compiler/typelimits.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ function issimplertype(lattice::AbstractLattice, @nospecialize(typea), @nospecia
321321
bi = (tni.val::Core.TypeName).wrapper
322322
is_lattice_equal(lattice, ai, bi) && continue
323323
end
324-
bi = getfield_tfunc(typeb, Const(i))
324+
bi = getfield_tfunc(lattice, typeb, Const(i))
325325
is_lattice_equal(lattice, ai, bi) && continue
326326
# It is not enough for ai to be simpler than bi: it must exactly equal
327327
# (for this, an invariant struct field, by contrast to
@@ -490,8 +490,8 @@ function tmerge(lattice::PartialsLattice, @nospecialize(typea), @nospecialize(ty
490490
fields = Vector{Any}(undef, type_nfields)
491491
anyrefine = false
492492
for i = 1:type_nfields
493-
ai = getfield_tfunc(typea, Const(i))
494-
bi = getfield_tfunc(typeb, Const(i))
493+
ai = getfield_tfunc(lattice, typea, Const(i))
494+
bi = getfield_tfunc(lattice, typeb, Const(i))
495495
ft = fieldtype(aty, i)
496496
if is_lattice_equal(lattice, ai, bi) || is_lattice_equal(lattice, ai, ft)
497497
# Since ai===bi, the given type has no restrictions on complexity.

0 commit comments

Comments
 (0)