Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support TypeAliasType #16926

Merged
merged 18 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 116 additions & 16 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@

from contextlib import contextmanager
from typing import Any, Callable, Collection, Final, Iterable, Iterator, List, TypeVar, cast
from typing_extensions import TypeAlias as _TypeAlias
from typing_extensions import TypeAlias as _TypeAlias, TypeGuard

from mypy import errorcodes as codes, message_registry
from mypy.constant_fold import constant_fold_expr
Expand Down Expand Up @@ -2018,34 +2018,35 @@ def analyze_class_typevar_declaration(self, base: Type) -> tuple[TypeVarLikeList

def analyze_unbound_tvar(self, t: Type) -> tuple[str, TypeVarLikeExpr] | None:
if isinstance(t, UnpackType) and isinstance(t.type, UnboundType):
return self.analyze_unbound_tvar_impl(t.type, allow_tvt=True)
return self.analyze_unbound_tvar_impl(t.type, is_unpacked=True)
if isinstance(t, UnboundType):
sym = self.lookup_qualified(t.name, t)
if sym and sym.fullname in ("typing.Unpack", "typing_extensions.Unpack"):
inner_t = t.args[0]
if isinstance(inner_t, UnboundType):
return self.analyze_unbound_tvar_impl(inner_t, allow_tvt=True)
return self.analyze_unbound_tvar_impl(inner_t, is_unpacked=True)
return None
return self.analyze_unbound_tvar_impl(t)
return None

def analyze_unbound_tvar_impl(
self, t: UnboundType, allow_tvt: bool = False
self, t: UnboundType, is_unpacked: bool = False, is_typealias_param: bool = False
) -> tuple[str, TypeVarLikeExpr] | None:
assert not is_unpacked or not is_typealias_param, "Mutually exclusive conditions"
sym = self.lookup_qualified(t.name, t)
if sym and isinstance(sym.node, PlaceholderNode):
self.record_incomplete_ref()
if not allow_tvt and sym and isinstance(sym.node, ParamSpecExpr):
if not is_unpacked and sym and isinstance(sym.node, ParamSpecExpr):
if sym.fullname and not self.tvar_scope.allow_binding(sym.fullname):
# It's bound by our type variable scope
return None
return t.name, sym.node
if allow_tvt and sym and isinstance(sym.node, TypeVarTupleExpr):
if (is_unpacked or is_typealias_param) and sym and isinstance(sym.node, TypeVarTupleExpr):
if sym.fullname and not self.tvar_scope.allow_binding(sym.fullname):
# It's bound by our type variable scope
return None
return t.name, sym.node
if sym is None or not isinstance(sym.node, TypeVarExpr) or allow_tvt:
if sym is None or not isinstance(sym.node, TypeVarExpr) or is_unpacked:
return None
elif sym.fullname and not self.tvar_scope.allow_binding(sym.fullname):
# It's bound by our type variable scope
Expand Down Expand Up @@ -3515,7 +3516,11 @@ def analyze_simple_literal_type(self, rvalue: Expression, is_final: bool) -> Typ
return typ

def analyze_alias(
self, name: str, rvalue: Expression, allow_placeholder: bool = False
self,
name: str,
rvalue: Expression,
allow_placeholder: bool = False,
declared_type_vars: TypeVarLikeList | None = None,
) -> tuple[Type | None, list[TypeVarLikeType], set[str], list[str], bool]:
"""Check if 'rvalue' is a valid type allowed for aliasing (e.g. not a type variable).

Expand All @@ -3540,9 +3545,10 @@ def analyze_alias(
found_type_vars = self.find_type_var_likes(typ)
tvar_defs: list[TypeVarLikeType] = []
namespace = self.qualified_name(name)
alias_type_vars = found_type_vars if declared_type_vars is None else declared_type_vars
last_tvar_name_with_default: str | None = None
with self.tvar_scope_frame(self.tvar_scope.class_frame(namespace)):
for name, tvar_expr in found_type_vars:
for name, tvar_expr in alias_type_vars:
tvar_expr.default = tvar_expr.default.accept(
TypeVarDefaultTranslator(self, tvar_expr.name, typ)
)
Expand All @@ -3567,6 +3573,7 @@ def analyze_alias(
in_dynamic_func=dynamic,
global_scope=global_scope,
allowed_alias_tvars=tvar_defs,
has_type_params=declared_type_vars is not None,
)

# There can be only one variadic variable at most, the error is reported elsewhere.
Expand All @@ -3579,7 +3586,7 @@ def analyze_alias(
variadic = True
new_tvar_defs.append(td)

qualified_tvars = [node.fullname for _name, node in found_type_vars]
qualified_tvars = [node.fullname for _name, node in alias_type_vars]
empty_tuple_index = typ.empty_tuple_index if isinstance(typ, UnboundType) else False
return analyzed, new_tvar_defs, depends_on, qualified_tvars, empty_tuple_index

Expand Down Expand Up @@ -3612,7 +3619,19 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool:
# unless using PEP 613 `cls: TypeAlias = A`
return False

if isinstance(s.rvalue, CallExpr) and s.rvalue.analyzed:
# It can be `A = TypeAliasType('A', ...)` call, in this case,
# we just take the second argument and analyze it:
type_params: TypeVarLikeList | None
if self.check_type_alias_type_call(s.rvalue, name=lvalue.name):
rvalue = s.rvalue.args[1]
pep_695 = True
type_params = self.analyze_type_alias_type_params(s.rvalue)
else:
rvalue = s.rvalue
pep_695 = False
type_params = None

if isinstance(rvalue, CallExpr) and rvalue.analyzed:
return False

existing = self.current_symbol_table().get(lvalue.name)
Expand All @@ -3638,7 +3657,7 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool:
return False

non_global_scope = self.type or self.is_func_scope()
if not pep_613 and isinstance(s.rvalue, RefExpr) and non_global_scope:
if not pep_613 and isinstance(rvalue, RefExpr) and non_global_scope:
# Fourth rule (special case): Non-subscripted right hand side creates a variable
# at class and function scopes. For example:
#
Expand All @@ -3650,8 +3669,7 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool:
# without this rule, this typical use case will require a lot of explicit
# annotations (see the second rule).
return False
rvalue = s.rvalue
if not pep_613 and not self.can_be_type_alias(rvalue):
if not pep_613 and not pep_695 and not self.can_be_type_alias(rvalue):
return False

if existing and not isinstance(existing.node, (PlaceholderNode, TypeAlias)):
Expand All @@ -3668,7 +3686,7 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool:
else:
tag = self.track_incomplete_refs()
res, alias_tvars, depends_on, qualified_tvars, empty_tuple_index = self.analyze_alias(
lvalue.name, rvalue, allow_placeholder=True
lvalue.name, rvalue, allow_placeholder=True, declared_type_vars=type_params
)
if not res:
return False
Expand Down Expand Up @@ -3698,13 +3716,15 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool:
# so we need to replace it with non-explicit Anys.
res = make_any_non_explicit(res)
# Note: with the new (lazy) type alias representation we only need to set no_args to True
# if the expected number of arguments is non-zero, so that aliases like A = List work.
# if the expected number of arguments is non-zero, so that aliases like `A = List` work
# but not aliases like `A = TypeAliasType("A", List)` as these need explicit type params.
# However, eagerly expanding aliases like Text = str is a nice performance optimization.
no_args = (
isinstance(res, ProperType)
and isinstance(res, Instance)
and not res.args
and not empty_tuple_index
and not pep_695
)
if isinstance(res, ProperType) and isinstance(res, Instance):
if not validate_instance(res, self.fail, empty_tuple_index):
Expand Down Expand Up @@ -3771,6 +3791,80 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool:
self.note("Use variable annotation syntax to define protocol members", s)
return True

def check_type_alias_type_call(self, rvalue: Expression, *, name: str) -> TypeGuard[CallExpr]:
if not isinstance(rvalue, CallExpr):
return False

names = ["typing_extensions.TypeAliasType"]
if self.options.python_version >= (3, 12):
names.append("typing.TypeAliasType")
if not refers_to_fullname(rvalue.callee, tuple(names)):
return False

return self.check_typevarlike_name(rvalue, name, rvalue)

def analyze_type_alias_type_params(self, rvalue: CallExpr) -> TypeVarLikeList:
if "type_params" in rvalue.arg_names:
type_params_arg = rvalue.args[rvalue.arg_names.index("type_params")]
if not isinstance(type_params_arg, TupleExpr):
self.fail(
"Tuple literal expected as the type_params argument to TypeAliasType",
type_params_arg,
)
return []
type_params = type_params_arg.items
else:
type_params = []

declared_tvars: TypeVarLikeList = []
have_type_var_tuple = False
for tp_expr in type_params:
if isinstance(tp_expr, StarExpr):
tp_expr.valid = False
self.analyze_type_expr(tp_expr)
try:
base = self.expr_to_unanalyzed_type(tp_expr)
except TypeTranslationError:
continue
if not isinstance(base, UnboundType):
continue

tag = self.track_incomplete_refs()
tvar = self.analyze_unbound_tvar_impl(base, is_typealias_param=True)
if tvar:
if isinstance(tvar[1], TypeVarTupleExpr):
if have_type_var_tuple:
self.fail(
"Can only use one TypeVarTuple in type_params argument to TypeAliasType",
base,
code=codes.TYPE_VAR,
)
have_type_var_tuple = True
continue
have_type_var_tuple = True
elif not self.found_incomplete_ref(tag):
self.fail(
"Free type variable expected in type_params argument to TypeAliasType",
base,
code=codes.TYPE_VAR,
)
sym = self.lookup_qualified(base.name, base)
if sym and sym.fullname in ("typing.Unpack", "typing_extensions.Unpack"):
self.note(
"Don't Unpack type variables in type_params", base, code=codes.TYPE_VAR
)
continue
if tvar in declared_tvars:
self.fail(
f'Duplicate type variable "{tvar[0]}" in type_params argument to TypeAliasType',
base,
code=codes.TYPE_VAR,
)
continue
if tvar:
declared_tvars.append(tvar)
return declared_tvars

def disable_invalid_recursive_aliases(
self, s: AssignmentStmt, current_node: TypeAlias
) -> None:
Expand Down Expand Up @@ -5187,6 +5281,12 @@ def visit_call_expr(self, expr: CallExpr) -> None:
expr.analyzed = OpExpr("divmod", expr.args[0], expr.args[1])
expr.analyzed.line = expr.line
expr.analyzed.accept(self)
elif refers_to_fullname(
expr.callee, ("typing.TypeAliasType", "typing_extensions.TypeAliasType")
):
with self.allow_unbound_tvars_set():
for a in expr.args:
a.accept(self)
else:
# Normal call expression.
for a in expr.args:
Expand Down
49 changes: 37 additions & 12 deletions mypy/typeanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def analyze_type_alias(
in_dynamic_func: bool = False,
global_scope: bool = True,
allowed_alias_tvars: list[TypeVarLikeType] | None = None,
has_type_params: bool = False,
) -> tuple[Type, set[str]]:
"""Analyze r.h.s. of a (potential) type alias definition.

Expand All @@ -158,6 +159,7 @@ def analyze_type_alias(
allow_placeholder=allow_placeholder,
prohibit_self_type="type alias target",
allowed_alias_tvars=allowed_alias_tvars,
has_type_params=has_type_params,
)
analyzer.in_dynamic_func = in_dynamic_func
analyzer.global_scope = global_scope
Expand Down Expand Up @@ -210,6 +212,7 @@ def __init__(
prohibit_self_type: str | None = None,
allowed_alias_tvars: list[TypeVarLikeType] | None = None,
allow_type_any: bool = False,
has_type_params: bool = False,
) -> None:
self.api = api
self.fail_func = api.fail
Expand All @@ -231,6 +234,7 @@ def __init__(
if allowed_alias_tvars is None:
allowed_alias_tvars = []
self.allowed_alias_tvars = allowed_alias_tvars
self.has_type_params = has_type_params
# If false, record incomplete ref if we generate PlaceholderType.
self.allow_placeholder = allow_placeholder
# Are we in a context where Required[] is allowed?
Expand Down Expand Up @@ -325,7 +329,11 @@ def visit_unbound_type_nonoptional(self, t: UnboundType, defining_literal: bool)
if tvar_def is None:
if self.allow_unbound_tvars:
return t
self.fail(f'ParamSpec "{t.name}" is unbound', t, code=codes.VALID_TYPE)
if self.defining_alias and self.has_type_params:
msg = f'ParamSpec "{t.name}" is not included in type_params'
else:
msg = f'ParamSpec "{t.name}" is unbound'
self.fail(msg, t, code=codes.VALID_TYPE)
return AnyType(TypeOfAny.from_error)
assert isinstance(tvar_def, ParamSpecType)
if len(t.args) > 0:
Expand All @@ -349,11 +357,11 @@ def visit_unbound_type_nonoptional(self, t: UnboundType, defining_literal: bool)
and not defining_literal
and (tvar_def is None or tvar_def not in self.allowed_alias_tvars)
):
self.fail(
f'Can\'t use bound type variable "{t.name}" to define generic alias',
t,
code=codes.VALID_TYPE,
)
if self.has_type_params:
msg = f'Type variable "{t.name}" is not included in type_params'
else:
msg = f'Can\'t use bound type variable "{t.name}" to define generic alias'
self.fail(msg, t, code=codes.VALID_TYPE)
return AnyType(TypeOfAny.from_error)
if isinstance(sym.node, TypeVarExpr) and tvar_def is not None:
assert isinstance(tvar_def, TypeVarType)
Expand All @@ -368,17 +376,21 @@ def visit_unbound_type_nonoptional(self, t: UnboundType, defining_literal: bool)
and self.defining_alias
and tvar_def not in self.allowed_alias_tvars
):
self.fail(
f'Can\'t use bound type variable "{t.name}" to define generic alias',
t,
code=codes.VALID_TYPE,
)
if self.has_type_params:
msg = f'Type variable "{t.name}" is not included in type_params'
else:
msg = f'Can\'t use bound type variable "{t.name}" to define generic alias'
self.fail(msg, t, code=codes.VALID_TYPE)
return AnyType(TypeOfAny.from_error)
if isinstance(sym.node, TypeVarTupleExpr):
if tvar_def is None:
if self.allow_unbound_tvars:
return t
self.fail(f'TypeVarTuple "{t.name}" is unbound', t, code=codes.VALID_TYPE)
if self.defining_alias and self.has_type_params:
msg = f'TypeVarTuple "{t.name}" is not included in type_params'
else:
msg = f'TypeVarTuple "{t.name}" is unbound'
self.fail(msg, t, code=codes.VALID_TYPE)
return AnyType(TypeOfAny.from_error)
assert isinstance(tvar_def, TypeVarTupleType)
if not self.allow_type_var_tuple:
Expand Down Expand Up @@ -1267,6 +1279,19 @@ def analyze_callable_args_for_paramspec(
AnyType(TypeOfAny.explicit), ret_type=ret_type, fallback=fallback
)
return None
elif (
self.defining_alias
and self.has_type_params
and tvar_def not in self.allowed_alias_tvars
):
self.fail(
f'ParamSpec "{callable_args.name}" is not included in type_params',
callable_args,
code=codes.VALID_TYPE,
)
return callable_with_ellipsis(
AnyType(TypeOfAny.special_form), ret_type=ret_type, fallback=fallback
)

return CallableType(
[
Expand Down
Loading
Loading