Skip to content

Commit be9760a

Browse files
committed
subtypes: fast path for Union/Union subtype check
Enums are exploded into Union of Literal when narrowed. Conditional branches on enum values can result in multiple distinct narrowing of the same enum which are later subject to subtype checks (most notably via `is_same_type`, when exiting frame context in the binder). Such checks would have quadratic complexity: `O(N*M)` where `N` and `M` are the number of entries in each narrowed enum variable, and led to drastic slowdown if any of the enums involved has a large number of valuees. Implemement a linear-time fast path where literals are quickly filtered, with a fallback to the slow path for more complex values. Fixes #13821
1 parent 695ea30 commit be9760a

File tree

1 file changed

+47
-12
lines changed

1 file changed

+47
-12
lines changed

mypy/subtypes.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from contextlib import contextmanager
44
from typing import Any, Callable, Iterator, List, TypeVar, cast
5-
from typing_extensions import Final, TypeAlias as _TypeAlias
5+
from typing_extensions import Final, Protocol, TypeAlias as _TypeAlias
66

77
import mypy.applytype
88
import mypy.constraints
@@ -57,6 +57,7 @@
5757
UninhabitedType,
5858
UnionType,
5959
UnpackType,
60+
flatten_nested_unions,
6061
get_proper_type,
6162
is_named_instance,
6263
)
@@ -269,6 +270,11 @@ def is_same_type(
269270
)
270271

271272

273+
class _SubtypeCheck(Protocol):
274+
def __call__(self, left: Type, right: Type, *, subtype_context: SubtypeContext) -> bool:
275+
...
276+
277+
272278
# This is a common entry point for subtyping checks (both proper and non-proper).
273279
# Never call this private function directly, use the public versions.
274280
def _is_subtype(
@@ -289,17 +295,14 @@ def _is_subtype(
289295
# ErasedType as we do for non-proper subtyping.
290296
return True
291297

292-
def check_item(left: Type, right: Type, subtype_context: SubtypeContext) -> bool:
293-
if proper_subtype:
294-
return is_proper_subtype(left, right, subtype_context=subtype_context)
295-
return is_subtype(left, right, subtype_context=subtype_context)
298+
check_item = cast(_SubtypeCheck, is_proper_subtype if proper_subtype else is_subtype)
296299

297300
if isinstance(right, UnionType) and not isinstance(left, UnionType):
298301
# Normally, when 'left' is not itself a union, the only way
299302
# 'left' can be a subtype of the union 'right' is if it is a
300303
# subtype of one of the items making up the union.
301304
is_subtype_of_item = any(
302-
check_item(orig_left, item, subtype_context) for item in right.items
305+
check_item(orig_left, item, subtype_context=subtype_context) for item in right.items
303306
)
304307
# Recombine rhs literal types, to make an enum type a subtype
305308
# of a union of all enum items as literal types. Only do it if
@@ -313,7 +316,8 @@ def check_item(left: Type, right: Type, subtype_context: SubtypeContext) -> bool
313316
):
314317
right = UnionType(mypy.typeops.try_contracting_literals_in_union(right.items))
315318
is_subtype_of_item = any(
316-
check_item(orig_left, item, subtype_context) for item in right.items
319+
check_item(orig_left, item, subtype_context=subtype_context)
320+
for item in right.items
317321
)
318322
# However, if 'left' is a type variable T, T might also have
319323
# an upper bound which is itself a union. This case will be
@@ -872,19 +876,50 @@ def visit_overloaded(self, left: Overloaded) -> bool:
872876
return False
873877

874878
def visit_union_type(self, left: UnionType) -> bool:
875-
if isinstance(self.right, Instance):
879+
if isinstance(self.right, (UnionType, Instance)):
880+
# prune literals early to avoid nasty quadratic behavior which would otherwise arise when checking
881+
# subtype relationships between slightly different narrowings of an Enum
882+
# we achieve O(N+M) instead of O(N*M)
883+
884+
right_lit_types: set[Instance] = set()
885+
right_lit_values: set[LiteralType] = set()
886+
887+
if isinstance(self.right, UnionType):
888+
for item in flatten_nested_unions(
889+
self.right.relevant_items(), handle_type_alias_type=True
890+
):
891+
p_item = get_proper_type(item)
892+
if isinstance(p_item, LiteralType):
893+
right_lit_values.add(p_item)
894+
elif isinstance(p_item, Instance):
895+
if p_item.last_known_value is None:
896+
right_lit_types.add(p_item)
897+
else:
898+
right_lit_values.add(p_item.last_known_value)
899+
elif isinstance(self.right, Instance):
900+
if self.right.last_known_value is None:
901+
right_lit_types.add(self.right)
902+
else:
903+
right_lit_values.add(self.right.last_known_value)
904+
876905
literal_types: set[Instance] = set()
877-
# avoid redundant check for union of literals
878906
for item in left.relevant_items():
879907
p_item = get_proper_type(item)
908+
if p_item in right_lit_types or p_item in right_lit_values:
909+
continue
880910
lit_type = mypy.typeops.simple_literal_type(p_item)
881911
if lit_type is not None:
882-
if lit_type in literal_types:
912+
if lit_type in right_lit_types:
883913
continue
884-
literal_types.add(lit_type)
885-
item = lit_type
914+
if isinstance(self.right, Instance):
915+
if lit_type in literal_types:
916+
continue
917+
literal_types.add(lit_type)
918+
item = lit_type
919+
886920
if not self._is_subtype(item, self.orig_right):
887921
return False
922+
888923
return True
889924
return all(self._is_subtype(item, self.orig_right) for item in left.items)
890925

0 commit comments

Comments
 (0)