diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index 0e79eef..1815197 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -9,6 +9,8 @@ This library adheres to - Dropped Python 3.8 support - Changed the signature of ``typeguard_ignore()`` to be compatible with ``typing.no_type_check()`` (PR by @jolaf) +- Fixed checking of variable assignments involving tuple unpacking + (`#486 `_) **4.4.0** (2024-10-27) diff --git a/src/typeguard/_functions.py b/src/typeguard/_functions.py index 2849785..ca21c14 100644 --- a/src/typeguard/_functions.py +++ b/src/typeguard/_functions.py @@ -2,6 +2,7 @@ import sys import warnings +from collections.abc import Sequence from typing import Any, Callable, NoReturn, TypeVar, Union, overload from . import _suppression @@ -242,59 +243,53 @@ def check_yield_type( def check_variable_assignment( - value: object, varname: str, annotation: Any, memo: TypeCheckMemo + value: Any, targets: Sequence[list[tuple[str, Any]]], memo: TypeCheckMemo ) -> Any: if _suppression.type_checks_suppressed: return value - try: - check_type_internal(value, annotation, memo) - except TypeCheckError as exc: - qualname = qualified_name(value, add_class_prefix=True) - exc.append_path_element(f"value assigned to {varname} ({qualname})") - if memo.config.typecheck_fail_callback: - memo.config.typecheck_fail_callback(exc, memo) - else: - raise - - return value - + value_to_return = value + for target in targets: + star_variable_index = next( + (i for i, (varname, _) in enumerate(target) if varname.startswith("*")), + None, + ) + if star_variable_index is not None: + value_to_return = list(value) + remaining_vars = len(target) - 1 - star_variable_index + end_index = len(value_to_return) - remaining_vars + values_to_check = ( + value_to_return[:star_variable_index] + + [value_to_return[star_variable_index:end_index]] + + value_to_return[end_index:] + ) + elif len(target) > 1: + values_to_check = value_to_return = [] + iterator = iter(value) + for _ in target: + try: + values_to_check.append(next(iterator)) + except StopIteration: + raise ValueError( + f"not enough values to unpack (expected {len(target)}, got " + f"{len(values_to_check)})" + ) from None -def check_multi_variable_assignment( - value: Any, targets: list[dict[str, Any]], memo: TypeCheckMemo -) -> Any: - if max(len(target) for target in targets) == 1: - iterated_values = [value] - else: - iterated_values = list(value) - - if not _suppression.type_checks_suppressed: - for expected_types in targets: - value_index = 0 - for ann_index, (varname, expected_type) in enumerate( - expected_types.items() - ): - if varname.startswith("*"): - varname = varname[1:] - keys_left = len(expected_types) - 1 - ann_index - next_value_index = len(iterated_values) - keys_left - obj: object = iterated_values[value_index:next_value_index] - value_index = next_value_index + else: + values_to_check = [value] + + for val, (varname, annotation) in zip(values_to_check, target): + try: + check_type_internal(val, annotation, memo) + except TypeCheckError as exc: + qualname = qualified_name(val, add_class_prefix=True) + exc.append_path_element(f"value assigned to {varname} ({qualname})") + if memo.config.typecheck_fail_callback: + memo.config.typecheck_fail_callback(exc, memo) else: - obj = iterated_values[value_index] - value_index += 1 + raise - try: - check_type_internal(obj, expected_type, memo) - except TypeCheckError as exc: - qualname = qualified_name(obj, add_class_prefix=True) - exc.append_path_element(f"value assigned to {varname} ({qualname})") - if memo.config.typecheck_fail_callback: - memo.config.typecheck_fail_callback(exc, memo) - else: - raise - - return iterated_values[0] if len(iterated_values) == 1 else iterated_values + return value_to_return def warn_on_error(exc: TypeCheckError, memo: TypeCheckMemo) -> None: diff --git a/src/typeguard/_transformer.py b/src/typeguard/_transformer.py index 13d2cf0..937b6b5 100644 --- a/src/typeguard/_transformer.py +++ b/src/typeguard/_transformer.py @@ -28,7 +28,6 @@ If, Import, ImportFrom, - Index, List, Load, LShift, @@ -389,9 +388,7 @@ def visit_BinOp(self, node: BinOp) -> Any: union_name = self.transformer._get_import("typing", "Union") return Subscript( value=union_name, - slice=Index( - Tuple(elts=[node.left, node.right], ctx=Load()), ctx=Load() - ), + slice=Tuple(elts=[node.left, node.right], ctx=Load()), ctx=Load(), ) @@ -410,24 +407,18 @@ def visit_Subscript(self, node: Subscript) -> Any: # The subscript of typing(_extensions).Literal can be any arbitrary string, so # don't try to evaluate it as code if node.slice: - if isinstance(node.slice, Index): - # Python 3.8 - slice_value = node.slice.value # type: ignore[attr-defined] - else: - slice_value = node.slice - - if isinstance(slice_value, Tuple): + if isinstance(node.slice, Tuple): if self._memo.name_matches(node.value, *annotated_names): # Only treat the first argument to typing.Annotated as a potential # forward reference items = cast( typing.List[expr], - [self.visit(slice_value.elts[0])] + slice_value.elts[1:], + [self.visit(node.slice.elts[0])] + node.slice.elts[1:], ) else: items = cast( typing.List[expr], - [self.visit(item) for item in slice_value.elts], + [self.visit(item) for item in node.slice.elts], ) # If this is a Union and any of the items is Any, erase the entire @@ -450,7 +441,7 @@ def visit_Subscript(self, node: Subscript) -> Any: if item is None: items[index] = self.transformer._get_import("typing", "Any") - slice_value.elts = items + node.slice.elts = items else: self.generic_visit(node) @@ -542,18 +533,10 @@ def _use_memo( return_annotation, *generator_names ): if isinstance(return_annotation, Subscript): - annotation_slice = return_annotation.slice - - # Python < 3.9 - if isinstance(annotation_slice, Index): - annotation_slice = ( - annotation_slice.value # type: ignore[attr-defined] - ) - - if isinstance(annotation_slice, Tuple): - items = annotation_slice.elts + if isinstance(return_annotation.slice, Tuple): + items = return_annotation.slice.elts else: - items = [annotation_slice] + items = [return_annotation.slice] if len(items) > 0: new_memo.yield_annotation = self._convert_annotation( @@ -743,7 +726,7 @@ def visit_FunctionDef( annotation_ = self._convert_annotation(node.args.vararg.annotation) if annotation_: container = Name("tuple", ctx=Load()) - subscript_slice: Tuple | Index = Tuple( + subscript_slice = Tuple( [ annotation_, Constant(Ellipsis), @@ -1024,12 +1007,25 @@ def visit_AnnAssign(self, node: AnnAssign) -> Any: func_name = self._get_import( "typeguard._functions", "check_variable_assignment" ) + targets_arg = List( + [ + List( + [ + Tuple( + [Constant(node.target.id), annotation], + ctx=Load(), + ) + ], + ctx=Load(), + ) + ], + ctx=Load(), + ) node.value = Call( func_name, [ node.value, - Constant(node.target.id), - annotation, + targets_arg, self._memo.get_memo_name(), ], [], @@ -1047,7 +1043,7 @@ def visit_Assign(self, node: Assign) -> Any: # Only instrument function-local assignments if isinstance(self._memo.node, (FunctionDef, AsyncFunctionDef)): - targets: list[dict[Constant, expr | None]] = [] + preliminary_targets: list[list[tuple[Constant, expr | None]]] = [] check_required = False for target in node.targets: elts: Sequence[expr] @@ -1058,63 +1054,63 @@ def visit_Assign(self, node: Assign) -> Any: else: continue - annotations_: dict[Constant, expr | None] = {} + annotations_: list[tuple[Constant, expr | None]] = [] for exp in elts: prefix = "" if isinstance(exp, Starred): exp = exp.value prefix = "*" + path: list[str] = [] + while isinstance(exp, Attribute): + path.insert(0, exp.attr) + exp = exp.value + if isinstance(exp, Name): - self._memo.ignored_names.add(exp.id) - name = prefix + exp.id + if not path: + self._memo.ignored_names.add(exp.id) + + path.insert(0, exp.id) + name = prefix + ".".join(path) annotation = self._memo.variable_annotations.get(exp.id) if annotation: - annotations_[Constant(name)] = annotation + annotations_.append((Constant(name), annotation)) check_required = True else: - annotations_[Constant(name)] = None + annotations_.append((Constant(name), None)) - targets.append(annotations_) + preliminary_targets.append(annotations_) if check_required: # Replace missing annotations with typing.Any - for item in targets: - for key, expression in item.items(): + targets: list[list[tuple[Constant, expr]]] = [] + for items in preliminary_targets: + target_list: list[tuple[Constant, expr]] = [] + targets.append(target_list) + for key, expression in items: if expression is None: - item[key] = self._get_import("typing", "Any") + target_list.append((key, self._get_import("typing", "Any"))) + else: + target_list.append((key, expression)) - if len(targets) == 1 and len(targets[0]) == 1: - func_name = self._get_import( - "typeguard._functions", "check_variable_assignment" - ) - target_varname = next(iter(targets[0])) - node.value = Call( - func_name, - [ - node.value, - target_varname, - targets[0][target_varname], - self._memo.get_memo_name(), - ], - [], - ) - elif targets: - func_name = self._get_import( - "typeguard._functions", "check_multi_variable_assignment" - ) - targets_arg = List( - [ - Dict(keys=list(target), values=list(target.values())) - for target in targets - ], - ctx=Load(), - ) - node.value = Call( - func_name, - [node.value, targets_arg, self._memo.get_memo_name()], - [], - ) + func_name = self._get_import( + "typeguard._functions", "check_variable_assignment" + ) + targets_arg = List( + [ + List( + [Tuple([name, ann], ctx=Load()) for name, ann in target], + ctx=Load(), + ) + for target in targets + ], + ctx=Load(), + ) + node.value = Call( + func_name, + [node.value, targets_arg, self._memo.get_memo_name()], + [], + ) return node @@ -1175,12 +1171,20 @@ def visit_AugAssign(self, node: AugAssign) -> Any: operator_call = Call( operator_func, [Name(node.target.id, ctx=Load()), node.value], [] ) + targets_arg = List( + [ + List( + [Tuple([Constant(node.target.id), annotation], ctx=Load())], + ctx=Load(), + ) + ], + ctx=Load(), + ) check_call = Call( self._get_import("typeguard._functions", "check_variable_assignment"), [ operator_call, - Constant(node.target.id), - annotation, + targets_arg, self._memo.get_memo_name(), ], [], diff --git a/src/typeguard/_union_transformer.py b/src/typeguard/_union_transformer.py index d0a3ddf..1c296d3 100644 --- a/src/typeguard/_union_transformer.py +++ b/src/typeguard/_union_transformer.py @@ -8,15 +8,14 @@ from ast import ( BinOp, BitOr, - Index, Load, Name, NodeTransformer, Subscript, + Tuple, fix_missing_locations, parse, ) -from ast import Tuple as ASTTuple from types import CodeType from typing import Any @@ -30,9 +29,7 @@ def visit_BinOp(self, node: BinOp) -> Any: if isinstance(node.op, BitOr): return Subscript( value=self.union_name, - slice=Index( - ASTTuple(elts=[node.left, node.right], ctx=Load()), ctx=Load() - ), + slice=Tuple(elts=[node.left, node.right], ctx=Load()), ctx=Load(), ) diff --git a/tests/test_transformer.py b/tests/test_transformer.py index 9248d50..3cf735d 100644 --- a/tests/test_transformer.py +++ b/tests/test_transformer.py @@ -967,7 +967,7 @@ def foo(x: Any) -> None: def foo(x: Any) -> None: memo = TypeCheckMemo(globals(), locals()) y: FooBar = x - z: list[FooBar] = check_variable_assignment([y], 'z', list, \ + z: list[FooBar] = check_variable_assignment([y], [[('z', list)]], \ memo) """ ).strip() @@ -1145,7 +1145,8 @@ def foo() -> None: def foo() -> None: memo = TypeCheckMemo(globals(), locals()) - x: int = check_variable_assignment(otherfunc(), 'x', int, memo) + x: int = check_variable_assignment(otherfunc(), [[('x', int)]], \ +memo) """ ).strip() ) @@ -1173,8 +1174,8 @@ def foo(*args: int) -> None: memo = TypeCheckMemo(globals(), locals()) check_argument_types('foo', {'args': (args, \ tuple[int, ...])}, memo) - args = check_variable_assignment((5,), 'args', \ -tuple[int, ...], memo) + args = check_variable_assignment((5,), \ +[[('args', tuple[int, ...])]], memo) """ ).strip() ) @@ -1202,8 +1203,8 @@ def foo(**kwargs: int) -> None: memo = TypeCheckMemo(globals(), locals()) check_argument_types('foo', {'kwargs': (kwargs, \ dict[str, int])}, memo) - kwargs = check_variable_assignment({'a': 5}, 'kwargs', \ -dict[str, int], memo) + kwargs = check_variable_assignment({'a': 5}, \ +[[('kwargs', dict[str, int])]], memo) """ ).strip() ) @@ -1232,8 +1233,8 @@ def foo() -> None: def foo() -> None: memo = TypeCheckMemo(globals(), locals()) - x: int | str = check_variable_assignment(otherfunc(), 'x', \ -Union_[int, str], memo) + x: int | str = check_variable_assignment(otherfunc(), \ +[[('x', Union_[int, str])]], memo) """ ).strip() ) @@ -1256,15 +1257,15 @@ def foo() -> None: == dedent( f""" from typeguard import TypeCheckMemo - from typeguard._functions import check_multi_variable_assignment + from typeguard._functions import check_variable_assignment from typing import Any def foo() -> None: memo = TypeCheckMemo(globals(), locals()) x: int z: bytes - {target} = check_multi_variable_assignment(otherfunc(), \ -[{{'x': int, 'y': Any, 'z': bytes}}], memo) + {target} = check_variable_assignment(otherfunc(), \ +[[('x', int), ('y', Any), ('z', bytes)]], memo) """ ).strip() ) @@ -1287,15 +1288,80 @@ def foo() -> None: == dedent( f""" from typeguard import TypeCheckMemo - from typeguard._functions import check_multi_variable_assignment + from typeguard._functions import check_variable_assignment from typing import Any def foo() -> None: memo = TypeCheckMemo(globals(), locals()) x: int z: bytes - {target} = check_multi_variable_assignment(otherfunc(), \ -[{{'x': int, '*y': Any, 'z': bytes}}], memo) + {target} = check_variable_assignment(otherfunc(), \ +[[('x', int), ('*y', Any), ('z', bytes)]], memo) + """ + ).strip() + ) + + def test_complex_multi_assign(self) -> None: + node = parse( + dedent( + """ + def foo() -> None: + x: int + z: bytes + all = x, *y, z = otherfunc() + """ + ) + ) + TypeguardTransformer().visit(node) + target = "x, *y, z" if sys.version_info >= (3, 11) else "(x, *y, z)" + assert ( + unparse(node) + == dedent( + f""" + from typeguard import TypeCheckMemo + from typeguard._functions import check_variable_assignment + from typing import Any + + def foo() -> None: + memo = TypeCheckMemo(globals(), locals()) + x: int + z: bytes + all = {target} = check_variable_assignment(otherfunc(), \ +[[('all', Any)], [('x', int), ('*y', Any), ('z', bytes)]], memo) + """ + ).strip() + ) + + def test_unpacking_assign_to_self(self) -> None: + node = parse( + dedent( + """ + class Foo: + + def foo(self) -> None: + x: int + (x, self.y) = 1, 'test' + """ + ) + ) + TypeguardTransformer().visit(node) + target = "x, self.y" if sys.version_info >= (3, 11) else "(x, self.y)" + assert ( + unparse(node) + == dedent( + f""" + from typeguard import TypeCheckMemo + from typeguard._functions import check_variable_assignment + from typing import Any + + class Foo: + + def foo(self) -> None: + memo = TypeCheckMemo(globals(), locals(), \ +self_type=self.__class__) + x: int + {target} = check_variable_assignment((1, 'test'), \ +[[('x', int), ('self.y', Any)]], memo) """ ).strip() ) @@ -1321,7 +1387,7 @@ def foo(x: int) -> None: def foo(x: int) -> None: memo = TypeCheckMemo(globals(), locals()) check_argument_types('foo', {'x': (x, int)}, memo) - x = check_variable_assignment(6, 'x', int, memo) + x = check_variable_assignment(6, [[('x', int)]], memo) """ ).strip() ) @@ -1422,7 +1488,8 @@ def foo() -> None: def foo() -> None: memo = TypeCheckMemo(globals(), locals()) x: int - x = check_variable_assignment({function}(x, 6), 'x', int, memo) + x = check_variable_assignment({function}(x, 6), [[('x', int)]], \ +memo) """ ).strip() ) @@ -1471,7 +1538,7 @@ def foo(x: int) -> None: def foo(x: int) -> None: memo = TypeCheckMemo(globals(), locals()) check_argument_types('foo', {'x': (x, int)}, memo) - x = check_variable_assignment(iadd(x, 6), 'x', int, memo) + x = check_variable_assignment(iadd(x, 6), [[('x', int)]], memo) """ ).strip() )