Skip to content

Commit 85cbc33

Browse files
committed
[ty] Implement equivalence for protocols with method members
1 parent 667dc62 commit 85cbc33

File tree

4 files changed

+82
-28
lines changed

4 files changed

+82
-28
lines changed

crates/ty_python_semantic/resources/mdtest/protocols.md

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1470,8 +1470,7 @@ class P1(Protocol):
14701470
class P2(Protocol):
14711471
def x(self, y: int) -> None: ...
14721472

1473-
# TODO: this should pass
1474-
static_assert(is_equivalent_to(P1, P2)) # error: [static-assert-error]
1473+
static_assert(is_equivalent_to(P1, P2))
14751474
```
14761475

14771476
As with protocols that only have non-method members, this also holds true when they appear in
@@ -1481,8 +1480,7 @@ differently ordered unions:
14811480
class A: ...
14821481
class B: ...
14831482

1484-
# TODO: this should pass
1485-
static_assert(is_equivalent_to(A | B | P1, P2 | B | A)) # error: [static-assert-error]
1483+
static_assert(is_equivalent_to(A | B | P1, P2 | B | A))
14861484
```
14871485

14881486
## Narrowing of protocols
@@ -1878,6 +1876,65 @@ if isinstance(obj, (B, A)):
18781876
reveal_type(obj) # revealed: (Unknown & B) | (Unknown & A)
18791877
```
18801878

1879+
### Protocols that use `Self`
1880+
1881+
`Self` is a `TypeVar` with an upper bound of the class in which it is defined. This means that
1882+
`Self` annotations in protocols can also be tricky to handle without infinite recursion and stack
1883+
overflows.
1884+
1885+
```toml
1886+
[environment]
1887+
python-version = "3.12"
1888+
```
1889+
1890+
```py
1891+
from typing_extensions import Protocol, Self
1892+
from ty_extensions import static_assert
1893+
1894+
class _HashObject(Protocol):
1895+
def copy(self) -> Self: ...
1896+
1897+
class Foo: ...
1898+
1899+
# Attempting to build this union caused us to overflow on an early version of
1900+
# <https://github.com/astral-sh/ruff/pull/18659>
1901+
x: Foo | _HashObject
1902+
```
1903+
1904+
Some other similar cases that caused issues in our early `Protocol` implementation:
1905+
1906+
`a.py`:
1907+
1908+
```py
1909+
from typing_extensions import Protocol, Self
1910+
1911+
class PGconn(Protocol):
1912+
def connect(self) -> Self: ...
1913+
1914+
class Connection:
1915+
pgconn: PGconn
1916+
1917+
def is_crdb(conn: PGconn) -> bool:
1918+
return isinstance(conn, Connection)
1919+
```
1920+
1921+
and:
1922+
1923+
`b.py`:
1924+
1925+
```py
1926+
from typing_extensions import Protocol
1927+
1928+
class PGconn(Protocol):
1929+
def connect[T: PGconn](self: T) -> T: ...
1930+
1931+
class Connection:
1932+
pgconn: PGconn
1933+
1934+
def f(x: PGconn):
1935+
isinstance(x, Connection)
1936+
```
1937+
18811938
## TODO
18821939

18831940
Add tests for:

crates/ty_python_semantic/src/types.rs

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -789,18 +789,7 @@ impl<'db> Type<'db> {
789789
.copied()
790790
.any(|ty| ty.any_over_type(db, type_fn)),
791791

792-
Self::Callable(callable) => {
793-
let signatures = callable.signatures(db);
794-
signatures.iter().any(|signature| {
795-
signature.parameters().iter().any(|param| {
796-
param
797-
.annotated_type()
798-
.is_some_and(|ty| ty.any_over_type(db, type_fn))
799-
}) || signature
800-
.return_ty
801-
.is_some_and(|ty| ty.any_over_type(db, type_fn))
802-
})
803-
}
792+
Self::Callable(callable) => callable.any_over_type(db, type_fn),
804793

805794
Self::SubclassOf(subclass_of) => {
806795
Type::from(subclass_of.subclass_of()).any_over_type(db, type_fn)
@@ -1445,8 +1434,7 @@ impl<'db> Type<'db> {
14451434
.is_some_and(|instance| instance.has_relation_to(db, target, relation)),
14461435

14471436
(Type::FunctionLiteral(self_function_literal), Type::Callable(_)) => {
1448-
self_function_literal
1449-
.into_callable_type(db)
1437+
Type::Callable(self_function_literal.into_callable_type(db))
14501438
.has_relation_to(db, target, relation)
14511439
}
14521440

@@ -7309,6 +7297,18 @@ impl<'db> CallableType<'db> {
73097297
.signatures(db)
73107298
.is_equivalent_to(db, other.signatures(db))
73117299
}
7300+
7301+
fn any_over_type(self, db: &'db dyn Db, type_fn: &dyn Fn(Type<'db>) -> bool) -> bool {
7302+
self.signatures(db).iter().any(|signature| {
7303+
signature.parameters().iter().any(|param| {
7304+
param
7305+
.annotated_type()
7306+
.is_some_and(|ty| ty.any_over_type(db, type_fn))
7307+
}) || signature
7308+
.return_ty
7309+
.is_some_and(|ty| ty.any_over_type(db, type_fn))
7310+
})
7311+
}
73127312
}
73137313

73147314
/// Represents a specific instance of `types.MethodWrapperType`

crates/ty_python_semantic/src/types/function.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -745,8 +745,8 @@ impl<'db> FunctionType<'db> {
745745
}
746746

747747
/// Convert the `FunctionType` into a [`Type::Callable`].
748-
pub(crate) fn into_callable_type(self, db: &'db dyn Db) -> Type<'db> {
749-
Type::Callable(CallableType::new(db, self.signature(db), false))
748+
pub(crate) fn into_callable_type(self, db: &'db dyn Db) -> CallableType<'db> {
749+
CallableType::new(db, self.signature(db), false)
750750
}
751751

752752
/// Convert the `FunctionType` into a [`Type::BoundMethod`].

crates/ty_python_semantic/src/types/protocol_class.rs

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ impl<'db> ProtocolMemberData<'db> {
255255

256256
#[derive(Debug, Copy, Clone, PartialEq, Eq, salsa::Update, Hash)]
257257
enum ProtocolMemberKind<'db> {
258-
Method(Type<'db>), // TODO: use CallableType
258+
Method(CallableType<'db>),
259259
Property(PropertyInstanceType<'db>),
260260
Other(Type<'db>),
261261
}
@@ -335,7 +335,7 @@ impl<'a, 'db> ProtocolMember<'a, 'db> {
335335

336336
fn ty(&self) -> Type<'db> {
337337
match &self.kind {
338-
ProtocolMemberKind::Method(callable) => *callable,
338+
ProtocolMemberKind::Method(callable) => Type::Callable(*callable),
339339
ProtocolMemberKind::Property(property) => Type::PropertyInstance(*property),
340340
ProtocolMemberKind::Other(ty) => *ty,
341341
}
@@ -490,13 +490,10 @@ fn cached_protocol_interface<'db>(
490490
(Type::Callable(callable), BoundOnClass::Yes)
491491
if callable.is_function_like(db) =>
492492
{
493-
ProtocolMemberKind::Method(ty)
493+
ProtocolMemberKind::Method(callable)
494494
}
495-
// TODO: method members that have `FunctionLiteral` types should be upcast
496-
// to `CallableType` so that two protocols with identical method members
497-
// are recognized as equivalent.
498-
(Type::FunctionLiteral(_function), BoundOnClass::Yes) => {
499-
ProtocolMemberKind::Method(ty)
495+
(Type::FunctionLiteral(function), BoundOnClass::Yes) => {
496+
ProtocolMemberKind::Method(function.into_callable_type(db))
500497
}
501498
_ => ProtocolMemberKind::Other(ty),
502499
};

0 commit comments

Comments
 (0)