Skip to content

Commit

Permalink
Fix strict equality with enum type with custom __eq__ (#14518)
Browse files Browse the repository at this point in the history
Fixes regression introduced in #14513.
  • Loading branch information
JukkaL authored Jan 24, 2023
1 parent 757e0d4 commit 0665ce9
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 10 deletions.
31 changes: 21 additions & 10 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2970,7 +2970,7 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
not local_errors.has_new_errors()
and cont_type
and self.dangerous_comparison(
left_type, cont_type, original_container=right_type
left_type, cont_type, original_container=right_type, prefer_literal=False
)
):
self.msg.dangerous_comparison(left_type, cont_type, "container", e)
Expand All @@ -2988,21 +2988,19 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
# testCustomEqCheckStrictEquality for an example.
if not w.has_new_errors() and operator in ("==", "!="):
right_type = self.accept(right)
# Also flag non-overlapping literals in situations like:
# x: Literal['a', 'b']
# if x == 'c':
# ...
left_type = try_getting_literal(left_type)
right_type = try_getting_literal(right_type)
if self.dangerous_comparison(left_type, right_type):
# Show the most specific literal types possible
left_type = try_getting_literal(left_type)
right_type = try_getting_literal(right_type)
self.msg.dangerous_comparison(left_type, right_type, "equality", e)

elif operator == "is" or operator == "is not":
right_type = self.accept(right) # validate the right operand
sub_result = self.bool_type()
left_type = try_getting_literal(left_type)
right_type = try_getting_literal(right_type)
if self.dangerous_comparison(left_type, right_type):
# Show the most specific literal types possible
left_type = try_getting_literal(left_type)
right_type = try_getting_literal(right_type)
self.msg.dangerous_comparison(left_type, right_type, "identity", e)
method_type = None
else:
Expand Down Expand Up @@ -3036,7 +3034,12 @@ def find_partial_type_ref_fast_path(self, expr: Expression) -> Type | None:
return None

def dangerous_comparison(
self, left: Type, right: Type, original_container: Type | None = None
self,
left: Type,
right: Type,
original_container: Type | None = None,
*,
prefer_literal: bool = True,
) -> bool:
"""Check for dangerous non-overlapping comparisons like 42 == 'no'.
Expand Down Expand Up @@ -3064,6 +3067,14 @@ def dangerous_comparison(
if custom_special_method(left, "__eq__") or custom_special_method(right, "__eq__"):
return False

if prefer_literal:
# Also flag non-overlapping literals in situations like:
# x: Literal['a', 'b']
# if x == 'c':
# ...
left = try_getting_literal(left)
right = try_getting_literal(right)

if self.chk.binder.is_unreachable_warning_suppressed():
# We are inside a function that contains type variables with value restrictions in
# its signature. In this case we just suppress all strict-equality checks to avoid
Expand Down
26 changes: 26 additions & 0 deletions test-data/unit/check-expressions.test
Original file line number Diff line number Diff line change
Expand Up @@ -2221,6 +2221,32 @@ int == y
y == int
[builtins fixtures/bool.pyi]

[case testStrictEqualityAndEnumWithCustomEq]
# flags: --strict-equality
from enum import Enum

class E1(Enum):
X = 0
Y = 1

class E2(Enum):
X = 0
Y = 1

def __eq__(self, other: object) -> bool:
return bool()

E1.X == E1.Y # E: Non-overlapping equality check (left operand type: "Literal[E1.X]", right operand type: "Literal[E1.Y]")
E2.X == E2.Y
[builtins fixtures/bool.pyi]

[case testStrictEqualityWithBytesContains]
# flags: --strict-equality
data = b"xy"
b"x" in data
[builtins fixtures/primitives.pyi]
[typing fixtures/typing-full.pyi]

[case testUnimportedHintAny]
def f(x: Any) -> None: # E: Name "Any" is not defined \
# N: Did you forget to import it from "typing"? (Suggestion: "from typing import Any")
Expand Down

0 comments on commit 0665ce9

Please sign in to comment.