Skip to content

Commit

Permalink
[red-knot] MDTest: Use custom class names instead of builtins (#16269)
Browse files Browse the repository at this point in the history
## Summary

Follow up on the discussion
[here](#16121 (comment)).
Replace builtin classes with custom placeholder names, which should
hopefully make the tests a bit easier to understand.

I carefully renamed things one after the other, to make sure that there
is no functional change in the tests.
  • Loading branch information
sharkdp authored Feb 20, 2025
1 parent fc6b03c commit 8198668
Show file tree
Hide file tree
Showing 4 changed files with 359 additions and 282 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,38 @@ most common case involves implementing these methods for the same type:
```py
from __future__ import annotations

class EqReturnType: ...
class NeReturnType: ...
class LtReturnType: ...
class LeReturnType: ...
class GtReturnType: ...
class GeReturnType: ...

class A:
def __eq__(self, other: A) -> int:
return 42
def __eq__(self, other: A) -> EqReturnType:
return EqReturnType()

def __ne__(self, other: A) -> bytearray:
return bytearray()
def __ne__(self, other: A) -> NeReturnType:
return NeReturnType()

def __lt__(self, other: A) -> str:
return "42"
def __lt__(self, other: A) -> LtReturnType:
return LtReturnType()

def __le__(self, other: A) -> bytes:
return b"42"
def __le__(self, other: A) -> LeReturnType:
return LeReturnType()

def __gt__(self, other: A) -> list:
return [42]
def __gt__(self, other: A) -> GtReturnType:
return GtReturnType()

def __ge__(self, other: A) -> set:
return {42}
def __ge__(self, other: A) -> GeReturnType:
return GeReturnType()

reveal_type(A() == A()) # revealed: int
reveal_type(A() != A()) # revealed: bytearray
reveal_type(A() < A()) # revealed: str
reveal_type(A() <= A()) # revealed: bytes
reveal_type(A() > A()) # revealed: list
reveal_type(A() >= A()) # revealed: set
reveal_type(A() == A()) # revealed: EqReturnType
reveal_type(A() != A()) # revealed: NeReturnType
reveal_type(A() < A()) # revealed: LtReturnType
reveal_type(A() <= A()) # revealed: LeReturnType
reveal_type(A() > A()) # revealed: GtReturnType
reveal_type(A() >= A()) # revealed: GeReturnType
```

## Rich Comparison Dunder Implementations for Other Class
Expand All @@ -51,33 +58,40 @@ type:
```py
from __future__ import annotations

class EqReturnType: ...
class NeReturnType: ...
class LtReturnType: ...
class LeReturnType: ...
class GtReturnType: ...
class GeReturnType: ...

class A:
def __eq__(self, other: B) -> int:
return 42
def __eq__(self, other: B) -> EqReturnType:
return EqReturnType()

def __ne__(self, other: B) -> bytearray:
return bytearray()
def __ne__(self, other: B) -> NeReturnType:
return NeReturnType()

def __lt__(self, other: B) -> str:
return "42"
def __lt__(self, other: B) -> LtReturnType:
return LtReturnType()

def __le__(self, other: B) -> bytes:
return b"42"
def __le__(self, other: B) -> LeReturnType:
return LeReturnType()

def __gt__(self, other: B) -> list:
return [42]
def __gt__(self, other: B) -> GtReturnType:
return GtReturnType()

def __ge__(self, other: B) -> set:
return {42}
def __ge__(self, other: B) -> GeReturnType:
return GeReturnType()

class B: ...

reveal_type(A() == B()) # revealed: int
reveal_type(A() != B()) # revealed: bytearray
reveal_type(A() < B()) # revealed: str
reveal_type(A() <= B()) # revealed: bytes
reveal_type(A() > B()) # revealed: list
reveal_type(A() >= B()) # revealed: set
reveal_type(A() == B()) # revealed: EqReturnType
reveal_type(A() != B()) # revealed: NeReturnType
reveal_type(A() < B()) # revealed: LtReturnType
reveal_type(A() <= B()) # revealed: LeReturnType
reveal_type(A() > B()) # revealed: GtReturnType
reveal_type(A() >= B()) # revealed: GeReturnType
```

## Reflected Comparisons
Expand All @@ -89,55 +103,64 @@ these methods will be ignored here because they require a mismatched operand typ
```py
from __future__ import annotations

class EqReturnType: ...
class NeReturnType: ...
class LtReturnType: ...
class LeReturnType: ...
class GtReturnType: ...
class GeReturnType: ...

class A:
def __eq__(self, other: B) -> int:
return 42
def __eq__(self, other: B) -> EqReturnType:
return EqReturnType()

def __ne__(self, other: B) -> NeReturnType:
return NeReturnType()

def __ne__(self, other: B) -> bytearray:
return bytearray()
def __lt__(self, other: B) -> LtReturnType:
return LtReturnType()

def __lt__(self, other: B) -> str:
return "42"
def __le__(self, other: B) -> LeReturnType:
return LeReturnType()

def __le__(self, other: B) -> bytes:
return b"42"
def __gt__(self, other: B) -> GtReturnType:
return GtReturnType()

def __gt__(self, other: B) -> list:
return [42]
def __ge__(self, other: B) -> GeReturnType:
return GeReturnType()

def __ge__(self, other: B) -> set:
return {42}
class Unrelated: ...

class B:
# To override builtins.object.__eq__ and builtins.object.__ne__
# TODO these should emit an invalid override diagnostic
def __eq__(self, other: str) -> B:
def __eq__(self, other: Unrelated) -> B:
return B()

def __ne__(self, other: str) -> B:
def __ne__(self, other: Unrelated) -> B:
return B()

# Because `object.__eq__` and `object.__ne__` accept `object` in typeshed,
# this can only happen with an invalid override of these methods,
# but we still support it.
reveal_type(B() == A()) # revealed: int
reveal_type(B() != A()) # revealed: bytearray
reveal_type(B() == A()) # revealed: EqReturnType
reveal_type(B() != A()) # revealed: NeReturnType

reveal_type(B() < A()) # revealed: list
reveal_type(B() <= A()) # revealed: set
reveal_type(B() < A()) # revealed: GtReturnType
reveal_type(B() <= A()) # revealed: GeReturnType

reveal_type(B() > A()) # revealed: str
reveal_type(B() >= A()) # revealed: bytes
reveal_type(B() > A()) # revealed: LtReturnType
reveal_type(B() >= A()) # revealed: LeReturnType

class C:
def __gt__(self, other: C) -> int:
def __gt__(self, other: C) -> EqReturnType:
return 42

def __ge__(self, other: C) -> bytearray:
return bytearray()
def __ge__(self, other: C) -> NeReturnType:
return NeReturnType()

reveal_type(C() < C()) # revealed: int
reveal_type(C() <= C()) # revealed: bytearray
reveal_type(C() < C()) # revealed: EqReturnType
reveal_type(C() <= C()) # revealed: NeReturnType
```

## Reflected Comparisons with Subclasses
Expand All @@ -149,6 +172,13 @@ than `A`.
```py
from __future__ import annotations

class EqReturnType: ...
class NeReturnType: ...
class LtReturnType: ...
class LeReturnType: ...
class GtReturnType: ...
class GeReturnType: ...

class A:
def __eq__(self, other: A) -> A:
return A()
Expand All @@ -169,32 +199,32 @@ class A:
return A()

class B(A):
def __eq__(self, other: A) -> int:
return 42
def __eq__(self, other: A) -> EqReturnType:
return EqReturnType()

def __ne__(self, other: A) -> bytearray:
return bytearray()
def __ne__(self, other: A) -> NeReturnType:
return NeReturnType()

def __lt__(self, other: A) -> str:
return "42"
def __lt__(self, other: A) -> LtReturnType:
return LtReturnType()

def __le__(self, other: A) -> bytes:
return b"42"
def __le__(self, other: A) -> LeReturnType:
return LeReturnType()

def __gt__(self, other: A) -> list:
return [42]
def __gt__(self, other: A) -> GtReturnType:
return GtReturnType()

def __ge__(self, other: A) -> set:
return {42}
def __ge__(self, other: A) -> GeReturnType:
return GeReturnType()

reveal_type(A() == B()) # revealed: int
reveal_type(A() != B()) # revealed: bytearray
reveal_type(A() == B()) # revealed: EqReturnType
reveal_type(A() != B()) # revealed: NeReturnType

reveal_type(A() < B()) # revealed: list
reveal_type(A() <= B()) # revealed: set
reveal_type(A() < B()) # revealed: GtReturnType
reveal_type(A() <= B()) # revealed: GeReturnType

reveal_type(A() > B()) # revealed: str
reveal_type(A() >= B()) # revealed: bytes
reveal_type(A() > B()) # revealed: LtReturnType
reveal_type(A() >= B()) # revealed: LeReturnType
```

## Reflected Comparisons with Subclass But Falls Back to LHS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,33 +147,40 @@ of the dunder methods.)
```py
from __future__ import annotations

class EqReturnType: ...
class NeReturnType: ...
class LtReturnType: ...
class LeReturnType: ...
class GtReturnType: ...
class GeReturnType: ...

class A:
def __eq__(self, o: object) -> str:
return "hello"
def __eq__(self, o: object) -> EqReturnType:
return EqReturnType()

def __ne__(self, o: object) -> bytes:
return b"world"
def __ne__(self, o: object) -> NeReturnType:
return NeReturnType()

def __lt__(self, o: A) -> bytearray:
return bytearray()
def __lt__(self, o: A) -> LtReturnType:
return LtReturnType()

def __le__(self, o: A) -> memoryview:
return memoryview(b"")
def __le__(self, o: A) -> LeReturnType:
return LeReturnType()

def __gt__(self, o: A) -> tuple:
return (1, 2, 3)
def __gt__(self, o: A) -> GtReturnType:
return GtReturnType()

def __ge__(self, o: A) -> list:
return [1, 2, 3]
def __ge__(self, o: A) -> GeReturnType:
return GeReturnType()

a = (A(), A())

reveal_type(a == a) # revealed: bool
reveal_type(a != a) # revealed: bool
reveal_type(a < a) # revealed: bytearray | Literal[False]
reveal_type(a <= a) # revealed: memoryview | Literal[True]
reveal_type(a > a) # revealed: tuple | Literal[False]
reveal_type(a >= a) # revealed: list | Literal[True]
reveal_type(a < a) # revealed: LtReturnType | Literal[False]
reveal_type(a <= a) # revealed: LeReturnType | Literal[True]
reveal_type(a > a) # revealed: GtReturnType | Literal[False]
reveal_type(a >= a) # revealed: GeReturnType | Literal[True]

# If lexicographic comparison is finished before comparing A()
b = ("1_foo", A())
Expand All @@ -186,11 +193,13 @@ reveal_type(b <= c) # revealed: Literal[True]
reveal_type(b > c) # revealed: Literal[False]
reveal_type(b >= c) # revealed: Literal[False]

class LtReturnTypeOnB: ...

class B:
def __lt__(self, o: B) -> set:
def __lt__(self, o: B) -> LtReturnTypeOnB:
return set()

reveal_type((A(), B()) < (A(), B())) # revealed: bytearray | set | Literal[False]
reveal_type((A(), B()) < (A(), B())) # revealed: LtReturnType | LtReturnTypeOnB | Literal[False]
```

#### Special Handling of Eq and NotEq in Lexicographic Comparisons
Expand Down
Loading

0 comments on commit 8198668

Please sign in to comment.