Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -182,13 +182,21 @@ def match_non_exhaustive(x: Color):

## `isinstance` checks

```toml
[environment]
python-version = "3.12"
```

```py
from typing import assert_never

class A: ...
class B: ...
class C: ...

class GenericClass[T]:
x: T

def if_else_exhaustive(x: A | B | C):
if isinstance(x, A):
pass
Expand Down Expand Up @@ -253,6 +261,17 @@ def match_non_exhaustive(x: A | B | C):

# this diagnostic is correct: the inferred type of `x` is `B & ~A & ~C`
assert_never(x) # error: [type-assertion-failure]

# Note: no invalid-return-type diagnostic; the `match` is exhaustive
def match_exhaustive_generic[T](obj: GenericClass[T]) -> GenericClass[T]:
match obj:
case GenericClass(x=42):
reveal_type(obj) # revealed: GenericClass[T@match_exhaustive_generic]
return obj
case GenericClass(x=x):
reveal_type(x) # revealed: @Todo(`match` pattern definition types)
reveal_type(obj) # revealed: GenericClass[T@match_exhaustive_generic]
return obj
```

## `isinstance` checks with generics
Expand Down
75 changes: 75 additions & 0 deletions crates/ty_python_semantic/resources/mdtest/narrow/match.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,81 @@ match x:
reveal_type(x) # revealed: object
```

## Class patterns with generic classes

```toml
[environment]
python-version = "3.12"
```

```py
from typing import assert_never

class Covariant[T]:
def get(self) -> T:
raise NotImplementedError

def f(x: Covariant[int]):
match x:
case Covariant():
reveal_type(x) # revealed: Covariant[int]
case _:
reveal_type(x) # revealed: Never
assert_never(x)
```

## Class patterns with generic `@final` classes

These work the same as non-`@final` classes.

```toml
[environment]
python-version = "3.12"
```

```py
from typing import assert_never, final

@final
class Covariant[T]:
def get(self) -> T:
raise NotImplementedError

def f(x: Covariant[int]):
match x:
case Covariant():
reveal_type(x) # revealed: Covariant[int]
case _:
reveal_type(x) # revealed: Never
assert_never(x)
```

## Class patterns where the class pattern does not resolve to a class

In general this does not allow for narrowing, but we make an exception for `Any`. This is to support
[real ecosystem code](https://github.com/jax-ml/jax/blob/d2ce04b6c3d03ae18b145965b8b8b92e09e8009c/jax/_src/pallas/mosaic_gpu/lowering.py#L3372-L3387)
found in `jax`.

```py
from typing import Any

X = Any

def f(obj: object):
match obj:
case int():
reveal_type(obj) # revealed: int
case X():
reveal_type(obj) # revealed: Any & ~int

def g(obj: object, Y: Any):
match obj:
case int():
reveal_type(obj) # revealed: int
case Y():
reveal_type(obj) # revealed: Any & ~int
```

## Value patterns

Value patterns are evaluated by equality, which is overridable. Therefore successfully matching on
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -771,8 +771,9 @@ impl ReachabilityConstraints {
truthiness
}
PatternPredicateKind::Class(class_expr, kind) => {
let class_ty =
infer_expression_type(db, *class_expr, TypeContext::default()).to_instance(db);
let class_ty = infer_expression_type(db, *class_expr, TypeContext::default())
.as_class_literal()
.map(|class| Type::instance(db, class.top_materialization(db)));

class_ty.map_or(Truthiness::Ambiguous, |class_ty| {
if subject_ty.is_subtype_of(db, class_ty) {
Expand Down
22 changes: 16 additions & 6 deletions crates/ty_python_semantic/src/types/narrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ use crate::types::enums::{enum_member_literals, enum_metadata};
use crate::types::function::KnownFunction;
use crate::types::infer::infer_same_file_expression_type;
use crate::types::{
ClassLiteral, ClassType, IntersectionBuilder, KnownClass, SubclassOfInner, SubclassOfType,
Truthiness, Type, TypeContext, TypeVarBoundOrConstraints, UnionBuilder, infer_expression_types,
ClassLiteral, ClassType, IntersectionBuilder, KnownClass, SpecialFormType, SubclassOfInner,
SubclassOfType, Truthiness, Type, TypeContext, TypeVarBoundOrConstraints, UnionBuilder,
infer_expression_types,
};

use ruff_db::parsed::{ParsedModuleRef, parsed_module};
Expand Down Expand Up @@ -962,11 +963,20 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
let subject = place_expr(subject.node_ref(self.db, self.module))?;
let place = self.expect_place(&subject);

let ty = infer_same_file_expression_type(self.db, cls, TypeContext::default(), self.module)
.to_instance(self.db)?;
let class_type =
infer_same_file_expression_type(self.db, cls, TypeContext::default(), self.module);

let ty = ty.negate_if(self.db, !is_positive);
Some(NarrowingConstraints::from_iter([(place, ty)]))
let narrowed_type = match class_type {
Type::ClassLiteral(class) => {
Type::instance(self.db, class.top_materialization(self.db))
.negate_if(self.db, !is_positive)
}
dynamic @ Type::Dynamic(_) => dynamic,
Type::SpecialForm(SpecialFormType::Any) => Type::any(),
_ => return None,
};

Some(NarrowingConstraints::from_iter([(place, narrowed_type)]))
}

fn evaluate_match_pattern_value(
Expand Down
Loading