Skip to content

Commit c9dfb51

Browse files
[ty] Fix match pattern value narrowing to use equality semantics (#20882)
## Summary Resolves astral-sh/ty#1349. Fix match statement value patterns to use equality comparison semantics instead of incorrectly narrowing to literal types directly. Value patterns use equality for matching, and equality can be overridden, so we can't always narrow to the matched literal. ## Test Plan Updated match.md with corrected expected types and an additional example with explanation --------- Co-authored-by: David Peter <mail@david-peter.de>
1 parent fe4e3e2 commit c9dfb51

File tree

2 files changed

+154
-98
lines changed

2 files changed

+154
-98
lines changed

crates/ty_python_semantic/resources/mdtest/narrow/match.md

Lines changed: 109 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -71,98 +71,140 @@ reveal_type(x) # revealed: object
7171

7272
## Value patterns
7373

74-
```py
75-
def get_object() -> object:
76-
return object()
74+
Value patterns are evaluated by equality, which is overridable. Therefore successfully matching on
75+
one can only give us information where we know how the subject type implements equality.
7776

78-
x = get_object()
77+
Consider the following example.
7978

80-
reveal_type(x) # revealed: object
79+
```py
80+
from typing import Literal
8181

82-
match x:
83-
case "foo":
84-
reveal_type(x) # revealed: Literal["foo"]
85-
case 42:
86-
reveal_type(x) # revealed: Literal[42]
87-
case 6.0:
88-
reveal_type(x) # revealed: float
89-
case 1j:
90-
reveal_type(x) # revealed: complex
91-
case b"foo":
92-
reveal_type(x) # revealed: Literal[b"foo"]
82+
def _(x: Literal["foo"] | int):
83+
match x:
84+
case "foo":
85+
reveal_type(x) # revealed: Literal["foo"] | int
9386

94-
reveal_type(x) # revealed: object
87+
match x:
88+
case "bar":
89+
reveal_type(x) # revealed: int
9590
```
9691

97-
## Value patterns with guard
92+
In the first `match`'s `case "foo"` all we know is `x == "foo"`. `x` could be an instance of an
93+
arbitrary `int` subclass with an arbitrary `__eq__`, so we can't actually narrow to
94+
`Literal["foo"]`.
95+
96+
In the second `match`'s `case "bar"` we know `x == "bar"`. As discussed above, this isn't enough to
97+
rule out `int`, but we know that `"foo" == "bar"` is false so we can eliminate `Literal["foo"]`.
98+
99+
More examples follow.
98100

99101
```py
100-
def get_object() -> object:
101-
return object()
102+
from typing import Literal
102103

103-
x = get_object()
104+
class C:
105+
pass
104106

105-
reveal_type(x) # revealed: object
107+
def _(x: Literal["foo", "bar", 42, b"foo"] | bool | complex):
108+
match x:
109+
case "foo":
110+
reveal_type(x) # revealed: Literal["foo"] | int | float | complex
111+
case 42:
112+
reveal_type(x) # revealed: int | float | complex
113+
case 6.0:
114+
reveal_type(x) # revealed: Literal["bar", b"foo"] | (int & ~Literal[42]) | float | complex
115+
case 1j:
116+
reveal_type(x) # revealed: Literal["bar", b"foo"] | (int & ~Literal[42]) | float | complex
117+
case b"foo":
118+
reveal_type(x) # revealed: (int & ~Literal[42]) | Literal[b"foo"] | float | complex
119+
case _:
120+
reveal_type(x) # revealed: Literal["bar"] | (int & ~Literal[42]) | float | complex
121+
```
106122

107-
match x:
108-
case "foo" if reveal_type(x): # revealed: Literal["foo"]
109-
pass
110-
case 42 if reveal_type(x): # revealed: Literal[42]
111-
pass
112-
case 6.0 if reveal_type(x): # revealed: float
113-
pass
114-
case 1j if reveal_type(x): # revealed: complex
115-
pass
116-
case b"foo" if reveal_type(x): # revealed: Literal[b"foo"]
117-
pass
123+
## Value patterns with guard
118124

119-
reveal_type(x) # revealed: object
125+
```py
126+
from typing import Literal
127+
128+
class C:
129+
pass
130+
131+
def _(x: Literal["foo", b"bar"] | int):
132+
match x:
133+
case "foo" if reveal_type(x): # revealed: Literal["foo"] | int
134+
pass
135+
case b"bar" if reveal_type(x): # revealed: Literal[b"bar"] | int
136+
pass
137+
case 42 if reveal_type(x): # revealed: int
138+
pass
120139
```
121140

122141
## Or patterns
123142

124143
```py
125-
def get_object() -> object:
126-
return object()
144+
from typing import Literal
145+
from enum import Enum
146+
147+
class Color(Enum):
148+
RED = 1
149+
GREEN = 2
150+
BLUE = 3
151+
152+
def _(color: Color):
153+
match color:
154+
case Color.RED | Color.GREEN:
155+
reveal_type(color) # revealed: Literal[Color.RED, Color.GREEN]
156+
case Color.BLUE:
157+
reveal_type(color) # revealed: Literal[Color.BLUE]
158+
159+
match color:
160+
case Color.RED | Color.GREEN | Color.BLUE:
161+
reveal_type(color) # revealed: Color
162+
163+
match color:
164+
case Color.RED:
165+
reveal_type(color) # revealed: Literal[Color.RED]
166+
case _:
167+
reveal_type(color) # revealed: Literal[Color.GREEN, Color.BLUE]
127168

128-
x = get_object()
169+
class A: ...
170+
class B: ...
171+
class C: ...
129172

130-
reveal_type(x) # revealed: object
173+
def _(x: A | B | C):
174+
match x:
175+
case A() | B():
176+
reveal_type(x) # revealed: A | B
177+
case C():
178+
reveal_type(x) # revealed: C & ~A & ~B
179+
case _:
180+
reveal_type(x) # revealed: Never
131181

132-
match x:
133-
case "foo" | 42 | None:
134-
reveal_type(x) # revealed: Literal["foo", 42] | None
135-
case "foo" | tuple():
136-
reveal_type(x) # revealed: tuple[Unknown, ...]
137-
case True | False:
138-
reveal_type(x) # revealed: bool
139-
case 3.14 | 2.718 | 1.414:
140-
reveal_type(x) # revealed: float
182+
match x:
183+
case A() | B() | C():
184+
reveal_type(x) # revealed: A | B | C
185+
case _:
186+
reveal_type(x) # revealed: Never
141187

142-
reveal_type(x) # revealed: object
188+
match x:
189+
case A():
190+
reveal_type(x) # revealed: A
191+
case _:
192+
reveal_type(x) # revealed: (B & ~A) | (C & ~A)
143193
```
144194

145195
## Or patterns with guard
146196

147197
```py
148-
def get_object() -> object:
149-
return object()
150-
151-
x = get_object()
152-
153-
reveal_type(x) # revealed: object
154-
155-
match x:
156-
case "foo" | 42 | None if reveal_type(x): # revealed: Literal["foo", 42] | None
157-
pass
158-
case "foo" | tuple() if reveal_type(x): # revealed: Literal["foo"] | tuple[Unknown, ...]
159-
pass
160-
case True | False if reveal_type(x): # revealed: bool
161-
pass
162-
case 3.14 | 2.718 | 1.414 if reveal_type(x): # revealed: float
163-
pass
198+
from typing import Literal
164199

165-
reveal_type(x) # revealed: object
200+
def _(x: Literal["foo", b"bar"] | int):
201+
match x:
202+
case "foo" | 42 if reveal_type(x): # revealed: Literal["foo"] | int
203+
pass
204+
case b"bar" if reveal_type(x): # revealed: Literal[b"bar"] | int
205+
pass
206+
case _ if reveal_type(x): # revealed: Literal["foo", b"bar"] | int
207+
pass
166208
```
167209

168210
## Narrowing due to guard
@@ -179,7 +221,7 @@ match x:
179221
case str() | float() if type(x) is str:
180222
reveal_type(x) # revealed: str
181223
case "foo" | 42 | None if isinstance(x, int):
182-
reveal_type(x) # revealed: Literal[42]
224+
reveal_type(x) # revealed: int
183225
case False if x:
184226
reveal_type(x) # revealed: Never
185227
case "foo" if x := "bar":
@@ -201,7 +243,7 @@ reveal_type(x) # revealed: object
201243
match x:
202244
case str() | float() if type(x) is str and reveal_type(x): # revealed: str
203245
pass
204-
case "foo" | 42 | None if isinstance(x, int) and reveal_type(x): # revealed: Literal[42]
246+
case "foo" | 42 | None if isinstance(x, int) and reveal_type(x): # revealed: int
205247
pass
206248
case False if x and reveal_type(x): # revealed: Never
207249
pass

0 commit comments

Comments
 (0)