Skip to content

Commit 48eea79

Browse files
committed
Better match narrowing for type objects
This is the more general fix I alluded to in #20367 (comment)
1 parent 837052e commit 48eea79

File tree

2 files changed

+62
-9
lines changed

2 files changed

+62
-9
lines changed

mypy/checkpattern.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,7 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType:
569569
return self.early_non_match()
570570
elif isinstance(p_typ, FunctionLike) and p_typ.is_type_obj():
571571
typ = fill_typevars_with_any(p_typ.type_object())
572+
type_range = TypeRange(typ, is_upper_bound=False)
572573
elif (
573574
isinstance(type_info, Var)
574575
and type_info.type is not None
@@ -578,8 +579,10 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType:
578579
fallback = self.chk.named_type("builtins.function")
579580
any_type = AnyType(TypeOfAny.unannotated)
580581
typ = callable_with_ellipsis(any_type, ret_type=any_type, fallback=fallback)
581-
elif isinstance(p_typ, TypeType) and isinstance(p_typ.item, NoneType):
582+
type_range = TypeRange(typ, is_upper_bound=False)
583+
elif isinstance(p_typ, TypeType):
582584
typ = p_typ.item
585+
type_range = TypeRange(p_typ.item, is_upper_bound=True)
583586
elif not isinstance(p_typ, AnyType):
584587
self.msg.fail(
585588
message_registry.CLASS_PATTERN_TYPE_REQUIRED.format(
@@ -588,9 +591,11 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType:
588591
o,
589592
)
590593
return self.early_non_match()
594+
else:
595+
type_range = get_type_range(typ)
591596

592597
new_type, rest_type = self.chk.conditional_types_with_intersection(
593-
current_type, [get_type_range(typ)], o, default=current_type
598+
current_type, [type_range], o, default=current_type
594599
)
595600
if is_uninhabited(new_type):
596601
return self.early_non_match()

test-data/unit/check-python310.test

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -344,19 +344,67 @@ match x:
344344
case [str()]:
345345
pass
346346

347-
[case testMatchSequencePatternWithInvalidClassPattern]
347+
[case testMatchSequencePatternWithTypeObjectClassPattern]
348+
# flags: --strict-equality --warn-unreachable
348349
class Example:
349350
__match_args__ = ("value",)
350351
def __init__(self, value: str) -> None:
351352
self.value = value
352353

353-
SubClass: type[Example]
354+
def f1(subclass: type[Example]) -> None:
355+
match subclass("a"):
356+
case Example(value):
357+
reveal_type(value) # N: Revealed type is "builtins.str"
358+
case anything:
359+
reveal_type(anything) # E: Statement is unreachable
360+
361+
def f2(subclass: type[Example]) -> None:
362+
match Example("a"):
363+
case subclass(value):
364+
reveal_type(value) # N: Revealed type is "builtins.str"
365+
case anything:
366+
reveal_type(anything) # N: Revealed type is "__main__.Example"
367+
368+
def f3(subclass: type[Example]) -> None:
369+
match subclass("a"):
370+
case subclass(value):
371+
reveal_type(value) # N: Revealed type is "builtins.str"
372+
case anything:
373+
reveal_type(anything) # N: Revealed type is "__main__.Example"
374+
375+
def f4(subclass: type[Example]) -> None:
376+
match [subclass("a"), subclass("b")]:
377+
case [subclass(value), *rest]:
378+
reveal_type(value) # N: Revealed type is "builtins.str"
379+
reveal_type(rest) # N: Revealed type is "builtins.list[__main__.Example]"
380+
case anything:
381+
reveal_type(anything) # N: Revealed type is "builtins.list[__main__.Example]"
382+
383+
class Example2:
384+
__match_args__ = ("value",)
385+
def __init__(self, value: str) -> None:
386+
self.value = value
354387

355-
match [SubClass("a"), SubClass("b")]:
356-
case [SubClass(value), *rest]: # E: Expected type in class pattern; found "type[__main__.Example]"
357-
reveal_type(value) # E: Cannot determine type of "value" \
358-
# N: Revealed type is "Any"
359-
reveal_type(rest) # N: Revealed type is "builtins.list[__main__.Example]"
388+
def f5(T: type[Example | Example2]) -> None:
389+
match T("a"):
390+
case Example(value):
391+
reveal_type(value) # N: Revealed type is "builtins.str"
392+
case anything:
393+
reveal_type(anything) # N: Revealed type is "__main__.Example2"
394+
395+
def f6(T: type[Example | Example2]) -> None:
396+
match Example("a"):
397+
case T(value): # E: Expected type in class pattern; found "type[__main__.Example] | type[__main__.Example2]"
398+
reveal_type(value) # E: Statement is unreachable
399+
case anything:
400+
reveal_type(anything) # N: Revealed type is "__main__.Example"
401+
402+
def f7(T: type[Example | Example2]) -> None:
403+
match T("a"):
404+
case T(value): # E: Expected type in class pattern; found "type[__main__.Example] | type[__main__.Example2]"
405+
reveal_type(value) # E: Statement is unreachable
406+
case anything:
407+
reveal_type(anything) # N: Revealed type is "__main__.Example | __main__.Example2"
360408
[builtins fixtures/tuple.pyi]
361409

362410
# Narrowing union-based values via a literal pattern on an indexed/attribute subject

0 commit comments

Comments
 (0)