Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 46 additions & 28 deletions mypy/checkpattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,40 +553,18 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType:
# Check class type
#
type_info = o.class_ref.node
typ = self.chk.expr_checker.accept(o.class_ref)
p_typ = get_proper_type(typ)
if isinstance(type_info, TypeAlias) and not type_info.no_args:
self.msg.fail(message_registry.CLASS_PATTERN_GENERIC_TYPE_ALIAS, o)
return self.early_non_match()
elif isinstance(p_typ, FunctionLike) and p_typ.is_type_obj():
typ = fill_typevars_with_any(p_typ.type_object())
type_range = TypeRange(typ, is_upper_bound=False)
elif (
isinstance(type_info, Var)
and type_info.type is not None
and type_info.fullname == "typing.Callable"
):
# Create a `Callable[..., Any]`
fallback = self.chk.named_type("builtins.function")
any_type = AnyType(TypeOfAny.unannotated)
typ = callable_with_ellipsis(any_type, ret_type=any_type, fallback=fallback)
type_range = TypeRange(typ, is_upper_bound=False)
elif isinstance(p_typ, TypeType):
typ = p_typ.item
type_range = TypeRange(p_typ.item, is_upper_bound=True)
elif not isinstance(p_typ, AnyType):
self.msg.fail(
message_registry.CLASS_PATTERN_TYPE_REQUIRED.format(
typ.str_with_options(self.options)
),
o,
)

typ = self.chk.expr_checker.accept(o.class_ref)
type_ranges = self.get_class_pattern_type_ranges(typ, o)
if type_ranges is None:
return self.early_non_match()
else:
type_range = get_type_range(typ)
typ = UnionType.make_union([t.item for t in type_ranges])

new_type, rest_type = self.chk.conditional_types_with_intersection(
current_type, [type_range], o, default=current_type
current_type, type_ranges, o, default=current_type
)
if is_uninhabited(new_type):
return self.early_non_match()
Expand Down Expand Up @@ -717,6 +695,46 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType:
new_type = UninhabitedType()
return PatternType(new_type, rest_type, captures)

def get_class_pattern_type_ranges(self, typ: Type, o: ClassPattern) -> list[TypeRange] | None:
p_typ = get_proper_type(typ)

if isinstance(p_typ, UnionType):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this branch is the real change

type_ranges = []
for item in p_typ.items:
type_range = self.get_class_pattern_type_ranges(item, o)
if type_range is not None:
type_ranges.extend(type_range)
if not type_ranges:
return None
return type_ranges

if isinstance(p_typ, FunctionLike) and p_typ.is_type_obj():
typ = fill_typevars_with_any(p_typ.type_object())
return [TypeRange(typ, is_upper_bound=False)]
if (
isinstance(o.class_ref.node, Var)
and o.class_ref.node.type is not None
and o.class_ref.node.fullname == "typing.Callable"
):
# Create a `Callable[..., Any]`
fallback = self.chk.named_type("builtins.function")
any_type = AnyType(TypeOfAny.unannotated)
typ = callable_with_ellipsis(any_type, ret_type=any_type, fallback=fallback)
return [TypeRange(typ, is_upper_bound=False)]
if isinstance(p_typ, TypeType):
typ = p_typ.item
return [TypeRange(p_typ.item, is_upper_bound=True)]
if isinstance(p_typ, AnyType):
return [TypeRange(p_typ, is_upper_bound=False)]

self.msg.fail(
message_registry.CLASS_PATTERN_TYPE_REQUIRED.format(
typ.str_with_options(self.options)
),
o,
)
return None

def should_self_match(self, typ: Type) -> bool:
typ = get_proper_type(typ)
if isinstance(typ, TupleType):
Expand Down
34 changes: 24 additions & 10 deletions test-data/unit/check-python310.test
Original file line number Diff line number Diff line change
Expand Up @@ -1091,13 +1091,27 @@ match m:
[builtins fixtures/tuple.pyi]

[case testMatchClassPatternIsNotType]
a = 1
m: object
# flags: --strict-equality --warn-unreachable
from typing import Any

match m:
case a(i, j): # E: Expected type in class pattern; found "builtins.int"
reveal_type(i)
reveal_type(j)
def match_int(m: object, a: int):
match m:
case a(i, j): # E: Expected type in class pattern; found "builtins.int"
reveal_type(i) # E: Statement is unreachable
reveal_type(j)

def match_int_str(m: object, a: int | str):
match m:
case a(i, j): # E: Expected type in class pattern; found "builtins.int" \
# E: Expected type in class pattern; found "builtins.str"
reveal_type(i) # E: Statement is unreachable
reveal_type(j)

def match_int_any(m: object, a: int | Any):
match m:
case a(i, j): # E: Expected type in class pattern; found "builtins.int"
reveal_type(i) # N: Revealed type is "Any"
reveal_type(j) # N: Revealed type is "Any"

[case testMatchClassPatternAny]
from typing import Any
Expand Down Expand Up @@ -1300,15 +1314,15 @@ def f4(T: type[Example | Example2]) -> None:

def f5(T: type[Example | Example2]) -> None:
match Example("a"):
case T(value): # E: Expected type in class pattern; found "type[__main__.Example] | type[__main__.Example2]"
reveal_type(value) # E: Statement is unreachable
case T(value):
reveal_type(value) # N: Revealed type is "builtins.str"
case anything:
reveal_type(anything) # N: Revealed type is "__main__.Example"

def f6(T: type[Example | Example2]) -> None:
match T("a"):
case T(value): # E: Expected type in class pattern; found "type[__main__.Example] | type[__main__.Example2]"
reveal_type(value) # E: Statement is unreachable
case T(value):
reveal_type(value) # N: Revealed type is "builtins.str"
case anything:
reveal_type(anything) # N: Revealed type is "__main__.Example | __main__.Example2"

Expand Down