Skip to content

Commit e55bc94

Browse files
authored
[ty] Reachability and narrowing for enum methods (#21130)
## Summary Adds proper type narrowing and reachability analysis for matching on non-inferable type variables bound to enums. For example: ```py from enum import Enum class Answer(Enum): NO = 0 YES = 1 def is_yes(self) -> bool: # no error here! match self: case Answer.YES: return True case Answer.NO: return False ``` closes astral-sh/ty#1404 ## Test Plan Added regression tests
1 parent 1b0ee46 commit e55bc94

File tree

4 files changed

+121
-5
lines changed

4 files changed

+121
-5
lines changed

crates/ty_python_semantic/resources/mdtest/exhaustiveness_checking.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,3 +379,22 @@ def as_pattern_non_exhaustive(subject: int | str):
379379
# this diagnostic is correct: the inferred type of `subject` is `str`
380380
assert_never(subject) # error: [type-assertion-failure]
381381
```
382+
383+
## Exhaustiveness checking for methods of enums
384+
385+
```py
386+
from enum import Enum
387+
388+
class Answer(Enum):
389+
YES = "yes"
390+
NO = "no"
391+
392+
def is_yes(self) -> bool:
393+
reveal_type(self) # revealed: Self@is_yes
394+
395+
match self:
396+
case Answer.YES:
397+
return True
398+
case Answer.NO:
399+
return False
400+
```

crates/ty_python_semantic/resources/mdtest/narrow/match.md

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,3 +252,51 @@ match x:
252252

253253
reveal_type(x) # revealed: object
254254
```
255+
256+
## Narrowing on `Self` in `match` statements
257+
258+
When performing narrowing on `self` inside methods on enums, we take into account that `Self` might
259+
refer to a subtype of the enum class, like `Literal[Answer.YES]`. This is why we do not simplify
260+
`Self & ~Literal[Answer.YES]` to `Literal[Answer.NO, Answer.MAYBE]`. Otherwise, we wouldn't be able
261+
to return `self` in the `assert_yes` method below:
262+
263+
```py
264+
from enum import Enum
265+
from typing_extensions import Self, assert_never
266+
267+
class Answer(Enum):
268+
NO = 0
269+
YES = 1
270+
MAYBE = 2
271+
272+
def is_yes(self) -> bool:
273+
reveal_type(self) # revealed: Self@is_yes
274+
275+
match self:
276+
case Answer.YES:
277+
reveal_type(self) # revealed: Self@is_yes
278+
return True
279+
case Answer.NO | Answer.MAYBE:
280+
reveal_type(self) # revealed: Self@is_yes & ~Literal[Answer.YES]
281+
return False
282+
case _:
283+
assert_never(self) # no error
284+
285+
def assert_yes(self) -> Self:
286+
reveal_type(self) # revealed: Self@assert_yes
287+
288+
match self:
289+
case Answer.YES:
290+
reveal_type(self) # revealed: Self@assert_yes
291+
return self
292+
case _:
293+
reveal_type(self) # revealed: Self@assert_yes & ~Literal[Answer.YES]
294+
raise ValueError("Answer is not YES")
295+
296+
Answer.YES.is_yes()
297+
298+
try:
299+
reveal_type(Answer.MAYBE.assert_yes()) # revealed: Literal[Answer.MAYBE]
300+
except ValueError:
301+
pass
302+
```

crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -802,10 +802,27 @@ impl ReachabilityConstraints {
802802
fn analyze_single_pattern_predicate(db: &dyn Db, predicate: PatternPredicate) -> Truthiness {
803803
let subject_ty = infer_expression_type(db, predicate.subject(db), TypeContext::default());
804804

805-
let narrowed_subject_ty = IntersectionBuilder::new(db)
805+
let narrowed_subject = IntersectionBuilder::new(db)
806806
.add_positive(subject_ty)
807-
.add_negative(type_excluded_by_previous_patterns(db, predicate))
807+
.add_negative(type_excluded_by_previous_patterns(db, predicate));
808+
809+
let narrowed_subject_ty = narrowed_subject.clone().build();
810+
811+
// Consider a case where we match on a subject type of `Self` with an upper bound of `Answer`,
812+
// where `Answer` is a {YES, NO} enum. After a previous pattern matching on `NO`, the narrowed
813+
// subject type is `Self & ~Literal[NO]`. This type is *not* equivalent to `Literal[YES]`,
814+
// because `Self` could also specialize to `Literal[NO]` or `Never`, making the intersection
815+
// empty. However, if the current pattern matches on `YES`, the *next* narrowed subject type
816+
// will be `Self & ~Literal[NO] & ~Literal[YES]`, which *is* always equivalent to `Never`. This
817+
// means that subsequent patterns can never match. And we know that if we reach this point,
818+
// the current pattern will have to match. We return `AlwaysTrue` here, since the call to
819+
// `analyze_single_pattern_predicate_kind` below would return `Ambiguous` in this case.
820+
let next_narrowed_subject_ty = narrowed_subject
821+
.add_negative(pattern_kind_to_type(db, predicate.kind(db)))
808822
.build();
823+
if !narrowed_subject_ty.is_never() && next_narrowed_subject_ty.is_never() {
824+
return Truthiness::AlwaysTrue;
825+
}
809826

810827
let truthiness = Self::analyze_single_pattern_predicate_kind(
811828
db,

crates/ty_python_semantic/src/types/builder.rs

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ use crate::types::{
4444
TypeVarBoundOrConstraints, UnionType,
4545
};
4646
use crate::{Db, FxOrderSet};
47+
use rustc_hash::FxHashSet;
4748
use smallvec::SmallVec;
4849

4950
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
@@ -422,9 +423,9 @@ impl<'db> UnionBuilder<'db> {
422423
.iter()
423424
.filter_map(UnionElement::to_type_element)
424425
.filter_map(Type::as_enum_literal)
425-
.map(|literal| literal.name(self.db).clone())
426-
.chain(std::iter::once(enum_member_to_add.name(self.db).clone()))
427-
.collect::<FxOrderSet<_>>();
426+
.map(|literal| literal.name(self.db))
427+
.chain(std::iter::once(enum_member_to_add.name(self.db)))
428+
.collect::<FxHashSet<_>>();
428429

429430
let all_members_are_in_union = metadata
430431
.members
@@ -780,6 +781,37 @@ impl<'db> IntersectionBuilder<'db> {
780781
seen_aliases,
781782
)
782783
}
784+
Type::EnumLiteral(enum_literal) => {
785+
let enum_class = enum_literal.enum_class(self.db);
786+
let metadata =
787+
enum_metadata(self.db, enum_class).expect("Class of enum literal is an enum");
788+
789+
let enum_members_in_negative_part = self
790+
.intersections
791+
.iter()
792+
.flat_map(|intersection| &intersection.negative)
793+
.filter_map(|ty| ty.as_enum_literal())
794+
.filter(|lit| lit.enum_class(self.db) == enum_class)
795+
.map(|lit| lit.name(self.db))
796+
.chain(std::iter::once(enum_literal.name(self.db)))
797+
.collect::<FxHashSet<_>>();
798+
799+
let all_members_are_in_negative_part = metadata
800+
.members
801+
.keys()
802+
.all(|name| enum_members_in_negative_part.contains(name));
803+
804+
if all_members_are_in_negative_part {
805+
for inner in &mut self.intersections {
806+
inner.add_negative(self.db, enum_literal.enum_class_instance(self.db));
807+
}
808+
} else {
809+
for inner in &mut self.intersections {
810+
inner.add_negative(self.db, ty);
811+
}
812+
}
813+
self
814+
}
783815
_ => {
784816
for inner in &mut self.intersections {
785817
inner.add_negative(self.db, ty);

0 commit comments

Comments
 (0)