From 2d04f054c0d7f9a6b8d7dc6c3a8a9b9a082cd3fa Mon Sep 17 00:00:00 2001 From: Edward Paget Date: Tue, 14 Nov 2023 14:15:57 -0800 Subject: [PATCH] Fix narrowing on match with function subject Fixes #12998 mypy can't narrow match statements with functions subjects because the callexpr node is not a literal node. This adds a 'dummy' literal node that the match statement visitor can use to do the type narrowing. The python grammar describes the the match subject as a named expression so this uses that nameexpr node as it's literal. --- mypy/checker.py | 11 ++++++++--- test-data/unit/check-python310.test | 12 ++++++++++++ 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index e4eb58d40715d..62610996aa277 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5043,8 +5043,13 @@ def visit_continue_stmt(self, s: ContinueStmt) -> None: return None def visit_match_stmt(self, s: MatchStmt) -> None: + # Create a dummy subject expression to handle cases where a match + # statement's subject is not a literal value which prevent us from correctly + # narrowing types and checking exhaustivity + named_subject = NameExpr("match") with self.binder.frame_context(can_skip=False, fall_through=0): subject_type = get_proper_type(self.expr_checker.accept(s.subject)) + self.store_type(named_subject, subject_type) if isinstance(subject_type, DeletedType): self.msg.deleted_as_rvalue(subject_type, s) @@ -5061,7 +5066,7 @@ def visit_match_stmt(self, s: MatchStmt) -> None: # The second pass narrows down the types and type checks bodies. for p, g, b in zip(s.patterns, s.guards, s.bodies): current_subject_type = self.expr_checker.narrow_type_from_binder( - s.subject, subject_type + named_subject, subject_type ) pattern_type = self.pattern_checker.accept(p, current_subject_type) with self.binder.frame_context(can_skip=True, fall_through=2): @@ -5072,7 +5077,7 @@ def visit_match_stmt(self, s: MatchStmt) -> None: else_map: TypeMap = {} else: pattern_map, else_map = conditional_types_to_typemaps( - s.subject, pattern_type.type, pattern_type.rest_type + named_subject, pattern_type.type, pattern_type.rest_type ) self.remove_capture_conflicts(pattern_type.captures, inferred_types) self.push_type_map(pattern_map) @@ -5100,7 +5105,7 @@ def visit_match_stmt(self, s: MatchStmt) -> None: and expr.fullname == case_target.fullname ): continue - type_map[s.subject] = type_map[expr] + type_map[named_subject] = type_map[expr] self.push_type_map(guard_map) self.accept(b) diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index d3cdf3af849d4..3c31911f98f7c 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -1139,6 +1139,18 @@ match m: reveal_type(a) # N: Revealed type is "builtins.str" +[case testMatchCapturePatternFromFunctionReturningUnion] +def func(arg: bool) -> str | int: + if arg: + return 1 + return "a" + +match func(True): + case str(a): + reveal_type(a) # N: Revealed type is "builtins.str" + case a: + reveal_type(a) # N: Revealed type is "builtins.int" + -- Guards -- [case testMatchSimplePatternGuard]