Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Subtyping and inference of user defined variadic types #16076

Merged
merged 14 commits into from
Sep 13, 2023
231 changes: 109 additions & 122 deletions mypy/constraints.py

Large diffs are not rendered by default.

11 changes: 10 additions & 1 deletion mypy/erasetype.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,16 @@ def visit_deleted_type(self, t: DeletedType) -> ProperType:
return t

def visit_instance(self, t: Instance) -> ProperType:
return Instance(t.type, [AnyType(TypeOfAny.special_form)] * len(t.args), t.line)
args: list[Type] = []
for tv in t.type.defn.type_vars:
# Valid erasure for *Ts is *tuple[Any, ...], not just Any.
if isinstance(tv, TypeVarTupleType):
args.append(
tv.tuple_fallback.copy_modified(args=[AnyType(TypeOfAny.special_form)])
)
else:
args.append(AnyType(TypeOfAny.special_form))
return Instance(t.type, args, t.line)

def visit_type_var(self, t: TypeVarType) -> ProperType:
return AnyType(TypeOfAny.special_form)
Expand Down
3 changes: 2 additions & 1 deletion mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,8 @@ def visit_param_spec(self, t: ParamSpecType) -> Type:
variables=[*t.prefix.variables, *repl.variables],
)
else:
# TODO: replace this with "assert False"
# We could encode Any as trivial parameters etc., but it would be too verbose.
# TODO: assert this is a trivial type, like Any, Never, or object.
return repl

def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type:
Expand Down
17 changes: 9 additions & 8 deletions mypy/fixup.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,17 @@ def visit_type_info(self, info: TypeInfo) -> None:
info.update_tuple_type(info.tuple_type)
if info.special_alias:
info.special_alias.alias_tvars = list(info.defn.type_vars)
for i, t in enumerate(info.defn.type_vars):
if isinstance(t, TypeVarTupleType):
info.special_alias.tvar_tuple_index = i
if info.typeddict_type:
info.typeddict_type.accept(self.type_fixer)
info.update_typeddict_type(info.typeddict_type)
if info.special_alias:
info.special_alias.alias_tvars = list(info.defn.type_vars)
for i, t in enumerate(info.defn.type_vars):
if isinstance(t, TypeVarTupleType):
info.special_alias.tvar_tuple_index = i
if info.declared_metaclass:
info.declared_metaclass.accept(self.type_fixer)
if info.metaclass_type:
Expand Down Expand Up @@ -166,11 +172,7 @@ def visit_decorator(self, d: Decorator) -> None:

def visit_class_def(self, c: ClassDef) -> None:
for v in c.type_vars:
if isinstance(v, TypeVarType):
for value in v.values:
value.accept(self.type_fixer)
v.upper_bound.accept(self.type_fixer)
v.default.accept(self.type_fixer)
v.accept(self.type_fixer)

def visit_type_var_expr(self, tv: TypeVarExpr) -> None:
for value in tv.values:
Expand All @@ -184,6 +186,7 @@ def visit_paramspec_expr(self, p: ParamSpecExpr) -> None:

def visit_type_var_tuple_expr(self, tv: TypeVarTupleExpr) -> None:
tv.upper_bound.accept(self.type_fixer)
tv.tuple_fallback.accept(self.type_fixer)
tv.default.accept(self.type_fixer)

def visit_var(self, v: Var) -> None:
Expand Down Expand Up @@ -314,6 +317,7 @@ def visit_param_spec(self, p: ParamSpecType) -> None:
p.default.accept(self)

def visit_type_var_tuple(self, t: TypeVarTupleType) -> None:
t.tuple_fallback.accept(self)
t.upper_bound.accept(self)
t.default.accept(self)

Expand All @@ -336,9 +340,6 @@ def visit_union_type(self, ut: UnionType) -> None:
for it in ut.items:
it.accept(self)

def visit_void(self, o: Any) -> None:
pass # Nothing to descend into.

def visit_type_type(self, t: TypeType) -> None:
t.item.accept(self)

Expand Down
148 changes: 142 additions & 6 deletions mypy/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,10 @@
UninhabitedType,
UnionType,
UnpackType,
find_unpack_in_list,
get_proper_type,
get_proper_types,
split_with_prefix_and_suffix,
)


Expand All @@ -67,7 +69,25 @@ def join_instances(self, t: Instance, s: Instance) -> ProperType:
args: list[Type] = []
# N.B: We use zip instead of indexing because the lengths might have
# mismatches during daemon reprocessing.
for ta, sa, type_var in zip(t.args, s.args, t.type.defn.type_vars):
if t.type.has_type_var_tuple_type:
# We handle joins of variadic instances by simply creating correct mapping
# for type arguments and compute the individual joins same as for regular
# instances. All the heavy lifting is done in the join of tuple types.
assert s.type.type_var_tuple_prefix is not None
assert s.type.type_var_tuple_suffix is not None
prefix = s.type.type_var_tuple_prefix
suffix = s.type.type_var_tuple_suffix
tvt = s.type.defn.type_vars[prefix]
assert isinstance(tvt, TypeVarTupleType)
fallback = tvt.tuple_fallback
s_prefix, s_middle, s_suffix = split_with_prefix_and_suffix(s.args, prefix, suffix)
t_prefix, t_middle, t_suffix = split_with_prefix_and_suffix(t.args, prefix, suffix)
s_args = s_prefix + (TupleType(list(s_middle), fallback),) + s_suffix
t_args = t_prefix + (TupleType(list(t_middle), fallback),) + t_suffix
else:
t_args = t.args
s_args = s.args
for ta, sa, type_var in zip(t_args, s_args, t.type.defn.type_vars):
ta_proper = get_proper_type(ta)
sa_proper = get_proper_type(sa)
new_type: Type | None = None
Expand All @@ -93,6 +113,18 @@ def join_instances(self, t: Instance, s: Instance) -> ProperType:
# If the types are different but equivalent, then an Any is involved
# so using a join in the contravariant case is also OK.
new_type = join_types(ta, sa, self)
elif isinstance(type_var, TypeVarTupleType):
new_type = get_proper_type(join_types(ta, sa, self))
# Put the joined arguments back into instance in the normal form:
# a) Tuple[X, Y, Z] -> [X, Y, Z]
# b) tuple[X, ...] -> [*tuple[X, ...]]
if isinstance(new_type, Instance):
assert new_type.type.fullname == "builtins.tuple"
new_type = UnpackType(new_type)
else:
assert isinstance(new_type, TupleType)
args.extend(new_type.items)
continue
else:
# ParamSpec type variables behave the same, independent of variance
if not is_equivalent(ta, sa):
Expand Down Expand Up @@ -440,6 +472,112 @@ def visit_overloaded(self, t: Overloaded) -> ProperType:
return join_types(t, call)
return join_types(t.fallback, s)

def join_tuples(self, s: TupleType, t: TupleType) -> list[Type] | None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about adding more unit tests for join_tuples? This is pretty tricky but should be easy to unit test. This can happen in a follow-up PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a good idea after all. Tests caught a silly bug (missing early return in one place), and also gave me an idea of how to handle some more edge cases essentially for free.

"""Join two tuple types while handling variadic entries.

This is surprisingly tricky, and we don't handle some tricky corner cases.
Most of the trickiness comes from the variadic tuple items like *tuple[X, ...]
since they can have arbitrary partial overlaps (while *Ts can't be split).
"""
s_unpack_index = find_unpack_in_list(s.items)
t_unpack_index = find_unpack_in_list(t.items)
if s_unpack_index is None and t_unpack_index is None:
if s.length() == t.length():
items: list[Type] = []
for i in range(t.length()):
items.append(join_types(t.items[i], s.items[i]))
return items
return None
if s_unpack_index is not None and t_unpack_index is not None:
# The most complex case: both tuples have an upack item.
s_unpack = s.items[s_unpack_index]
assert isinstance(s_unpack, UnpackType)
s_unpacked = get_proper_type(s_unpack.type)
t_unpack = t.items[t_unpack_index]
assert isinstance(t_unpack, UnpackType)
t_unpacked = get_proper_type(t_unpack.type)
if s.length() == t.length() and s_unpack_index == t_unpack_index:
# We can handle a case where arity is perfectly aligned, e.g.
# join(Tuple[X1, *tuple[Y1, ...], Z1], Tuple[X2, *tuple[Y2, ...], Z2]).
# We can essentially perform the join elementwise.
prefix_len = t_unpack_index
suffix_len = t.length() - t_unpack_index - 1
items = []
for si, ti in zip(s.items[:prefix_len], t.items[:prefix_len]):
items.append(join_types(si, ti))
joined = join_types(s_unpacked, t_unpacked)
if isinstance(joined, TypeVarTupleType):
items.append(UnpackType(joined))
elif isinstance(joined, Instance) and joined.type.fullname == "builtins.tuple":
items.append(UnpackType(joined))
else:
if isinstance(t_unpacked, Instance):
assert t_unpacked.type.fullname == "builtins.tuple"
tuple_instance = t_unpacked
else:
assert isinstance(t_unpacked, TypeVarTupleType)
tuple_instance = t_unpacked.tuple_fallback
items.append(
UnpackType(
tuple_instance.copy_modified(
args=[object_from_instance(tuple_instance)]
)
)
)
if suffix_len:
for si, ti in zip(s.items[-suffix_len:], t.items[-suffix_len:]):
items.append(join_types(si, ti))
if s.length() == 1 or t.length() == 1:
# Another case we can handle is when one of tuple is purely variadic
# (i.e. a non-normalized form of tuple[X, ...]), in this case the join
# will be again purely variadic.
if not (isinstance(s_unpacked, Instance) and isinstance(t_unpacked, Instance)):
return None
assert s_unpacked.type.fullname == "builtins.tuple"
assert t_unpacked.type.fullname == "builtins.tuple"
mid_joined = join_types(s_unpacked.args[0], t_unpacked.args[0])
t_other = [a for i, a in enumerate(t.items) if i != t_unpack_index]
s_other = [a for i, a in enumerate(s.items) if i != s_unpack_index]
other_joined = join_type_list(s_other + t_other)
mid_joined = join_types(mid_joined, other_joined)
return [UnpackType(s_unpacked.copy_modified(args=[mid_joined]))]
# TODO: are there other case we can handle?
return None
if s_unpack_index is not None:
variadic = s
unpack_index = s_unpack_index
fixed = t
else:
assert t_unpack_index is not None
variadic = t
unpack_index = t_unpack_index
fixed = s
# Case where one tuple has variadic item and the other one doesn't. The join will
# be variadic, since fixed tuple is a subtype of variadic, but not vice versa.
unpack = variadic.items[unpack_index]
assert isinstance(unpack, UnpackType)
unpacked = get_proper_type(unpack.type)
if not isinstance(unpacked, Instance):
return None
if fixed.length() < variadic.length() - 1:
# There are no non-trivial types that are supertype of both.
return None
prefix_len = unpack_index
suffix_len = variadic.length() - prefix_len - 1
prefix, middle, suffix = split_with_prefix_and_suffix(
tuple(fixed.items), prefix_len, suffix_len
)
items = []
for fi, vi in zip(prefix, variadic.items[:prefix_len]):
items.append(join_types(fi, vi))
mid_joined = join_type_list(list(middle))
mid_joined = join_types(mid_joined, unpacked.args[0])
items.append(UnpackType(unpacked.copy_modified(args=[mid_joined])))
if suffix_len:
for fi, vi in zip(suffix, variadic.items[-suffix_len:]):
items.append(join_types(fi, vi))
return items

def visit_tuple_type(self, t: TupleType) -> ProperType:
# When given two fixed-length tuples:
# * If they have the same length, join their subtypes item-wise:
Expand All @@ -452,17 +590,15 @@ def visit_tuple_type(self, t: TupleType) -> ProperType:
# Tuple[int, bool] + Tuple[bool, ...] becomes Tuple[int, ...]
# * Joining with any Sequence also returns a Sequence:
# Tuple[int, bool] + List[bool] becomes Sequence[int]
if isinstance(self.s, TupleType) and self.s.length() == t.length():
if isinstance(self.s, TupleType):
if self.instance_joiner is None:
self.instance_joiner = InstanceJoiner()
fallback = self.instance_joiner.join_instances(
mypy.typeops.tuple_fallback(self.s), mypy.typeops.tuple_fallback(t)
)
assert isinstance(fallback, Instance)
if self.s.length() == t.length():
items: list[Type] = []
for i in range(t.length()):
items.append(join_types(t.items[i], self.s.items[i]))
items = self.join_tuples(self.s, t)
if items is not None:
return TupleType(items, fallback)
else:
return fallback
Expand Down
Loading