From 80190101f68b52e960c22572ed6cc814de078b9c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Simon?= Date: Thu, 4 Apr 2024 12:34:00 +0200 Subject: [PATCH] Narrow individual items when matching a tuple to a sequence pattern (#16905) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #12364 When matching a tuple to a sequence pattern, this change narrows the type of tuple items inside the matched case: ```py def test(a: bool, b: bool) -> None: match a, b: case True, True: reveal_type(a) # before: "builtins.bool", after: "Literal[True]" ``` This also works with nested tuples, recursively: ```py def test(a: bool, b: bool, c: bool) -> None: match a, (b, c): case _, [True, False]: reveal_type(c) # before: "builtins.bool", after: "Literal[False]" ``` This only partially fixes issue #12364; see [my comment there](https://github.com/python/mypy/issues/12364#issuecomment-1937375271) for more context. --- This is my first contribution to mypy, so I may miss some context or conventions; I'm eager for any feedback! --------- Co-authored-by: Loïc Simon --- mypy/checker.py | 17 ++++++++ test-data/unit/check-python310.test | 66 +++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+) diff --git a/mypy/checker.py b/mypy/checker.py index 5d243195d50f..af7535581091 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5119,6 +5119,9 @@ def visit_match_stmt(self, s: MatchStmt) -> None: ) self.remove_capture_conflicts(pattern_type.captures, inferred_types) self.push_type_map(pattern_map) + if pattern_map: + for expr, typ in pattern_map.items(): + self.push_type_map(self._get_recursive_sub_patterns_map(expr, typ)) self.push_type_map(pattern_type.captures) if g is not None: with self.binder.frame_context(can_skip=False, fall_through=3): @@ -5156,6 +5159,20 @@ def visit_match_stmt(self, s: MatchStmt) -> None: with self.binder.frame_context(can_skip=False, fall_through=2): pass + def _get_recursive_sub_patterns_map( + self, expr: Expression, typ: Type + ) -> dict[Expression, Type]: + sub_patterns_map: dict[Expression, Type] = {} + typ_ = get_proper_type(typ) + if isinstance(expr, TupleExpr) and isinstance(typ_, TupleType): + # When matching a tuple expression with a sequence pattern, narrow individual tuple items + assert len(expr.items) == len(typ_.items) + for item_expr, item_typ in zip(expr.items, typ_.items): + sub_patterns_map[item_expr] = item_typ + sub_patterns_map.update(self._get_recursive_sub_patterns_map(item_expr, item_typ)) + + return sub_patterns_map + def infer_variable_types_from_type_maps(self, type_maps: list[TypeMap]) -> dict[Var, Type]: all_captures: dict[Var, list[tuple[NameExpr, Type]]] = defaultdict(list) for tm in type_maps: diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 3a040d94d7ba..2b56d2db07a9 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -341,6 +341,72 @@ match m: reveal_type(m) # N: Revealed type is "builtins.list[builtins.list[builtins.str]]" [builtins fixtures/list.pyi] +[case testMatchSequencePatternNarrowSubjectItems] +m: int +n: str +o: bool + +match m, n, o: + case [3, "foo", True]: + reveal_type(m) # N: Revealed type is "Literal[3]" + reveal_type(n) # N: Revealed type is "Literal['foo']" + reveal_type(o) # N: Revealed type is "Literal[True]" + case [a, b, c]: + reveal_type(m) # N: Revealed type is "builtins.int" + reveal_type(n) # N: Revealed type is "builtins.str" + reveal_type(o) # N: Revealed type is "builtins.bool" + +reveal_type(m) # N: Revealed type is "builtins.int" +reveal_type(n) # N: Revealed type is "builtins.str" +reveal_type(o) # N: Revealed type is "builtins.bool" +[builtins fixtures/tuple.pyi] + +[case testMatchSequencePatternNarrowSubjectItemsRecursive] +m: int +n: int +o: int +p: int +q: int +r: int + +match m, (n, o), (p, (q, r)): + case [0, [1, 2], [3, [4, 5]]]: + reveal_type(m) # N: Revealed type is "Literal[0]" + reveal_type(n) # N: Revealed type is "Literal[1]" + reveal_type(o) # N: Revealed type is "Literal[2]" + reveal_type(p) # N: Revealed type is "Literal[3]" + reveal_type(q) # N: Revealed type is "Literal[4]" + reveal_type(r) # N: Revealed type is "Literal[5]" +[builtins fixtures/tuple.pyi] + +[case testMatchSequencePatternSequencesLengthMismatchNoNarrowing] +m: int +n: str +o: bool + +match m, n, o: + case [3, "foo"]: + pass + case [3, "foo", True, True]: + pass +[builtins fixtures/tuple.pyi] + +[case testMatchSequencePatternSequencesLengthMismatchNoNarrowingRecursive] +m: int +n: int +o: int + +match m, (n, o): + case [0]: + pass + case [0, 1, [2]]: + pass + case [0, [1]]: + pass + case [0, [1, 2, 3]]: + pass +[builtins fixtures/tuple.pyi] + -- Mapping Pattern -- [case testMatchMappingPatternCaptures]