3232)
3333from mypy .checkpattern import PatternChecker
3434from mypy .constraints import SUPERTYPE_OF
35- from mypy .erasetype import erase_type , erase_typevars , remove_instance_last_known_values
35+ from mypy .erasetype import (
36+ erase_type ,
37+ erase_typevars ,
38+ remove_instance_last_known_values ,
39+ shallow_erase_type_for_equality ,
40+ )
3641from mypy .errorcodes import TYPE_VAR , UNUSED_AWAITABLE , UNUSED_COROUTINE , ErrorCode
3742from mypy .errors import (
3843 ErrorInfo ,
@@ -6628,6 +6633,9 @@ def comparison_type_narrowing_helper(self, node: ComparisonExpr) -> tuple[TypeMa
66286633 narrowable_indices = {0 },
66296634 )
66306635
6636+ # TODO: This remove_optional code should no longer be needed. The only
6637+ # thing it does is paper over a pre-existing deficiency in equality
6638+ # narrowing w.r.t to enums.
66316639 # We only try and narrow away 'None' for now
66326640 if (
66336641 not is_unreachable_map (if_map )
@@ -6775,8 +6783,7 @@ def narrow_type_by_identity_equality(
67756783 target = TypeRange (target_type , is_upper_bound = False )
67766784
67776785 if_map , else_map = conditional_types_to_typemaps (
6778- operands [i ],
6779- * conditional_types (expr_type , [target ], consider_promotion_overlap = True ),
6786+ operands [i ], * conditional_types (expr_type , [target ], from_equality = True )
67806787 )
67816788 if is_target_for_value_narrowing (get_proper_type (target_type )):
67826789 all_if_maps .append (if_map )
@@ -6814,9 +6821,7 @@ def narrow_type_by_identity_equality(
68146821 if is_target_for_value_narrowing (get_proper_type (target_type )):
68156822 if_map , else_map = conditional_types_to_typemaps (
68166823 operands [i ],
6817- * conditional_types (
6818- expr_type , [target ], consider_promotion_overlap = True
6819- ),
6824+ * conditional_types (expr_type , [target ], from_equality = True ),
68206825 )
68216826 all_else_maps .append (else_map )
68226827 continue
@@ -6855,7 +6860,7 @@ def narrow_type_by_identity_equality(
68556860 if_map , else_map = conditional_types_to_typemaps (
68566861 operands [i ],
68576862 * conditional_types (
6858- expr_type , [target ], default = expr_type , consider_promotion_overlap = True
6863+ expr_type , [target ], default = expr_type , from_equality = True
68596864 ),
68606865 )
68616866 or_if_maps .append (if_map )
@@ -8359,7 +8364,7 @@ def conditional_types(
83598364 default : None = None ,
83608365 * ,
83618366 consider_runtime_isinstance : bool = True ,
8362- consider_promotion_overlap : bool = False ,
8367+ from_equality : bool = False ,
83638368) -> tuple [Type | None , Type | None ]: ...
83648369
83658370
@@ -8370,7 +8375,7 @@ def conditional_types(
83708375 default : Type ,
83718376 * ,
83728377 consider_runtime_isinstance : bool = True ,
8373- consider_promotion_overlap : bool = False ,
8378+ from_equality : bool = False ,
83748379) -> tuple [Type , Type ]: ...
83758380
83768381
@@ -8380,7 +8385,7 @@ def conditional_types(
83808385 default : Type | None = None ,
83818386 * ,
83828387 consider_runtime_isinstance : bool = True ,
8383- consider_promotion_overlap : bool = False ,
8388+ from_equality : bool = False ,
83848389) -> tuple [Type | None , Type | None ]:
83858390 """Takes in the current type and a proposed type of an expression.
83868391
@@ -8425,7 +8430,7 @@ def conditional_types(
84258430 proposed_type_ranges ,
84268431 default = union_item ,
84278432 consider_runtime_isinstance = consider_runtime_isinstance ,
8428- consider_promotion_overlap = consider_promotion_overlap ,
8433+ from_equality = from_equality ,
84298434 )
84308435 yes_items .append (yes_type )
84318436 no_items .append (no_type )
@@ -8470,17 +8475,23 @@ def conditional_types(
84708475 consider_runtime_isinstance = consider_runtime_isinstance ,
84718476 )
84728477 return default , remainder
8473- if not is_overlapping_types (
8474- current_type , proposed_type , ignore_promotions = not consider_promotion_overlap
8475- ):
8476- # Expression is never of any type in proposed_type_ranges
8477- return UninhabitedType (), default
8478- if consider_promotion_overlap and not is_overlapping_types (
8479- current_type , proposed_type , ignore_promotions = True
8480- ):
8481- # We set consider_promotion_overlap when comparing equality. This is one of the places
8482- # at runtime where subtyping with promotion does happen to match runtime semantics
8483- return default , default
8478+
8479+ if from_equality :
8480+ # We erase generic args because values with different generic types can compare equal
8481+ # For instance, cast(list[str], []) and cast(list[int], [])
8482+ proposed_type = shallow_erase_type_for_equality (proposed_type )
8483+ if not is_overlapping_types (current_type , proposed_type , ignore_promotions = False ):
8484+ # Equality narrowing is one of the places at runtime where subtyping with promotion
8485+ # does happen to match runtime semantics
8486+ # Expression is never of any type in proposed_type_ranges
8487+ return UninhabitedType (), default
8488+ if not is_overlapping_types (current_type , proposed_type , ignore_promotions = True ):
8489+ return default , default
8490+ else :
8491+ if not is_overlapping_types (current_type , proposed_type , ignore_promotions = True ):
8492+ # Expression is never of any type in proposed_type_ranges
8493+ return UninhabitedType (), default
8494+
84848495 # we can only restrict when the type is precise, not bounded
84858496 proposed_precise_type = UnionType .make_union (
84868497 [type_range .item for type_range in proposed_type_ranges if not type_range .is_upper_bound ]
@@ -8726,13 +8737,7 @@ def reduce_and_conditional_type_maps(ms: list[TypeMap], *, use_meet: bool) -> Ty
87268737 return result
87278738
87288739
8729- BUILTINS_CUSTOM_EQ_CHECKS : Final = {
8730- "builtins.bytearray" ,
8731- "builtins.memoryview" ,
8732- "builtins.list" ,
8733- "builtins.dict" ,
8734- "builtins.set" ,
8735- }
8740+ BUILTINS_CUSTOM_EQ_CHECKS : Final = {"builtins.bytearray" , "builtins.memoryview" }
87368741
87378742
87388743def has_custom_eq_checks (t : Type ) -> bool :
0 commit comments