diff --git a/mypy/subtypes.py b/mypy/subtypes.py index e4667c45fbc5f..1e6efdd6169db 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -2,7 +2,7 @@ from contextlib import contextmanager from typing import Any, Callable, Iterator, List, TypeVar, cast -from typing_extensions import Final, TypeAlias as _TypeAlias +from typing_extensions import Final, Protocol, TypeAlias as _TypeAlias import mypy.applytype import mypy.constraints @@ -57,6 +57,7 @@ UninhabitedType, UnionType, UnpackType, + flatten_nested_unions, get_proper_type, is_named_instance, ) @@ -269,6 +270,11 @@ def is_same_type( ) +class _SubtypeCheck(Protocol): + def __call__(self, left: Type, right: Type, *, subtype_context: SubtypeContext) -> bool: + ... + + # This is a common entry point for subtyping checks (both proper and non-proper). # Never call this private function directly, use the public versions. def _is_subtype( @@ -289,17 +295,14 @@ def _is_subtype( # ErasedType as we do for non-proper subtyping. return True - def check_item(left: Type, right: Type, subtype_context: SubtypeContext) -> bool: - if proper_subtype: - return is_proper_subtype(left, right, subtype_context=subtype_context) - return is_subtype(left, right, subtype_context=subtype_context) + check_item = cast(_SubtypeCheck, is_proper_subtype if proper_subtype else is_subtype) if isinstance(right, UnionType) and not isinstance(left, UnionType): # Normally, when 'left' is not itself a union, the only way # 'left' can be a subtype of the union 'right' is if it is a # subtype of one of the items making up the union. is_subtype_of_item = any( - check_item(orig_left, item, subtype_context) for item in right.items + check_item(orig_left, item, subtype_context=subtype_context) for item in right.items ) # Recombine rhs literal types, to make an enum type a subtype # of a union of all enum items as literal types. Only do it if @@ -313,7 +316,8 @@ def check_item(left: Type, right: Type, subtype_context: SubtypeContext) -> bool ): right = UnionType(mypy.typeops.try_contracting_literals_in_union(right.items)) is_subtype_of_item = any( - check_item(orig_left, item, subtype_context) for item in right.items + check_item(orig_left, item, subtype_context=subtype_context) + for item in right.items ) # However, if 'left' is a type variable T, T might also have # an upper bound which is itself a union. This case will be @@ -872,19 +876,50 @@ def visit_overloaded(self, left: Overloaded) -> bool: return False def visit_union_type(self, left: UnionType) -> bool: - if isinstance(self.right, Instance): + if isinstance(self.right, (UnionType, Instance)): + # prune literals early to avoid nasty quadratic behavior which would otherwise arise when checking + # subtype relationships between slightly different narrowings of an Enum + # we achieve O(N+M) instead of O(N*M) + + right_lit_types: set[Instance] = set() + right_lit_values: set[LiteralType] = set() + + if isinstance(self.right, UnionType): + for item in flatten_nested_unions( + self.right.relevant_items(), handle_type_alias_type=True + ): + p_item = get_proper_type(item) + if isinstance(p_item, LiteralType): + right_lit_values.add(p_item) + elif isinstance(p_item, Instance): + if p_item.last_known_value is None: + right_lit_types.add(p_item) + else: + right_lit_values.add(p_item.last_known_value) + elif isinstance(self.right, Instance): + if self.right.last_known_value is None: + right_lit_types.add(self.right) + else: + right_lit_values.add(self.right.last_known_value) + literal_types: set[Instance] = set() - # avoid redundant check for union of literals for item in left.relevant_items(): p_item = get_proper_type(item) + if p_item in right_lit_types or p_item in right_lit_values: + continue lit_type = mypy.typeops.simple_literal_type(p_item) if lit_type is not None: - if lit_type in literal_types: + if lit_type in right_lit_types: continue - literal_types.add(lit_type) - item = lit_type + if isinstance(self.right, Instance): + if lit_type in literal_types: + continue + literal_types.add(lit_type) + item = lit_type + if not self._is_subtype(item, self.orig_right): return False + return True return all(self._is_subtype(item, self.orig_right) for item in left.items)