Skip to content

Commit

Permalink
Misc
Browse files Browse the repository at this point in the history
  • Loading branch information
cdce8p committed Mar 7, 2023
1 parent a8212dd commit e43b1f8
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 15 deletions.
16 changes: 12 additions & 4 deletions mypy/fixup.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,12 @@ def visit_class_def(self, c: ClassDef) -> None:
value.accept(self.type_fixer)
v.upper_bound.accept(self.type_fixer)
v.default.accept(self.type_fixer)
if isinstance(v, ParamSpecType):
v.upper_bound.accept(self.type_fixer)
v.default.accept(self.type_fixer)
if isinstance(v, TypeVarTupleType):
v.upper_bound.accept(self.type_fixer)
v.default.accept(self.type_fixer)

def visit_type_var_expr(self, tv: TypeVarExpr) -> None:
for value in tv.values:
Expand All @@ -181,9 +187,11 @@ def visit_type_var_expr(self, tv: TypeVarExpr) -> None:

def visit_paramspec_expr(self, p: ParamSpecExpr) -> None:
p.upper_bound.accept(self.type_fixer)
p.default.accept(self.type_fixer)

def visit_type_var_tuple_expr(self, tv: TypeVarTupleExpr) -> None:
tv.upper_bound.accept(self.type_fixer)
tv.default.accept(self.type_fixer)

def visit_var(self, v: Var) -> None:
if self.current_info is not None:
Expand Down Expand Up @@ -305,16 +313,16 @@ def visit_type_var(self, tvt: TypeVarType) -> None:
if tvt.values:
for vt in tvt.values:
vt.accept(self)
if tvt.upper_bound is not None:
tvt.upper_bound.accept(self)
if tvt.default is not None:
tvt.default.accept(self)
tvt.upper_bound.accept(self)
tvt.default.accept(self)

def visit_param_spec(self, p: ParamSpecType) -> None:
p.upper_bound.accept(self)
p.default.accept(self)

def visit_type_var_tuple(self, t: TypeVarTupleType) -> None:
t.upper_bound.accept(self)
t.default.accept(self)

def visit_unpack_type(self, u: UnpackType) -> None:
u.type.accept(self)
Expand Down
4 changes: 2 additions & 2 deletions mypy/indirection.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ def visit_type_var(self, t: types.TypeVarType) -> set[str]:
return self._visit(t.values) | self._visit(t.upper_bound) | self._visit(t.default)

def visit_param_spec(self, t: types.ParamSpecType) -> set[str]:
return set()
return self._visit(t.upper_bound) | self._visit(t.default)

def visit_type_var_tuple(self, t: types.TypeVarTupleType) -> set[str]:
return self._visit(t.upper_bound)
return self._visit(t.upper_bound) | self._visit(t.default)

def visit_unpack_type(self, t: types.UnpackType) -> set[str]:
return t.type.accept(self)
Expand Down
16 changes: 14 additions & 2 deletions mypy/server/astdiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,19 @@ def snapshot_symbol_table(name_prefix: str, table: SymbolTable) -> dict[str, Sym
snapshot_optional_type(node.target),
)
elif isinstance(node, ParamSpecExpr):
result[name] = ("ParamSpec", node.variance, snapshot_type(node.upper_bound))
result[name] = (
"ParamSpec",
node.variance,
snapshot_type(node.upper_bound),
snapshot_type(node.default),
)
elif isinstance(node, TypeVarTupleExpr):
result[name] = ("TypeVarTuple", node.variance, snapshot_type(node.upper_bound))
result[name] = (
"TypeVarTuple",
node.variance,
snapshot_type(node.upper_bound),
snapshot_type(node.default),
)
else:
assert symbol.kind != UNBOUND_IMPORTED
if node and get_prefix(node.fullname) != name_prefix:
Expand Down Expand Up @@ -394,6 +404,7 @@ def visit_param_spec(self, typ: ParamSpecType) -> SnapshotItem:
typ.id.meta_level,
typ.flavor,
snapshot_type(typ.upper_bound),
snapshot_type(typ.default),
)

def visit_type_var_tuple(self, typ: TypeVarTupleType) -> SnapshotItem:
Expand All @@ -402,6 +413,7 @@ def visit_type_var_tuple(self, typ: TypeVarTupleType) -> SnapshotItem:
typ.id.raw_id,
typ.id.meta_level,
snapshot_type(typ.upper_bound),
snapshot_type(typ.default),
)

def visit_unpack_type(self, typ: UnpackType) -> SnapshotItem:
Expand Down
13 changes: 12 additions & 1 deletion mypy/server/astmerge.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,14 @@ def process_type_var_def(self, tv: TypeVarType) -> None:
self.fixup_type(tv.upper_bound)
self.fixup_type(tv.default)

def process_param_spec_def(self, tv: ParamSpecType) -> None:
self.fixup_type(tv.upper_bound)
self.fixup_type(tv.default)

def process_type_var_tuple_def(self, tv: TypeVarTupleType) -> None:
self.fixup_type(tv.upper_bound)
self.fixup_type(tv.default)

def visit_assignment_stmt(self, node: AssignmentStmt) -> None:
self.fixup_type(node.type)
super().visit_assignment_stmt(node)
Expand Down Expand Up @@ -478,14 +486,17 @@ def visit_type_type(self, typ: TypeType) -> None:

def visit_type_var(self, typ: TypeVarType) -> None:
typ.upper_bound.accept(self)
typ.default.accept(self)
for value in typ.values:
value.accept(self)

def visit_param_spec(self, typ: ParamSpecType) -> None:
pass
typ.upper_bound.accept(self)
typ.default.accept(self)

def visit_type_var_tuple(self, typ: TypeVarTupleType) -> None:
typ.upper_bound.accept(self)
typ.default.accept(self)

def visit_unpack_type(self, typ: UnpackType) -> None:
typ.type.accept(self)
Expand Down
8 changes: 8 additions & 0 deletions mypy/server/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,13 +1049,21 @@ def visit_param_spec(self, typ: ParamSpecType) -> list[str]:
triggers = []
if typ.fullname:
triggers.append(make_trigger(typ.fullname))
if typ.upper_bound:
triggers.append(self.get_type_triggers(typ.upper_bound))
if typ.default:
triggers.extend(self.get_type_triggers(typ.default))
triggers.extend(self.get_type_triggers(typ.upper_bound))
return triggers

def visit_type_var_tuple(self, typ: TypeVarTupleType) -> list[str]:
triggers = []
if typ.fullname:
triggers.append(make_trigger(typ.fullname))
if typ.upper_bound:
triggers.extend(self.get_type_triggers(typ.upper_bound))
if typ.default:
triggers.extend(self.get_type_triggers(typ.default))
triggers.extend(self.get_type_triggers(typ.upper_bound))
return triggers

Expand Down
6 changes: 3 additions & 3 deletions mypy/type_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def visit_param_spec(self, t: ParamSpecType) -> T:
return self.query_types([t.default])

def visit_type_var_tuple(self, t: TypeVarTupleType) -> T:
return self.query_types([t.default])
return self.query_types([t.upper_bound, t.default])

def visit_unpack_type(self, t: UnpackType) -> T:
return self.query_types([t.type])
Expand Down Expand Up @@ -483,10 +483,10 @@ def visit_type_var(self, t: TypeVarType) -> bool:
return self.query_types([t.upper_bound, t.default] + t.values)

def visit_param_spec(self, t: ParamSpecType) -> bool:
return t.default.accept(self)
return self.query_types([t.upper_bound, t.default])

def visit_type_var_tuple(self, t: TypeVarTupleType) -> bool:
return t.default.accept(self)
return self.query_types([t.upper_bound, t.default])

def visit_unpack_type(self, t: UnpackType) -> bool:
return self.query_types([t.type])
Expand Down
1 change: 1 addition & 0 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,7 @@ class TypeVarLikeType(ProperType):
fullname: str # Fully qualified name
id: TypeVarId
upper_bound: Type
default: Type

def __init__(
self,
Expand Down
5 changes: 2 additions & 3 deletions mypy/typeshed/stdlib/typing.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,11 @@ Any = object()
class TypeVar:
__name__: str
__bound__: Any | None
__default__: Any
__constraints__: tuple[Any, ...]
__covariant__: bool
__contravariant__: bool
def __init__(
self, name: str, *constraints: Any, bound: Any | None = None, default: Any | None = None, covariant: bool = False, contravariant: bool = False
self, name: str, *constraints: Any, bound: Any | None = None, covariant: bool = False, contravariant: bool = False
) -> None: ...
if sys.version_info >= (3, 10):
def __or__(self, right: Any) -> _SpecialForm: ...
Expand Down Expand Up @@ -222,7 +221,7 @@ if sys.version_info >= (3, 10):
__covariant__: bool
__contravariant__: bool
def __init__(
self, name: str, *, bound: Any | None = None, default: Any | None = None, contravariant: bool = False, covariant: bool = False
self, name: str, *, bound: Any | None = None, contravariant: bool = False, covariant: bool = False
) -> None: ...
@property
def args(self) -> ParamSpecArgs: ...
Expand Down

0 comments on commit e43b1f8

Please sign in to comment.