diff --git a/mypy/fixup.py b/mypy/fixup.py index 493e5d3ef1bc6..d48c554c5d97e 100644 --- a/mypy/fixup.py +++ b/mypy/fixup.py @@ -172,6 +172,12 @@ def visit_class_def(self, c: ClassDef) -> None: value.accept(self.type_fixer) v.upper_bound.accept(self.type_fixer) v.default.accept(self.type_fixer) + if isinstance(v, ParamSpecType): + v.upper_bound.accept(self.type_fixer) + v.default.accept(self.type_fixer) + if isinstance(v, TypeVarTupleType): + v.upper_bound.accept(self.type_fixer) + v.default.accept(self.type_fixer) def visit_type_var_expr(self, tv: TypeVarExpr) -> None: for value in tv.values: @@ -181,9 +187,11 @@ def visit_type_var_expr(self, tv: TypeVarExpr) -> None: def visit_paramspec_expr(self, p: ParamSpecExpr) -> None: p.upper_bound.accept(self.type_fixer) + p.default.accept(self.type_fixer) def visit_type_var_tuple_expr(self, tv: TypeVarTupleExpr) -> None: tv.upper_bound.accept(self.type_fixer) + tv.default.accept(self.type_fixer) def visit_var(self, v: Var) -> None: if self.current_info is not None: @@ -305,16 +313,16 @@ def visit_type_var(self, tvt: TypeVarType) -> None: if tvt.values: for vt in tvt.values: vt.accept(self) - if tvt.upper_bound is not None: - tvt.upper_bound.accept(self) - if tvt.default is not None: - tvt.default.accept(self) + tvt.upper_bound.accept(self) + tvt.default.accept(self) def visit_param_spec(self, p: ParamSpecType) -> None: p.upper_bound.accept(self) + p.default.accept(self) def visit_type_var_tuple(self, t: TypeVarTupleType) -> None: t.upper_bound.accept(self) + t.default.accept(self) def visit_unpack_type(self, u: UnpackType) -> None: u.type.accept(self) diff --git a/mypy/indirection.py b/mypy/indirection.py index a8c8fea654965..00356d7a4ddbe 100644 --- a/mypy/indirection.py +++ b/mypy/indirection.py @@ -67,10 +67,10 @@ def visit_type_var(self, t: types.TypeVarType) -> set[str]: return self._visit(t.values) | self._visit(t.upper_bound) | self._visit(t.default) def visit_param_spec(self, t: types.ParamSpecType) -> set[str]: - return set() + return self._visit(t.upper_bound) | self._visit(t.default) def visit_type_var_tuple(self, t: types.TypeVarTupleType) -> set[str]: - return self._visit(t.upper_bound) + return self._visit(t.upper_bound) | self._visit(t.default) def visit_unpack_type(self, t: types.UnpackType) -> set[str]: return t.type.accept(self) diff --git a/mypy/server/astdiff.py b/mypy/server/astdiff.py index 2c205ba4d992f..c26d5ea4aeaa9 100644 --- a/mypy/server/astdiff.py +++ b/mypy/server/astdiff.py @@ -201,9 +201,19 @@ def snapshot_symbol_table(name_prefix: str, table: SymbolTable) -> dict[str, Sym snapshot_optional_type(node.target), ) elif isinstance(node, ParamSpecExpr): - result[name] = ("ParamSpec", node.variance, snapshot_type(node.upper_bound)) + result[name] = ( + "ParamSpec", + node.variance, + snapshot_type(node.upper_bound), + snapshot_type(node.default), + ) elif isinstance(node, TypeVarTupleExpr): - result[name] = ("TypeVarTuple", node.variance, snapshot_type(node.upper_bound)) + result[name] = ( + "TypeVarTuple", + node.variance, + snapshot_type(node.upper_bound), + snapshot_type(node.default), + ) else: assert symbol.kind != UNBOUND_IMPORTED if node and get_prefix(node.fullname) != name_prefix: @@ -394,6 +404,7 @@ def visit_param_spec(self, typ: ParamSpecType) -> SnapshotItem: typ.id.meta_level, typ.flavor, snapshot_type(typ.upper_bound), + snapshot_type(typ.default), ) def visit_type_var_tuple(self, typ: TypeVarTupleType) -> SnapshotItem: @@ -402,6 +413,7 @@ def visit_type_var_tuple(self, typ: TypeVarTupleType) -> SnapshotItem: typ.id.raw_id, typ.id.meta_level, snapshot_type(typ.upper_bound), + snapshot_type(typ.default), ) def visit_unpack_type(self, typ: UnpackType) -> SnapshotItem: diff --git a/mypy/server/astmerge.py b/mypy/server/astmerge.py index 7ff888f270393..f25dda0fe4832 100644 --- a/mypy/server/astmerge.py +++ b/mypy/server/astmerge.py @@ -252,6 +252,14 @@ def process_type_var_def(self, tv: TypeVarType) -> None: self.fixup_type(tv.upper_bound) self.fixup_type(tv.default) + def process_param_spec_def(self, tv: ParamSpecType) -> None: + self.fixup_type(tv.upper_bound) + self.fixup_type(tv.default) + + def process_type_var_tuple_def(self, tv: TypeVarTupleType) -> None: + self.fixup_type(tv.upper_bound) + self.fixup_type(tv.default) + def visit_assignment_stmt(self, node: AssignmentStmt) -> None: self.fixup_type(node.type) super().visit_assignment_stmt(node) @@ -478,14 +486,17 @@ def visit_type_type(self, typ: TypeType) -> None: def visit_type_var(self, typ: TypeVarType) -> None: typ.upper_bound.accept(self) + typ.default.accept(self) for value in typ.values: value.accept(self) def visit_param_spec(self, typ: ParamSpecType) -> None: - pass + typ.upper_bound.accept(self) + typ.default.accept(self) def visit_type_var_tuple(self, typ: TypeVarTupleType) -> None: typ.upper_bound.accept(self) + typ.default.accept(self) def visit_unpack_type(self, typ: UnpackType) -> None: typ.type.accept(self) diff --git a/mypy/server/deps.py b/mypy/server/deps.py index d0c1e43e5bc7d..1762c2a9e704b 100644 --- a/mypy/server/deps.py +++ b/mypy/server/deps.py @@ -1049,6 +1049,10 @@ def visit_param_spec(self, typ: ParamSpecType) -> list[str]: triggers = [] if typ.fullname: triggers.append(make_trigger(typ.fullname)) + if typ.upper_bound: + triggers.append(self.get_type_triggers(typ.upper_bound)) + if typ.default: + triggers.extend(self.get_type_triggers(typ.default)) triggers.extend(self.get_type_triggers(typ.upper_bound)) return triggers @@ -1056,6 +1060,10 @@ def visit_type_var_tuple(self, typ: TypeVarTupleType) -> list[str]: triggers = [] if typ.fullname: triggers.append(make_trigger(typ.fullname)) + if typ.upper_bound: + triggers.extend(self.get_type_triggers(typ.upper_bound)) + if typ.default: + triggers.extend(self.get_type_triggers(typ.default)) triggers.extend(self.get_type_triggers(typ.upper_bound)) return triggers diff --git a/mypy/type_visitor.py b/mypy/type_visitor.py index 69d026da6c5f0..86d06f52523e4 100644 --- a/mypy/type_visitor.py +++ b/mypy/type_visitor.py @@ -352,7 +352,7 @@ def visit_param_spec(self, t: ParamSpecType) -> T: return self.query_types([t.default]) def visit_type_var_tuple(self, t: TypeVarTupleType) -> T: - return self.query_types([t.default]) + return self.query_types([t.upper_bound, t.default]) def visit_unpack_type(self, t: UnpackType) -> T: return self.query_types([t.type]) @@ -483,10 +483,10 @@ def visit_type_var(self, t: TypeVarType) -> bool: return self.query_types([t.upper_bound, t.default] + t.values) def visit_param_spec(self, t: ParamSpecType) -> bool: - return t.default.accept(self) + return self.query_types([t.upper_bound, t.default]) def visit_type_var_tuple(self, t: TypeVarTupleType) -> bool: - return t.default.accept(self) + return self.query_types([t.upper_bound, t.default]) def visit_unpack_type(self, t: UnpackType) -> bool: return self.query_types([t.type]) diff --git a/mypy/types.py b/mypy/types.py index 1e14630056901..968c45ac8b430 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -521,6 +521,7 @@ class TypeVarLikeType(ProperType): fullname: str # Fully qualified name id: TypeVarId upper_bound: Type + default: Type def __init__( self, diff --git a/mypy/typeshed/stdlib/typing.pyi b/mypy/typeshed/stdlib/typing.pyi index b4d3947bbda51..efd61ad8bf438 100644 --- a/mypy/typeshed/stdlib/typing.pyi +++ b/mypy/typeshed/stdlib/typing.pyi @@ -138,12 +138,11 @@ Any = object() class TypeVar: __name__: str __bound__: Any | None - __default__: Any __constraints__: tuple[Any, ...] __covariant__: bool __contravariant__: bool def __init__( - self, name: str, *constraints: Any, bound: Any | None = None, default: Any | None = None, covariant: bool = False, contravariant: bool = False + self, name: str, *constraints: Any, bound: Any | None = None, covariant: bool = False, contravariant: bool = False ) -> None: ... if sys.version_info >= (3, 10): def __or__(self, right: Any) -> _SpecialForm: ... @@ -222,7 +221,7 @@ if sys.version_info >= (3, 10): __covariant__: bool __contravariant__: bool def __init__( - self, name: str, *, bound: Any | None = None, default: Any | None = None, contravariant: bool = False, covariant: bool = False + self, name: str, *, bound: Any | None = None, contravariant: bool = False, covariant: bool = False ) -> None: ... @property def args(self) -> ParamSpecArgs: ...