Skip to content

Commit

Permalink
Add initial support for new style TypeVar defaults (PEP 696) (#17985)
Browse files Browse the repository at this point in the history
Add initial support for TypeVar defaults using the new syntax. Similar
to the old syntax, it doesn't fully work yet for ParamSpec, TypeVarTuple
and recursive TypeVar defaults.

Refs: #14851
  • Loading branch information
cdce8p authored Oct 19, 2024
1 parent c9d4c61 commit 603a365
Show file tree
Hide file tree
Showing 10 changed files with 456 additions and 85 deletions.
15 changes: 15 additions & 0 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1159,6 +1159,7 @@ def check_func_def(
) -> None:
"""Type check a function definition."""
# Expand type variables with value restrictions to ordinary types.
self.check_typevar_defaults(typ.variables)
expanded = self.expand_typevars(defn, typ)
original_typ = typ
for item, typ in expanded:
Expand Down Expand Up @@ -2483,6 +2484,8 @@ def visit_class_def(self, defn: ClassDef) -> None:
context=defn,
code=codes.TYPE_VAR,
)
if typ.defn.type_vars:
self.check_typevar_defaults(typ.defn.type_vars)

if typ.is_protocol and typ.defn.type_vars:
self.check_protocol_variance(defn)
Expand Down Expand Up @@ -2546,6 +2549,15 @@ def check_init_subclass(self, defn: ClassDef) -> None:
# all other bases have already been checked.
break

def check_typevar_defaults(self, tvars: Sequence[TypeVarLikeType]) -> None:
for tv in tvars:
if not (isinstance(tv, TypeVarType) and tv.has_default()):
continue
if not is_subtype(tv.default, tv.upper_bound):
self.fail("TypeVar default must be a subtype of the bound type", tv)
if tv.values and not any(tv.default == value for value in tv.values):
self.fail("TypeVar default must be one of the constraint types", tv)

def check_enum(self, defn: ClassDef) -> None:
assert defn.info.is_enum
if defn.info.fullname not in ENUM_BASES:
Expand Down Expand Up @@ -5365,6 +5377,9 @@ def remove_capture_conflicts(self, type_map: TypeMap, inferred_types: dict[Var,
del type_map[expr]

def visit_type_alias_stmt(self, o: TypeAliasStmt) -> None:
if o.alias_node:
self.check_typevar_defaults(o.alias_node.alias_tvars)

with self.msg.filter_errors():
self.expr_checker.accept(o.value)

Expand Down
22 changes: 11 additions & 11 deletions mypy/fastparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -1196,19 +1196,17 @@ def validate_type_param(self, type_param: ast_TypeVar) -> None:
def translate_type_params(self, type_params: list[Any]) -> list[TypeParam]:
explicit_type_params = []
for p in type_params:
bound = None
bound: Type | None = None
values: list[Type] = []
if sys.version_info >= (3, 13) and p.default_value is not None:
self.fail(
message_registry.TYPE_PARAM_DEFAULT_NOT_SUPPORTED,
p.lineno,
p.col_offset,
blocker=False,
)
default: Type | None = None
if sys.version_info >= (3, 13):
default = TypeConverter(self.errors, line=p.lineno).visit(p.default_value)
if isinstance(p, ast_ParamSpec): # type: ignore[misc]
explicit_type_params.append(TypeParam(p.name, PARAM_SPEC_KIND, None, []))
explicit_type_params.append(TypeParam(p.name, PARAM_SPEC_KIND, None, [], default))
elif isinstance(p, ast_TypeVarTuple): # type: ignore[misc]
explicit_type_params.append(TypeParam(p.name, TYPE_VAR_TUPLE_KIND, None, []))
explicit_type_params.append(
TypeParam(p.name, TYPE_VAR_TUPLE_KIND, None, [], default)
)
else:
if isinstance(p.bound, ast3.Tuple):
if len(p.bound.elts) < 2:
Expand All @@ -1224,7 +1222,9 @@ def translate_type_params(self, type_params: list[Any]) -> list[TypeParam]:
elif p.bound is not None:
self.validate_type_param(p)
bound = TypeConverter(self.errors, line=p.lineno).visit(p.bound)
explicit_type_params.append(TypeParam(p.name, TYPE_VAR_KIND, bound, values))
explicit_type_params.append(
TypeParam(p.name, TYPE_VAR_KIND, bound, values, default)
)
return explicit_type_params

# Return(expr? value)
Expand Down
5 changes: 0 additions & 5 deletions mypy/message_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,8 +362,3 @@ def with_additional_msg(self, info: str) -> ErrorMessage:
TYPE_ALIAS_WITH_AWAIT_EXPRESSION: Final = ErrorMessage(
"Await expression cannot be used within a type alias", codes.SYNTAX
)

TYPE_PARAM_DEFAULT_NOT_SUPPORTED: Final = ErrorMessage(
"Type parameter default types not supported when using Python 3.12 type parameter syntax",
codes.MISC,
)
10 changes: 7 additions & 3 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,19 +670,21 @@ def set_line(


class TypeParam:
__slots__ = ("name", "kind", "upper_bound", "values")
__slots__ = ("name", "kind", "upper_bound", "values", "default")

def __init__(
self,
name: str,
kind: int,
upper_bound: mypy.types.Type | None,
values: list[mypy.types.Type],
default: mypy.types.Type | None,
) -> None:
self.name = name
self.kind = kind
self.upper_bound = upper_bound
self.values = values
self.default = default


FUNCITEM_FLAGS: Final = FUNCBASE_FLAGS + [
Expand Down Expand Up @@ -782,7 +784,7 @@ class FuncDef(FuncItem, SymbolNode, Statement):
"deco_line",
"is_trivial_body",
"is_mypy_only",
# Present only when a function is decorated with @typing.datasclass_transform or similar
# Present only when a function is decorated with @typing.dataclass_transform or similar
"dataclass_transform_spec",
"docstring",
"deprecated",
Expand Down Expand Up @@ -1657,21 +1659,23 @@ def accept(self, visitor: StatementVisitor[T]) -> T:


class TypeAliasStmt(Statement):
__slots__ = ("name", "type_args", "value", "invalid_recursive_alias")
__slots__ = ("name", "type_args", "value", "invalid_recursive_alias", "alias_node")

__match_args__ = ("name", "type_args", "value")

name: NameExpr
type_args: list[TypeParam]
value: LambdaExpr # Return value will get translated into a type
invalid_recursive_alias: bool
alias_node: TypeAlias | None

def __init__(self, name: NameExpr, type_args: list[TypeParam], value: LambdaExpr) -> None:
super().__init__()
self.name = name
self.type_args = type_args
self.value = value
self.invalid_recursive_alias = False
self.alias_node = None

def accept(self, visitor: StatementVisitor[T]) -> T:
return visitor.visit_type_alias_stmt(self)
Expand Down
132 changes: 79 additions & 53 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1808,7 +1808,26 @@ def analyze_type_param(
upper_bound = self.named_type("builtins.tuple", [self.object_type()])
else:
upper_bound = self.object_type()
default = AnyType(TypeOfAny.from_omitted_generics)
if type_param.default:
default = self.anal_type(
type_param.default,
allow_placeholder=True,
allow_unbound_tvars=True,
report_invalid_types=False,
allow_param_spec_literals=type_param.kind == PARAM_SPEC_KIND,
allow_tuple_literal=type_param.kind == PARAM_SPEC_KIND,
allow_unpack=type_param.kind == TYPE_VAR_TUPLE_KIND,
)
if default is None:
default = PlaceholderType(None, [], context.line)
elif type_param.kind == TYPE_VAR_KIND:
default = self.check_typevar_default(default, type_param.default)
elif type_param.kind == PARAM_SPEC_KIND:
default = self.check_paramspec_default(default, type_param.default)
elif type_param.kind == TYPE_VAR_TUPLE_KIND:
default = self.check_typevartuple_default(default, type_param.default)
else:
default = AnyType(TypeOfAny.from_omitted_generics)
if type_param.kind == TYPE_VAR_KIND:
values = []
if type_param.values:
Expand Down Expand Up @@ -2243,21 +2262,7 @@ class Foo(Bar, Generic[T]): ...
# grained incremental mode.
defn.removed_base_type_exprs.append(defn.base_type_exprs[i])
del base_type_exprs[i]
tvar_defs: list[TypeVarLikeType] = []
last_tvar_name_with_default: str | None = None
for name, tvar_expr in declared_tvars:
tvar_expr.default = tvar_expr.default.accept(
TypeVarDefaultTranslator(self, tvar_expr.name, context)
)
tvar_def = self.tvar_scope.bind_new(name, tvar_expr)
if last_tvar_name_with_default is not None and not tvar_def.has_default():
self.msg.tvar_without_default_type(
tvar_def.name, last_tvar_name_with_default, context
)
tvar_def.default = AnyType(TypeOfAny.from_error)
elif tvar_def.has_default():
last_tvar_name_with_default = tvar_def.name
tvar_defs.append(tvar_def)
tvar_defs = self.tvar_defs_from_tvars(declared_tvars, context)
return base_type_exprs, tvar_defs, is_protocol

def analyze_class_typevar_declaration(self, base: Type) -> tuple[TypeVarLikeList, bool] | None:
Expand Down Expand Up @@ -2358,6 +2363,26 @@ def get_all_bases_tvars(
tvars.extend(base_tvars)
return remove_dups(tvars)

def tvar_defs_from_tvars(
self, tvars: TypeVarLikeList, context: Context
) -> list[TypeVarLikeType]:
tvar_defs: list[TypeVarLikeType] = []
last_tvar_name_with_default: str | None = None
for name, tvar_expr in tvars:
tvar_expr.default = tvar_expr.default.accept(
TypeVarDefaultTranslator(self, tvar_expr.name, context)
)
tvar_def = self.tvar_scope.bind_new(name, tvar_expr)
if last_tvar_name_with_default is not None and not tvar_def.has_default():
self.msg.tvar_without_default_type(
tvar_def.name, last_tvar_name_with_default, context
)
tvar_def.default = AnyType(TypeOfAny.from_error)
elif tvar_def.has_default():
last_tvar_name_with_default = tvar_def.name
tvar_defs.append(tvar_def)
return tvar_defs

def get_and_bind_all_tvars(self, type_exprs: list[Expression]) -> list[TypeVarLikeType]:
"""Return all type variable references in item type expressions.
Expand Down Expand Up @@ -3833,21 +3858,8 @@ def analyze_alias(
tvar_defs: list[TypeVarLikeType] = []
namespace = self.qualified_name(name)
alias_type_vars = found_type_vars if declared_type_vars is None else declared_type_vars
last_tvar_name_with_default: str | None = None
with self.tvar_scope_frame(self.tvar_scope.class_frame(namespace)):
for name, tvar_expr in alias_type_vars:
tvar_expr.default = tvar_expr.default.accept(
TypeVarDefaultTranslator(self, tvar_expr.name, typ)
)
tvar_def = self.tvar_scope.bind_new(name, tvar_expr)
if last_tvar_name_with_default is not None and not tvar_def.has_default():
self.msg.tvar_without_default_type(
tvar_def.name, last_tvar_name_with_default, typ
)
tvar_def.default = AnyType(TypeOfAny.from_error)
elif tvar_def.has_default():
last_tvar_name_with_default = tvar_def.name
tvar_defs.append(tvar_def)
tvar_defs = self.tvar_defs_from_tvars(alias_type_vars, typ)

if python_3_12_type_alias:
with self.allow_unbound_tvars_set():
Expand Down Expand Up @@ -4615,6 +4627,40 @@ def process_typevar_declaration(self, s: AssignmentStmt) -> bool:
self.add_symbol(name, call.analyzed, s)
return True

def check_typevar_default(self, default: Type, context: Context) -> Type:
typ = get_proper_type(default)
if isinstance(typ, AnyType) and typ.is_from_error:
self.fail(
message_registry.TYPEVAR_ARG_MUST_BE_TYPE.format("TypeVar", "default"), context
)
return default

def check_paramspec_default(self, default: Type, context: Context) -> Type:
typ = get_proper_type(default)
if isinstance(typ, Parameters):
for i, arg_type in enumerate(typ.arg_types):
arg_ptype = get_proper_type(arg_type)
if isinstance(arg_ptype, AnyType) and arg_ptype.is_from_error:
self.fail(f"Argument {i} of ParamSpec default must be a type", context)
elif (
isinstance(typ, AnyType)
and typ.is_from_error
or not isinstance(typ, (AnyType, UnboundType))
):
self.fail(
"The default argument to ParamSpec must be a list expression, ellipsis, or a ParamSpec",
context,
)
default = AnyType(TypeOfAny.from_error)
return default

def check_typevartuple_default(self, default: Type, context: Context) -> Type:
typ = get_proper_type(default)
if not isinstance(typ, UnpackType):
self.fail("The default argument to TypeVarTuple must be an Unpacked tuple", context)
default = AnyType(TypeOfAny.from_error)
return default

def check_typevarlike_name(self, call: CallExpr, name: str, context: Context) -> bool:
"""Checks that the name of a TypeVar or ParamSpec matches its variable."""
name = unmangle(name)
Expand Down Expand Up @@ -4822,23 +4868,7 @@ def process_paramspec_declaration(self, s: AssignmentStmt) -> bool:
report_invalid_typevar_arg=False,
)
default = tv_arg or AnyType(TypeOfAny.from_error)
if isinstance(tv_arg, Parameters):
for i, arg_type in enumerate(tv_arg.arg_types):
typ = get_proper_type(arg_type)
if isinstance(typ, AnyType) and typ.is_from_error:
self.fail(
f"Argument {i} of ParamSpec default must be a type", param_value
)
elif (
isinstance(default, AnyType)
and default.is_from_error
or not isinstance(default, (AnyType, UnboundType))
):
self.fail(
"The default argument to ParamSpec must be a list expression, ellipsis, or a ParamSpec",
param_value,
)
default = AnyType(TypeOfAny.from_error)
default = self.check_paramspec_default(default, param_value)
else:
# ParamSpec is different from a regular TypeVar:
# arguments are not semantically valid. But, allowed in runtime.
Expand Down Expand Up @@ -4899,12 +4929,7 @@ def process_typevartuple_declaration(self, s: AssignmentStmt) -> bool:
allow_unpack=True,
)
default = tv_arg or AnyType(TypeOfAny.from_error)
if not isinstance(default, UnpackType):
self.fail(
"The default argument to TypeVarTuple must be an Unpacked tuple",
param_value,
)
default = AnyType(TypeOfAny.from_error)
default = self.check_typevartuple_default(default, param_value)
else:
self.fail(f'Unexpected keyword argument "{param_name}" for "TypeVarTuple"', s)

Expand Down Expand Up @@ -5503,6 +5528,7 @@ def visit_type_alias_stmt(self, s: TypeAliasStmt) -> None:
eager=eager,
python_3_12_type_alias=True,
)
s.alias_node = alias_node

if (
existing
Expand Down
2 changes: 2 additions & 0 deletions mypy/strconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,8 @@ def type_param(self, p: mypy.nodes.TypeParam) -> list[Any]:
a.append(p.upper_bound)
if p.values:
a.append(("Values", p.values))
if p.default:
a.append(("Default", [p.default]))
return [("TypeParam", a)]

# Expressions
Expand Down
4 changes: 4 additions & 0 deletions mypy/test/testparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class ParserSuite(DataSuite):
files.remove("parse-python310.test")
if sys.version_info < (3, 12):
files.remove("parse-python312.test")
if sys.version_info < (3, 13):
files.remove("parse-python313.test")

def run_case(self, testcase: DataDrivenTestCase) -> None:
test_parser(testcase)
Expand All @@ -43,6 +45,8 @@ def test_parser(testcase: DataDrivenTestCase) -> None:
options.python_version = (3, 10)
elif testcase.file.endswith("python312.test"):
options.python_version = (3, 12)
elif testcase.file.endswith("python313.test"):
options.python_version = (3, 13)
else:
options.python_version = defaults.PYTHON3_VERSION

Expand Down
Loading

0 comments on commit 603a365

Please sign in to comment.