Skip to content

Commit c09b174

Browse files
authored
Better handling of generics when narrowing (#20863)
Notably we preserve behaviour on the `testNarrowingCollections` test I added in a previous PR Closes #20673
1 parent 6958e77 commit c09b174

3 files changed

Lines changed: 134 additions & 29 deletions

File tree

mypy/checker.py

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,12 @@
3232
)
3333
from mypy.checkpattern import PatternChecker
3434
from mypy.constraints import SUPERTYPE_OF
35-
from mypy.erasetype import erase_type, erase_typevars, remove_instance_last_known_values
35+
from mypy.erasetype import (
36+
erase_type,
37+
erase_typevars,
38+
remove_instance_last_known_values,
39+
shallow_erase_type_for_equality,
40+
)
3641
from mypy.errorcodes import TYPE_VAR, UNUSED_AWAITABLE, UNUSED_COROUTINE, ErrorCode
3742
from mypy.errors import (
3843
ErrorInfo,
@@ -6628,6 +6633,9 @@ def comparison_type_narrowing_helper(self, node: ComparisonExpr) -> tuple[TypeMa
66286633
narrowable_indices={0},
66296634
)
66306635

6636+
# TODO: This remove_optional code should no longer be needed. The only
6637+
# thing it does is paper over a pre-existing deficiency in equality
6638+
# narrowing w.r.t to enums.
66316639
# We only try and narrow away 'None' for now
66326640
if (
66336641
not is_unreachable_map(if_map)
@@ -6775,8 +6783,7 @@ def narrow_type_by_identity_equality(
67756783
target = TypeRange(target_type, is_upper_bound=False)
67766784

67776785
if_map, else_map = conditional_types_to_typemaps(
6778-
operands[i],
6779-
*conditional_types(expr_type, [target], consider_promotion_overlap=True),
6786+
operands[i], *conditional_types(expr_type, [target], from_equality=True)
67806787
)
67816788
if is_target_for_value_narrowing(get_proper_type(target_type)):
67826789
all_if_maps.append(if_map)
@@ -6814,9 +6821,7 @@ def narrow_type_by_identity_equality(
68146821
if is_target_for_value_narrowing(get_proper_type(target_type)):
68156822
if_map, else_map = conditional_types_to_typemaps(
68166823
operands[i],
6817-
*conditional_types(
6818-
expr_type, [target], consider_promotion_overlap=True
6819-
),
6824+
*conditional_types(expr_type, [target], from_equality=True),
68206825
)
68216826
all_else_maps.append(else_map)
68226827
continue
@@ -6855,7 +6860,7 @@ def narrow_type_by_identity_equality(
68556860
if_map, else_map = conditional_types_to_typemaps(
68566861
operands[i],
68576862
*conditional_types(
6858-
expr_type, [target], default=expr_type, consider_promotion_overlap=True
6863+
expr_type, [target], default=expr_type, from_equality=True
68596864
),
68606865
)
68616866
or_if_maps.append(if_map)
@@ -8359,7 +8364,7 @@ def conditional_types(
83598364
default: None = None,
83608365
*,
83618366
consider_runtime_isinstance: bool = True,
8362-
consider_promotion_overlap: bool = False,
8367+
from_equality: bool = False,
83638368
) -> tuple[Type | None, Type | None]: ...
83648369

83658370

@@ -8370,7 +8375,7 @@ def conditional_types(
83708375
default: Type,
83718376
*,
83728377
consider_runtime_isinstance: bool = True,
8373-
consider_promotion_overlap: bool = False,
8378+
from_equality: bool = False,
83748379
) -> tuple[Type, Type]: ...
83758380

83768381

@@ -8380,7 +8385,7 @@ def conditional_types(
83808385
default: Type | None = None,
83818386
*,
83828387
consider_runtime_isinstance: bool = True,
8383-
consider_promotion_overlap: bool = False,
8388+
from_equality: bool = False,
83848389
) -> tuple[Type | None, Type | None]:
83858390
"""Takes in the current type and a proposed type of an expression.
83868391
@@ -8425,7 +8430,7 @@ def conditional_types(
84258430
proposed_type_ranges,
84268431
default=union_item,
84278432
consider_runtime_isinstance=consider_runtime_isinstance,
8428-
consider_promotion_overlap=consider_promotion_overlap,
8433+
from_equality=from_equality,
84298434
)
84308435
yes_items.append(yes_type)
84318436
no_items.append(no_type)
@@ -8470,17 +8475,23 @@ def conditional_types(
84708475
consider_runtime_isinstance=consider_runtime_isinstance,
84718476
)
84728477
return default, remainder
8473-
if not is_overlapping_types(
8474-
current_type, proposed_type, ignore_promotions=not consider_promotion_overlap
8475-
):
8476-
# Expression is never of any type in proposed_type_ranges
8477-
return UninhabitedType(), default
8478-
if consider_promotion_overlap and not is_overlapping_types(
8479-
current_type, proposed_type, ignore_promotions=True
8480-
):
8481-
# We set consider_promotion_overlap when comparing equality. This is one of the places
8482-
# at runtime where subtyping with promotion does happen to match runtime semantics
8483-
return default, default
8478+
8479+
if from_equality:
8480+
# We erase generic args because values with different generic types can compare equal
8481+
# For instance, cast(list[str], []) and cast(list[int], [])
8482+
proposed_type = shallow_erase_type_for_equality(proposed_type)
8483+
if not is_overlapping_types(current_type, proposed_type, ignore_promotions=False):
8484+
# Equality narrowing is one of the places at runtime where subtyping with promotion
8485+
# does happen to match runtime semantics
8486+
# Expression is never of any type in proposed_type_ranges
8487+
return UninhabitedType(), default
8488+
if not is_overlapping_types(current_type, proposed_type, ignore_promotions=True):
8489+
return default, default
8490+
else:
8491+
if not is_overlapping_types(current_type, proposed_type, ignore_promotions=True):
8492+
# Expression is never of any type in proposed_type_ranges
8493+
return UninhabitedType(), default
8494+
84848495
# we can only restrict when the type is precise, not bounded
84858496
proposed_precise_type = UnionType.make_union(
84868497
[type_range.item for type_range in proposed_type_ranges if not type_range.is_upper_bound]
@@ -8726,13 +8737,7 @@ def reduce_and_conditional_type_maps(ms: list[TypeMap], *, use_meet: bool) -> Ty
87268737
return result
87278738

87288739

8729-
BUILTINS_CUSTOM_EQ_CHECKS: Final = {
8730-
"builtins.bytearray",
8731-
"builtins.memoryview",
8732-
"builtins.list",
8733-
"builtins.dict",
8734-
"builtins.set",
8735-
}
8740+
BUILTINS_CUSTOM_EQ_CHECKS: Final = {"builtins.bytearray", "builtins.memoryview"}
87368741

87378742

87388743
def has_custom_eq_checks(t: Type) -> bool:

mypy/erasetype.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,3 +285,17 @@ def visit_union_type(self, t: UnionType) -> Type:
285285
merged.append(orig_item)
286286
return UnionType.make_union(merged)
287287
return new
288+
289+
290+
def shallow_erase_type_for_equality(typ: Type) -> ProperType:
291+
"""Erase type variables from Instance's"""
292+
p_typ = get_proper_type(typ)
293+
if isinstance(p_typ, Instance):
294+
if not p_typ.args:
295+
return p_typ
296+
args = erased_vars(p_typ.type.defn.type_vars, TypeOfAny.special_form)
297+
return Instance(p_typ.type, args, p_typ.line)
298+
if isinstance(p_typ, UnionType):
299+
items = [shallow_erase_type_for_equality(item) for item in p_typ.items]
300+
return UnionType.make_union(items)
301+
return p_typ

test-data/unit/check-narrowing.test

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,6 +1065,92 @@ def f(x: Custom, y: CustomSub):
10651065
reveal_type(y) # N: Revealed type is "__main__.CustomSub"
10661066
[builtins fixtures/tuple.pyi]
10671067

1068+
[case testNarrowingCustomEqualityGeneric]
1069+
# flags: --strict-equality --warn-unreachable
1070+
from __future__ import annotations
1071+
from typing import Union
1072+
1073+
class Custom:
1074+
def __eq__(self, other: object) -> bool:
1075+
raise
1076+
1077+
class Default: ...
1078+
1079+
def f1(x: list[Custom] | Default, y: list[int]):
1080+
if x == y: # E: Non-overlapping equality check (left operand type: "list[Custom] | Default", right operand type: "list[int]")
1081+
reveal_type(x) # N: Revealed type is "builtins.list[__main__.Custom]"
1082+
reveal_type(y) # N: Revealed type is "builtins.list[builtins.int]"
1083+
else:
1084+
reveal_type(x) # N: Revealed type is "builtins.list[__main__.Custom] | __main__.Default"
1085+
reveal_type(y) # N: Revealed type is "builtins.list[builtins.int]"
1086+
1087+
f1([], [])
1088+
1089+
def f2(x: list[Custom] | Default, y: list[int] | list[Default]):
1090+
if x == y: # E: Non-overlapping equality check (left operand type: "list[Custom] | Default", right operand type: "list[int] | list[Default]")
1091+
reveal_type(x) # N: Revealed type is "builtins.list[__main__.Custom]"
1092+
reveal_type(y) # N: Revealed type is "builtins.list[builtins.int] | builtins.list[__main__.Default]"
1093+
else:
1094+
reveal_type(x) # N: Revealed type is "builtins.list[__main__.Custom] | __main__.Default"
1095+
reveal_type(y) # N: Revealed type is "builtins.list[builtins.int] | builtins.list[__main__.Default]"
1096+
1097+
listcustom_or_default = Union[list[Custom], Default]
1098+
listint_or_default = Union[list[int], list[Default]]
1099+
1100+
def f2_with_alias(x: listcustom_or_default, y: listint_or_default):
1101+
if x == y: # E: Non-overlapping equality check (left operand type: "list[Custom] | Default", right operand type: "list[int] | list[Default]")
1102+
reveal_type(x) # N: Revealed type is "builtins.list[__main__.Custom]"
1103+
reveal_type(y) # N: Revealed type is "builtins.list[builtins.int] | builtins.list[__main__.Default]"
1104+
else:
1105+
reveal_type(x) # N: Revealed type is "builtins.list[__main__.Custom] | __main__.Default"
1106+
reveal_type(y) # N: Revealed type is "builtins.list[builtins.int] | builtins.list[__main__.Default]"
1107+
1108+
def f3(x: Custom | dict[str, str], y: dict[int, int]):
1109+
if x == y:
1110+
reveal_type(x) # N: Revealed type is "__main__.Custom | builtins.dict[builtins.str, builtins.str]"
1111+
reveal_type(y) # N: Revealed type is "builtins.dict[builtins.int, builtins.int]"
1112+
else:
1113+
reveal_type(x) # N: Revealed type is "__main__.Custom | builtins.dict[builtins.str, builtins.str]"
1114+
reveal_type(y) # N: Revealed type is "builtins.dict[builtins.int, builtins.int]"
1115+
[builtins fixtures/primitives.pyi]
1116+
1117+
[case testNarrowingRecursiveCallable]
1118+
# flags: --strict-equality --warn-unreachable
1119+
from __future__ import annotations
1120+
from typing import Callable
1121+
1122+
class A: ...
1123+
class B: ...
1124+
1125+
T = Callable[[A], "S"]
1126+
S = Callable[[B], "T"]
1127+
1128+
def f(x: S, y: T):
1129+
if x == y: # E: Unsupported left operand type for == ("Callable[[B], T]")
1130+
reveal_type(x) # N: Revealed type is "def (__main__.B) -> def (__main__.A) -> ..."
1131+
reveal_type(y) # N: Revealed type is "def (__main__.A) -> def (__main__.B) -> ..."
1132+
else:
1133+
reveal_type(x) # N: Revealed type is "def (__main__.B) -> def (__main__.A) -> ..."
1134+
reveal_type(y) # N: Revealed type is "def (__main__.A) -> def (__main__.B) -> ..."
1135+
[builtins fixtures/tuple.pyi]
1136+
1137+
[case testNarrowingRecursiveUnion]
1138+
# flags: --strict-equality --warn-unreachable
1139+
from __future__ import annotations
1140+
from typing import Union
1141+
1142+
class A: ...
1143+
class B: ...
1144+
1145+
T = Union[A, "S"]
1146+
S = Union[B, "T"] # E: Invalid recursive alias: a union item of itself
1147+
1148+
def f(x: S, y: T):
1149+
if x == y:
1150+
reveal_type(x) # N: Revealed type is "Any"
1151+
reveal_type(y) # N: Revealed type is "__main__.A | Any"
1152+
[builtins fixtures/tuple.pyi]
1153+
10681154
[case testNarrowingUnreachableCases]
10691155
# flags: --strict-equality --warn-unreachable
10701156
from typing import Literal, Union

0 commit comments

Comments
 (0)