Skip to content

Commit 2a00eca

Browse files
authored
[ty] Exhaustiveness checking & reachability for match statements (#19508)
## Summary Implements proper reachability analysis and — in effect — exhaustiveness checking for `match` statements. This allows us to check the following code without any errors (leads to *"can implicitly return `None`"* on `main`): ```py from enum import Enum, auto class Color(Enum): RED = auto() GREEN = auto() BLUE = auto() def hex(color: Color) -> str: match color: case Color.RED: return "#ff0000" case Color.GREEN: return "#00ff00" case Color.BLUE: return "#0000ff" ``` Note that code like this already worked fine if there was a `assert_never(color)` statement in a catch-all case, because we would then consider that `assert_never` call terminal. But now this also works without the wildcard case. Adding a member to the enum would still lead to an error here, if that case would not be handled in `hex`. What needed to happen to support this is a new way of evaluating match pattern constraints. Previously, we would simply compare the type of the subject expression against the patterns. For the last case here, the subject type would still be `Color` and the value type would be `Literal[Color.BLUE]`, so we would infer an ambiguous truthiness. Now, before we compare the subject type against the pattern, we first generate a union type that corresponds to the set of all values that would have *definitely been matched* by previous patterns. Then, we build a "narrowed" subject type by computing `subject_type & ~already_matched_type`, and compare *that* against the pattern type. For the example here, `already_matched_type = Literal[Color.RED] | Literal[Color.GREEN]`, and so we have a narrowed subject type of `Color & ~(Literal[Color.RED] | Literal[Color.GREEN]) = Literal[Color.BLUE]`, which allows us to infer a reachability of `AlwaysTrue`. <details> <summary>A note on negated reachability constraints</summary> It might seem that we now perform duplicate work, because we also record *negated* reachability constraints. But that is still important for cases like the following (and possibly also for more realistic scenarios): ```py from typing import Literal def _(x: int | str): match x: case None: pass # never reachable case _: y = 1 y ``` </details> closes astral-sh/ty#99 ## Test Plan * I verified that this solves all examples from the linked ticket (the first example needs a PEP 695 type alias, because we don't support legacy type aliases yet) * Verified that the ecosystem changes are all because of removed false positives * Updated tests
1 parent 3d17897 commit 2a00eca

File tree

7 files changed

+109
-49
lines changed

7 files changed

+109
-49
lines changed

crates/ty_python_semantic/resources/mdtest/conditional/match.md

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,7 @@ def _(target: Literal[True, False]):
201201
case None:
202202
y = 4
203203

204-
# TODO: with exhaustiveness checking, this should be Literal[2, 3]
205-
reveal_type(y) # revealed: Literal[1, 2, 3]
204+
reveal_type(y) # revealed: Literal[2, 3]
206205

207206
def _(target: bool):
208207
y = 1
@@ -215,8 +214,7 @@ def _(target: bool):
215214
case None:
216215
y = 4
217216

218-
# TODO: with exhaustiveness checking, this should be Literal[2, 3]
219-
reveal_type(y) # revealed: Literal[1, 2, 3]
217+
reveal_type(y) # revealed: Literal[2, 3]
220218

221219
def _(target: None):
222220
y = 1
@@ -242,8 +240,7 @@ def _(target: None | Literal[True]):
242240
case None:
243241
y = 4
244242

245-
# TODO: with exhaustiveness checking, this should be Literal[2, 4]
246-
reveal_type(y) # revealed: Literal[1, 2, 4]
243+
reveal_type(y) # revealed: Literal[2, 4]
247244

248245
# bool is an int subclass
249246
def _(target: int):
@@ -292,7 +289,7 @@ def _(answer: Answer):
292289
reveal_type(answer) # revealed: Literal[Answer.NO]
293290
y = 2
294291

295-
reveal_type(y) # revealed: Literal[0, 1, 2]
292+
reveal_type(y) # revealed: Literal[1, 2]
296293
```
297294

298295
## Or match
@@ -311,8 +308,7 @@ def _(target: Literal["foo", "baz"]):
311308
case "baz":
312309
y = 3
313310

314-
# TODO: with exhaustiveness, this should be Literal[2, 3]
315-
reveal_type(y) # revealed: Literal[1, 2, 3]
311+
reveal_type(y) # revealed: Literal[2, 3]
316312

317313
def _(target: None):
318314
y = 1

crates/ty_python_semantic/resources/mdtest/directives/assert_never.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,6 @@ def match_singletons_success(obj: Literal[1, "a"] | None):
119119
case None:
120120
pass
121121
case _ as obj:
122-
# TODO: Ideally, we would not emit an error here
123-
# error: [type-assertion-failure] "Argument does not have asserted type `Never`"
124122
assert_never(obj)
125123

126124
def match_singletons_error(obj: Literal[1, "a"] | None):

crates/ty_python_semantic/resources/mdtest/enums.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -720,8 +720,6 @@ def color_name(color: Color) -> str:
720720
case _:
721721
assert_never(color)
722722

723-
# TODO: this should not be an error, see https://github.com/astral-sh/ty/issues/99#issuecomment-2983054488
724-
# error: [invalid-return-type] "Function can implicitly return `None`, which is not assignable to return type `str`"
725723
def color_name_without_assertion(color: Color) -> str:
726724
match color:
727725
case Color.RED:

crates/ty_python_semantic/resources/mdtest/exhaustiveness_checking.md

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,10 @@ def match_exhaustive(x: Literal[0, 1, "a"]):
5050
case "a":
5151
pass
5252
case _:
53-
# TODO: this should not be an error
54-
no_diagnostic_here # error: [unresolved-reference]
53+
no_diagnostic_here
5554

5655
assert_never(x)
5756

58-
# TODO: there should be no error here
59-
# error: [invalid-return-type] "Function can implicitly return `None`, which is not assignable to return type `int`"
6057
def match_exhaustive_no_assertion(x: Literal[0, 1, "a"]) -> int:
6158
match x:
6259
case 0:
@@ -130,13 +127,21 @@ def match_exhaustive(x: Color):
130127
case Color.BLUE:
131128
pass
132129
case _:
133-
# TODO: this should not be an error
134-
no_diagnostic_here # error: [unresolved-reference]
130+
no_diagnostic_here
131+
132+
assert_never(x)
133+
134+
def match_exhaustive_2(x: Color):
135+
match x:
136+
case Color.RED:
137+
pass
138+
case Color.GREEN | Color.BLUE:
139+
pass
140+
case _:
141+
no_diagnostic_here
135142

136143
assert_never(x)
137144

138-
# TODO: there should be no error here
139-
# error: [invalid-return-type] "Function can implicitly return `None`, which is not assignable to return type `int`"
140145
def match_exhaustive_no_assertion(x: Color) -> int:
141146
match x:
142147
case Color.RED:
@@ -208,13 +213,10 @@ def match_exhaustive(x: A | B | C):
208213
case C():
209214
pass
210215
case _:
211-
# TODO: this should not be an error
212-
no_diagnostic_here # error: [unresolved-reference]
216+
no_diagnostic_here
213217

214218
assert_never(x)
215219

216-
# TODO: there should be no error here
217-
# error: [invalid-return-type] "Function can implicitly return `None`, which is not assignable to return type `int`"
218220
def match_exhaustive_no_assertion(x: A | B | C) -> int:
219221
match x:
220222
case A():

crates/ty_python_semantic/src/semantic_index/builder.rs

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -734,7 +734,8 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
734734
subject: Expression<'db>,
735735
pattern: &ast::Pattern,
736736
guard: Option<&ast::Expr>,
737-
) -> PredicateOrLiteral<'db> {
737+
previous_pattern: Option<PatternPredicate<'db>>,
738+
) -> (PredicateOrLiteral<'db>, PatternPredicate<'db>) {
738739
// This is called for the top-level pattern of each match arm. We need to create a
739740
// standalone expression for each arm of a match statement, since they can introduce
740741
// constraints on the match subject. (Or more accurately, for the match arm's pattern,
@@ -756,13 +757,14 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
756757
subject,
757758
kind,
758759
guard,
760+
previous_pattern.map(Box::new),
759761
);
760762
let predicate = PredicateOrLiteral::Predicate(Predicate {
761763
node: PredicateNode::Pattern(pattern_predicate),
762764
is_positive: true,
763765
});
764766
self.record_narrowing_constraint(predicate);
765-
predicate
767+
(predicate, pattern_predicate)
766768
}
767769

768770
/// Record an expression that needs to be a Salsa ingredient, because we need to infer its type
@@ -1747,7 +1749,7 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> {
17471749
.is_some_and(|case| case.guard.is_none() && case.pattern.is_wildcard());
17481750

17491751
let mut post_case_snapshots = vec![];
1750-
let mut match_predicate;
1752+
let mut previous_pattern: Option<PatternPredicate<'_>> = None;
17511753

17521754
for (i, case) in cases.iter().enumerate() {
17531755
self.current_match_case = Some(CurrentMatchCase::new(&case.pattern));
@@ -1757,11 +1759,14 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> {
17571759
// here because the effects of visiting a pattern is binding
17581760
// symbols, and this doesn't occur unless the pattern
17591761
// actually matches
1760-
match_predicate = self.add_pattern_narrowing_constraint(
1761-
subject_expr,
1762-
&case.pattern,
1763-
case.guard.as_deref(),
1764-
);
1762+
let (match_predicate, match_pattern_predicate) = self
1763+
.add_pattern_narrowing_constraint(
1764+
subject_expr,
1765+
&case.pattern,
1766+
case.guard.as_deref(),
1767+
previous_pattern,
1768+
);
1769+
previous_pattern = Some(match_pattern_predicate);
17651770
let reachability_constraint =
17661771
self.record_reachability_constraint(match_predicate);
17671772

crates/ty_python_semantic/src/semantic_index/predicate.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,9 @@ pub(crate) struct PatternPredicate<'db> {
150150
pub(crate) kind: PatternPredicateKind<'db>,
151151

152152
pub(crate) guard: Option<Expression<'db>>,
153+
154+
/// A reference to the pattern of the previous match case
155+
pub(crate) previous_predicate: Option<Box<PatternPredicate<'db>>>,
153156
}
154157

155158
// The Salsa heap is tracked separately.

crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs

Lines changed: 74 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -202,13 +202,14 @@ use crate::Db;
202202
use crate::dunder_all::dunder_all_names;
203203
use crate::place::{RequiresExplicitReExport, imported_symbol};
204204
use crate::rank::RankBitBox;
205-
use crate::semantic_index::expression::Expression;
206205
use crate::semantic_index::place_table;
207206
use crate::semantic_index::predicate::{
208207
CallableAndCallExpr, PatternPredicate, PatternPredicateKind, Predicate, PredicateNode,
209208
Predicates, ScopedPredicateId,
210209
};
211-
use crate::types::{Truthiness, Type, infer_expression_type};
210+
use crate::types::{
211+
IntersectionBuilder, Truthiness, Type, UnionBuilder, UnionType, infer_expression_type,
212+
};
212213

213214
/// A ternary formula that defines under what conditions a binding is visible. (A ternary formula
214215
/// is just like a boolean formula, but with `Ambiguous` as a third potential result. See the
@@ -311,6 +312,55 @@ const AMBIGUOUS: ScopedReachabilityConstraintId = ScopedReachabilityConstraintId
311312
const ALWAYS_FALSE: ScopedReachabilityConstraintId = ScopedReachabilityConstraintId::ALWAYS_FALSE;
312313
const SMALLEST_TERMINAL: ScopedReachabilityConstraintId = ALWAYS_FALSE;
313314

315+
fn singleton_to_type(db: &dyn Db, singleton: ruff_python_ast::Singleton) -> Type<'_> {
316+
let ty = match singleton {
317+
ruff_python_ast::Singleton::None => Type::none(db),
318+
ruff_python_ast::Singleton::True => Type::BooleanLiteral(true),
319+
ruff_python_ast::Singleton::False => Type::BooleanLiteral(false),
320+
};
321+
debug_assert!(ty.is_singleton(db));
322+
ty
323+
}
324+
325+
/// Turn a `match` pattern kind into a type that represents the set of all values that would definitely
326+
/// match that pattern.
327+
fn pattern_kind_to_type<'db>(db: &'db dyn Db, kind: &PatternPredicateKind<'db>) -> Type<'db> {
328+
match kind {
329+
PatternPredicateKind::Singleton(singleton) => singleton_to_type(db, *singleton),
330+
PatternPredicateKind::Value(value) => infer_expression_type(db, *value),
331+
PatternPredicateKind::Class(class_expr, kind) => {
332+
if kind.is_irrefutable() {
333+
infer_expression_type(db, *class_expr)
334+
.to_instance(db)
335+
.unwrap_or(Type::Never)
336+
} else {
337+
Type::Never
338+
}
339+
}
340+
PatternPredicateKind::Or(predicates) => {
341+
UnionType::from_elements(db, predicates.iter().map(|p| pattern_kind_to_type(db, p)))
342+
}
343+
PatternPredicateKind::Unsupported => Type::Never,
344+
}
345+
}
346+
347+
/// Go through the list of previous match cases, and accumulate a union of all types that were already
348+
/// matched by these patterns.
349+
fn type_excluded_by_previous_patterns<'db>(
350+
db: &'db dyn Db,
351+
mut predicate: PatternPredicate<'db>,
352+
) -> Type<'db> {
353+
let mut builder = UnionBuilder::new(db);
354+
while let Some(previous) = predicate.previous_predicate(db) {
355+
predicate = *previous;
356+
357+
if predicate.guard(db).is_none() {
358+
builder = builder.add(pattern_kind_to_type(db, predicate.kind(db)));
359+
}
360+
}
361+
builder.build()
362+
}
363+
314364
/// A collection of reachability constraints for a given scope.
315365
#[derive(Debug, PartialEq, Eq, salsa::Update, get_size2::GetSize)]
316366
pub(crate) struct ReachabilityConstraints {
@@ -637,11 +687,10 @@ impl ReachabilityConstraints {
637687
fn analyze_single_pattern_predicate_kind<'db>(
638688
db: &'db dyn Db,
639689
predicate_kind: &PatternPredicateKind<'db>,
640-
subject: Expression<'db>,
690+
subject_ty: Type<'db>,
641691
) -> Truthiness {
642692
match predicate_kind {
643693
PatternPredicateKind::Value(value) => {
644-
let subject_ty = infer_expression_type(db, subject);
645694
let value_ty = infer_expression_type(db, *value);
646695

647696
if subject_ty.is_single_valued(db) {
@@ -651,15 +700,7 @@ impl ReachabilityConstraints {
651700
}
652701
}
653702
PatternPredicateKind::Singleton(singleton) => {
654-
let subject_ty = infer_expression_type(db, subject);
655-
656-
let singleton_ty = match singleton {
657-
ruff_python_ast::Singleton::None => Type::none(db),
658-
ruff_python_ast::Singleton::True => Type::BooleanLiteral(true),
659-
ruff_python_ast::Singleton::False => Type::BooleanLiteral(false),
660-
};
661-
662-
debug_assert!(singleton_ty.is_singleton(db));
703+
let singleton_ty = singleton_to_type(db, *singleton);
663704

664705
if subject_ty.is_equivalent_to(db, singleton_ty) {
665706
Truthiness::AlwaysTrue
@@ -671,10 +712,21 @@ impl ReachabilityConstraints {
671712
}
672713
PatternPredicateKind::Or(predicates) => {
673714
use std::ops::ControlFlow;
715+
716+
let mut excluded_types = vec![];
674717
let (ControlFlow::Break(truthiness) | ControlFlow::Continue(truthiness)) =
675718
predicates
676719
.iter()
677-
.map(|p| Self::analyze_single_pattern_predicate_kind(db, p, subject))
720+
.map(|p| {
721+
let narrowed_subject_ty = IntersectionBuilder::new(db)
722+
.add_positive(subject_ty)
723+
.add_negative(UnionType::from_elements(db, excluded_types.iter()))
724+
.build();
725+
726+
excluded_types.push(pattern_kind_to_type(db, p));
727+
728+
Self::analyze_single_pattern_predicate_kind(db, p, narrowed_subject_ty)
729+
})
678730
// this is just a "max", but with a slight optimization: `AlwaysTrue` is the "greatest" possible element, so we short-circuit if we get there
679731
.try_fold(Truthiness::AlwaysFalse, |acc, next| match (acc, next) {
680732
(Truthiness::AlwaysTrue, _) | (_, Truthiness::AlwaysTrue) => {
@@ -690,7 +742,6 @@ impl ReachabilityConstraints {
690742
truthiness
691743
}
692744
PatternPredicateKind::Class(class_expr, kind) => {
693-
let subject_ty = infer_expression_type(db, subject);
694745
let class_ty = infer_expression_type(db, *class_expr).to_instance(db);
695746

696747
class_ty.map_or(Truthiness::Ambiguous, |class_ty| {
@@ -715,10 +766,17 @@ impl ReachabilityConstraints {
715766
}
716767

717768
fn analyze_single_pattern_predicate(db: &dyn Db, predicate: PatternPredicate) -> Truthiness {
769+
let subject_ty = infer_expression_type(db, predicate.subject(db));
770+
771+
let narrowed_subject_ty = IntersectionBuilder::new(db)
772+
.add_positive(subject_ty)
773+
.add_negative(type_excluded_by_previous_patterns(db, predicate))
774+
.build();
775+
718776
let truthiness = Self::analyze_single_pattern_predicate_kind(
719777
db,
720778
predicate.kind(db),
721-
predicate.subject(db),
779+
narrowed_subject_ty,
722780
);
723781

724782
if truthiness == Truthiness::AlwaysTrue && predicate.guard(db).is_some() {

0 commit comments

Comments
 (0)