2
2
3
3
from contextlib import contextmanager
4
4
from typing import Any , Callable , Iterator , List , TypeVar , cast
5
- from typing_extensions import Final , TypeAlias as _TypeAlias
5
+ from typing_extensions import Final , Protocol , TypeAlias as _TypeAlias
6
6
7
7
import mypy .applytype
8
8
import mypy .constraints
57
57
UninhabitedType ,
58
58
UnionType ,
59
59
UnpackType ,
60
+ flatten_nested_unions ,
60
61
get_proper_type ,
61
62
is_named_instance ,
62
63
)
@@ -269,6 +270,11 @@ def is_same_type(
269
270
)
270
271
271
272
273
+ class _SubtypeCheck (Protocol ):
274
+ def __call__ (self , left : Type , right : Type , * , subtype_context : SubtypeContext ) -> bool :
275
+ ...
276
+
277
+
272
278
# This is a common entry point for subtyping checks (both proper and non-proper).
273
279
# Never call this private function directly, use the public versions.
274
280
def _is_subtype (
@@ -289,17 +295,14 @@ def _is_subtype(
289
295
# ErasedType as we do for non-proper subtyping.
290
296
return True
291
297
292
- def check_item (left : Type , right : Type , subtype_context : SubtypeContext ) -> bool :
293
- if proper_subtype :
294
- return is_proper_subtype (left , right , subtype_context = subtype_context )
295
- return is_subtype (left , right , subtype_context = subtype_context )
298
+ check_item = cast (_SubtypeCheck , is_proper_subtype if proper_subtype else is_subtype )
296
299
297
300
if isinstance (right , UnionType ) and not isinstance (left , UnionType ):
298
301
# Normally, when 'left' is not itself a union, the only way
299
302
# 'left' can be a subtype of the union 'right' is if it is a
300
303
# subtype of one of the items making up the union.
301
304
is_subtype_of_item = any (
302
- check_item (orig_left , item , subtype_context ) for item in right .items
305
+ check_item (orig_left , item , subtype_context = subtype_context ) for item in right .items
303
306
)
304
307
# Recombine rhs literal types, to make an enum type a subtype
305
308
# of a union of all enum items as literal types. Only do it if
@@ -313,7 +316,8 @@ def check_item(left: Type, right: Type, subtype_context: SubtypeContext) -> bool
313
316
):
314
317
right = UnionType (mypy .typeops .try_contracting_literals_in_union (right .items ))
315
318
is_subtype_of_item = any (
316
- check_item (orig_left , item , subtype_context ) for item in right .items
319
+ check_item (orig_left , item , subtype_context = subtype_context )
320
+ for item in right .items
317
321
)
318
322
# However, if 'left' is a type variable T, T might also have
319
323
# an upper bound which is itself a union. This case will be
@@ -872,19 +876,50 @@ def visit_overloaded(self, left: Overloaded) -> bool:
872
876
return False
873
877
874
878
def visit_union_type (self , left : UnionType ) -> bool :
875
- if isinstance (self .right , Instance ):
879
+ if isinstance (self .right , (UnionType , Instance )):
880
+ # prune literals early to avoid nasty quadratic behavior which would otherwise arise when checking
881
+ # subtype relationships between slightly different narrowings of an Enum
882
+ # we achieve O(N+M) instead of O(N*M)
883
+
884
+ right_lit_types : set [Instance ] = set ()
885
+ right_lit_values : set [LiteralType ] = set ()
886
+
887
+ if isinstance (self .right , UnionType ):
888
+ for item in flatten_nested_unions (
889
+ self .right .relevant_items (), handle_type_alias_type = True
890
+ ):
891
+ p_item = get_proper_type (item )
892
+ if isinstance (p_item , LiteralType ):
893
+ right_lit_values .add (p_item )
894
+ elif isinstance (p_item , Instance ):
895
+ if p_item .last_known_value is None :
896
+ right_lit_types .add (p_item )
897
+ else :
898
+ right_lit_values .add (p_item .last_known_value )
899
+ elif isinstance (self .right , Instance ):
900
+ if self .right .last_known_value is None :
901
+ right_lit_types .add (self .right )
902
+ else :
903
+ right_lit_values .add (self .right .last_known_value )
904
+
876
905
literal_types : set [Instance ] = set ()
877
- # avoid redundant check for union of literals
878
906
for item in left .relevant_items ():
879
907
p_item = get_proper_type (item )
908
+ if p_item in right_lit_types or p_item in right_lit_values :
909
+ continue
880
910
lit_type = mypy .typeops .simple_literal_type (p_item )
881
911
if lit_type is not None :
882
- if lit_type in literal_types :
912
+ if lit_type in right_lit_types :
883
913
continue
884
- literal_types .add (lit_type )
885
- item = lit_type
914
+ if isinstance (self .right , Instance ):
915
+ if lit_type in literal_types :
916
+ continue
917
+ literal_types .add (lit_type )
918
+ item = lit_type
919
+
886
920
if not self ._is_subtype (item , self .orig_right ):
887
921
return False
922
+
888
923
return True
889
924
return all (self ._is_subtype (item , self .orig_right ) for item in left .items )
890
925
0 commit comments