Skip to content

Commit

Permalink
Fix narrowing on match with function subject (#16503)
Browse files Browse the repository at this point in the history
Fixes #12998

mypy can't narrow match statements with functions subjects because the
callexpr node is not a literal node. This adds a 'dummy' literal node
that the match statement visitor can use to do the type narrowing.

The python grammar describes the the match subject as a named expression
so this uses that nameexpr node as it's literal.

---------

Co-authored-by: hauntsaninja <hauntsaninja@gmail.com>
  • Loading branch information
edpaget and hauntsaninja authored Feb 17, 2024
1 parent bfbac5e commit 17271e5
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 3 deletions.
19 changes: 16 additions & 3 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5053,6 +5053,19 @@ def visit_continue_stmt(self, s: ContinueStmt) -> None:
return

def visit_match_stmt(self, s: MatchStmt) -> None:
named_subject: Expression
if isinstance(s.subject, CallExpr):
# Create a dummy subject expression to handle cases where a match statement's subject
# is not a literal value. This lets us correctly narrow types and check exhaustivity
# This is hack!
id = s.subject.callee.fullname if isinstance(s.subject.callee, RefExpr) else ""
name = "dummy-match-" + id
v = Var(name)
named_subject = NameExpr(name)
named_subject.node = v
else:
named_subject = s.subject

with self.binder.frame_context(can_skip=False, fall_through=0):
subject_type = get_proper_type(self.expr_checker.accept(s.subject))

Expand All @@ -5071,7 +5084,7 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
# The second pass narrows down the types and type checks bodies.
for p, g, b in zip(s.patterns, s.guards, s.bodies):
current_subject_type = self.expr_checker.narrow_type_from_binder(
s.subject, subject_type
named_subject, subject_type
)
pattern_type = self.pattern_checker.accept(p, current_subject_type)
with self.binder.frame_context(can_skip=True, fall_through=2):
Expand All @@ -5082,7 +5095,7 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
else_map: TypeMap = {}
else:
pattern_map, else_map = conditional_types_to_typemaps(
s.subject, pattern_type.type, pattern_type.rest_type
named_subject, pattern_type.type, pattern_type.rest_type
)
self.remove_capture_conflicts(pattern_type.captures, inferred_types)
self.push_type_map(pattern_map)
Expand Down Expand Up @@ -5110,7 +5123,7 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
and expr.fullname == case_target.fullname
):
continue
type_map[s.subject] = type_map[expr]
type_map[named_subject] = type_map[expr]

self.push_type_map(guard_map)
self.accept(b)
Expand Down
15 changes: 15 additions & 0 deletions test-data/unit/check-python310.test
Original file line number Diff line number Diff line change
Expand Up @@ -1139,6 +1139,21 @@ match m:

reveal_type(a) # N: Revealed type is "builtins.str"

[case testMatchCapturePatternFromFunctionReturningUnion]
def func1(arg: bool) -> str | int: ...
def func2(arg: bool) -> bytes | int: ...

def main() -> None:
match func1(True):
case str(a):
match func2(True):
case c:
reveal_type(a) # N: Revealed type is "builtins.str"
reveal_type(c) # N: Revealed type is "Union[builtins.bytes, builtins.int]"
reveal_type(a) # N: Revealed type is "builtins.str"
case a:
reveal_type(a) # N: Revealed type is "builtins.int"

-- Guards --

[case testMatchSimplePatternGuard]
Expand Down

0 comments on commit 17271e5

Please sign in to comment.