Skip to content

Commit a663796

Browse files
authored
[ty] Implement equivalence for protocols with method members (#18659)
## Summary This PR implements the following pieces of `Protocol` semantics: 1. A protocol with a method member that does not have a fully static signature should not be considered fully static. I.e., this protocol is not fully static because `Foo.x` has no return type; we previously incorrectly considered that it was: ```py class Foo(Protocol): def f(self): ... ``` 2. Two protocols `P1` and `P2`, both with method members `x`, should be considered equivalent if the signature of `P1.x` is equivalent to the signature of `P2.x`. Currently we do not recognize this. Implementing these semantics requires distinguishing between method members and non-method members. The stored type of a method member must be eagerly upcast to a `Callable` type when collecting the protocol's interface: doing otherwise would mean that it would be hard to implement equivalence of protocols even in the face of differently ordered unions, since the two equivalent protocols would have different Salsa IDs even when normalized. The semantics implemented by this PR are that we consider something a method member if: 1. It is accessible on the class itself; and 2. It is a function-like callable: a callable type that also has a `__get__` method, meaning it can be used as a method when accessed on instances. Note that the spec has complicated things to say about classmethod members and staticmethod members. These semantics are not implemented by this PR; they are all deferred for now. The infrastructure added in this PR fixes bugs in its own right, but also lays the groundwork for implementing subtyping and assignability rules for method members of protocols. A (currently failing) test is added to verify this. ## Test Plan mdtests
1 parent c15aa57 commit a663796

File tree

8 files changed

+146
-21
lines changed

8 files changed

+146
-21
lines changed

crates/ty_python_semantic/resources/mdtest/protocols.md

Lines changed: 82 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1476,8 +1476,7 @@ class P1(Protocol):
14761476
class P2(Protocol):
14771477
def x(self, y: int) -> None: ...
14781478

1479-
# TODO: this should pass
1480-
static_assert(is_equivalent_to(P1, P2)) # error: [static-assert-error]
1479+
static_assert(is_equivalent_to(P1, P2))
14811480
```
14821481

14831482
As with protocols that only have non-method members, this also holds true when they appear in
@@ -1487,8 +1486,7 @@ differently ordered unions:
14871486
class A: ...
14881487
class B: ...
14891488

1490-
# TODO: this should pass
1491-
static_assert(is_equivalent_to(A | B | P1, P2 | B | A)) # error: [static-assert-error]
1489+
static_assert(is_equivalent_to(A | B | P1, P2 | B | A))
14921490
```
14931491

14941492
## Narrowing of protocols
@@ -1896,6 +1894,86 @@ if isinstance(obj, (B, A)):
18961894
reveal_type(obj) # revealed: (Unknown & B) | (Unknown & A)
18971895
```
18981896

1897+
### Protocols that use `Self`
1898+
1899+
`Self` is a `TypeVar` with an upper bound of the class in which it is defined. This means that
1900+
`Self` annotations in protocols can also be tricky to handle without infinite recursion and stack
1901+
overflows.
1902+
1903+
```toml
1904+
[environment]
1905+
python-version = "3.12"
1906+
```
1907+
1908+
```py
1909+
from typing_extensions import Protocol, Self
1910+
from ty_extensions import static_assert
1911+
1912+
class _HashObject(Protocol):
1913+
def copy(self) -> Self: ...
1914+
1915+
class Foo: ...
1916+
1917+
# Attempting to build this union caused us to overflow on an early version of
1918+
# <https://github.com/astral-sh/ruff/pull/18659>
1919+
x: Foo | _HashObject
1920+
```
1921+
1922+
Some other similar cases that caused issues in our early `Protocol` implementation:
1923+
1924+
`a.py`:
1925+
1926+
```py
1927+
from typing_extensions import Protocol, Self
1928+
1929+
class PGconn(Protocol):
1930+
def connect(self) -> Self: ...
1931+
1932+
class Connection:
1933+
pgconn: PGconn
1934+
1935+
def is_crdb(conn: PGconn) -> bool:
1936+
return isinstance(conn, Connection)
1937+
```
1938+
1939+
and:
1940+
1941+
`b.py`:
1942+
1943+
```py
1944+
from typing_extensions import Protocol
1945+
1946+
class PGconn(Protocol):
1947+
def connect[T: PGconn](self: T) -> T: ...
1948+
1949+
class Connection:
1950+
pgconn: PGconn
1951+
1952+
def f(x: PGconn):
1953+
isinstance(x, Connection)
1954+
```
1955+
1956+
### Recursive protocols used as the first argument to `cast()`
1957+
1958+
These caused issues in an early version of our `Protocol` implementation due to the fact that we use
1959+
a recursive function in our `cast()` implementation to check whether a type contains `Unknown` or
1960+
`Todo`. Recklessly recursing into a type causes stack overflows if the type is recursive:
1961+
1962+
```toml
1963+
[environment]
1964+
python-version = "3.12"
1965+
```
1966+
1967+
```py
1968+
from typing import cast, Protocol
1969+
1970+
class Iterator[T](Protocol):
1971+
def __iter__(self) -> Iterator[T]: ...
1972+
1973+
def f(value: Iterator):
1974+
cast(Iterator, value) # error: [redundant-cast]
1975+
```
1976+
18991977
## TODO
19001978

19011979
Add tests for:

crates/ty_python_semantic/resources/mdtest/type_properties/is_equivalent_to.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,20 @@ static_assert(not is_equivalent_to(CallableTypeOf[f12], CallableTypeOf[f13]))
300300
static_assert(not is_equivalent_to(CallableTypeOf[f13], CallableTypeOf[f12]))
301301
```
302302

303+
### Unions containing `Callable`s
304+
305+
Two unions containing different `Callable` types are equivalent even if the unions are differently
306+
ordered:
307+
308+
```py
309+
from ty_extensions import CallableTypeOf, Unknown, is_equivalent_to, static_assert
310+
311+
def f(x): ...
312+
def g(x: Unknown): ...
313+
314+
static_assert(is_equivalent_to(CallableTypeOf[f] | int | str, str | int | CallableTypeOf[g]))
315+
```
316+
303317
### Unions containing `Callable`s containing unions
304318

305319
Differently ordered unions inside `Callable`s inside unions can still be equivalent:

crates/ty_python_semantic/src/types.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1102,7 +1102,7 @@ impl<'db> Type<'db> {
11021102
Type::Dynamic(_) => Some(CallableType::single(db, Signature::dynamic(self))),
11031103

11041104
Type::FunctionLiteral(function_literal) => {
1105-
Some(function_literal.into_callable_type(db))
1105+
Some(Type::Callable(function_literal.into_callable_type(db)))
11061106
}
11071107
Type::BoundMethod(bound_method) => Some(bound_method.into_callable_type(db)),
11081108

@@ -7336,6 +7336,10 @@ impl<'db> CallableType<'db> {
73367336
///
73377337
/// See [`Type::is_equivalent_to`] for more details.
73387338
fn is_equivalent_to(self, db: &'db dyn Db, other: Self) -> bool {
7339+
if self == other {
7340+
return true;
7341+
}
7342+
73397343
self.is_function_like(db) == other.is_function_like(db)
73407344
&& self
73417345
.signatures(db)

crates/ty_python_semantic/src/types/cyclic.rs

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use rustc_hash::FxHashMap;
2+
13
use crate::FxIndexSet;
24
use crate::types::Type;
35
use std::cmp::Eq;
@@ -19,14 +21,27 @@ pub(crate) type PairVisitor<'db> = CycleDetector<(Type<'db>, Type<'db>), bool>;
1921

2022
#[derive(Debug)]
2123
pub(crate) struct CycleDetector<T, R> {
24+
/// If the type we're visiting is present in `seen`,
25+
/// it indicates that we've hit a cycle (due to a recursive type);
26+
/// we need to immediately short circuit the whole operation and return the fallback value.
27+
/// That's why we pop items off the end of `seen` after we've visited them.
2228
seen: FxIndexSet<T>,
29+
30+
/// Unlike `seen`, this field is a pure performance optimisation (and an essential one).
31+
/// If the type we're trying to normalize is present in `cache`, it doesn't necessarily mean we've hit a cycle:
32+
/// it just means that we've already visited this inner type as part of a bigger call chain we're currently in.
33+
/// Since this cache is just a performance optimisation, it doesn't make sense to pop items off the end of the
34+
/// cache after they've been visited (it would sort-of defeat the point of a cache if we did!)
35+
cache: FxHashMap<T, R>,
36+
2337
fallback: R,
2438
}
2539

26-
impl<T: Hash + Eq, R: Copy> CycleDetector<T, R> {
40+
impl<T: Hash + Eq + Copy, R: Copy> CycleDetector<T, R> {
2741
pub(crate) fn new(fallback: R) -> Self {
2842
CycleDetector {
2943
seen: FxIndexSet::default(),
44+
cache: FxHashMap::default(),
3045
fallback,
3146
}
3247
}
@@ -35,7 +50,12 @@ impl<T: Hash + Eq, R: Copy> CycleDetector<T, R> {
3550
if !self.seen.insert(item) {
3651
return self.fallback;
3752
}
53+
if let Some(ty) = self.cache.get(&item) {
54+
self.seen.pop();
55+
return *ty;
56+
}
3857
let ret = func(self);
58+
self.cache.insert(item, ret);
3959
self.seen.pop();
4060
ret
4161
}

crates/ty_python_semantic/src/types/function.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -767,8 +767,8 @@ impl<'db> FunctionType<'db> {
767767
}
768768

769769
/// Convert the `FunctionType` into a [`Type::Callable`].
770-
pub(crate) fn into_callable_type(self, db: &'db dyn Db) -> Type<'db> {
771-
Type::Callable(CallableType::new(db, self.signature(db), false))
770+
pub(crate) fn into_callable_type(self, db: &'db dyn Db) -> CallableType<'db> {
771+
CallableType::new(db, self.signature(db), false)
772772
}
773773

774774
/// Convert the `FunctionType` into a [`Type::BoundMethod`].

crates/ty_python_semantic/src/types/instance.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,14 @@ impl<'db> ProtocolInstanceType<'db> {
270270
///
271271
/// TODO: consider the types of the members as well as their existence
272272
pub(super) fn is_equivalent_to(self, db: &'db dyn Db, other: Self) -> bool {
273-
self.normalized(db) == other.normalized(db)
273+
if self == other {
274+
return true;
275+
}
276+
let self_normalized = self.normalized(db);
277+
if self_normalized == Type::ProtocolInstance(other) {
278+
return true;
279+
}
280+
self_normalized == other.normalized(db)
274281
}
275282

276283
/// Return `true` if this protocol type is disjoint from the protocol `other`.

crates/ty_python_semantic/src/types/protocol_class.rs

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ impl<'db> ProtocolMemberData<'db> {
260260

261261
#[derive(Debug, Copy, Clone, PartialEq, Eq, salsa::Update, Hash)]
262262
enum ProtocolMemberKind<'db> {
263-
Method(Type<'db>), // TODO: use CallableType
263+
Method(CallableType<'db>),
264264
Property(PropertyInstanceType<'db>),
265265
Other(Type<'db>),
266266
}
@@ -335,7 +335,7 @@ fn walk_protocol_member<'db, V: super::visitor::TypeVisitor<'db> + ?Sized>(
335335
visitor: &mut V,
336336
) {
337337
match member.kind {
338-
ProtocolMemberKind::Method(method) => visitor.visit_type(db, method),
338+
ProtocolMemberKind::Method(method) => visitor.visit_callable_type(db, method),
339339
ProtocolMemberKind::Property(property) => {
340340
visitor.visit_property_instance_type(db, property);
341341
}
@@ -354,7 +354,7 @@ impl<'a, 'db> ProtocolMember<'a, 'db> {
354354

355355
fn ty(&self) -> Type<'db> {
356356
match &self.kind {
357-
ProtocolMemberKind::Method(callable) => *callable,
357+
ProtocolMemberKind::Method(callable) => Type::Callable(*callable),
358358
ProtocolMemberKind::Property(property) => Type::PropertyInstance(*property),
359359
ProtocolMemberKind::Other(ty) => *ty,
360360
}
@@ -508,13 +508,10 @@ fn cached_protocol_interface<'db>(
508508
(Type::Callable(callable), BoundOnClass::Yes)
509509
if callable.is_function_like(db) =>
510510
{
511-
ProtocolMemberKind::Method(ty)
511+
ProtocolMemberKind::Method(callable)
512512
}
513-
// TODO: method members that have `FunctionLiteral` types should be upcast
514-
// to `CallableType` so that two protocols with identical method members
515-
// are recognized as equivalent.
516-
(Type::FunctionLiteral(_function), BoundOnClass::Yes) => {
517-
ProtocolMemberKind::Method(ty)
513+
(Type::FunctionLiteral(function), BoundOnClass::Yes) => {
514+
ProtocolMemberKind::Method(function.into_callable_type(db))
518515
}
519516
_ => ProtocolMemberKind::Other(ty),
520517
};

crates/ty_python_semantic/src/types/signatures.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1318,8 +1318,13 @@ impl<'db> Parameter<'db> {
13181318
form,
13191319
} = self;
13201320

1321-
// Ensure unions and intersections are ordered in the annotated type (if there is one)
1322-
let annotated_type = annotated_type.map(|ty| ty.normalized_impl(db, visitor));
1321+
// Ensure unions and intersections are ordered in the annotated type (if there is one).
1322+
// Ensure that a parameter without an annotation is treated equivalently to a parameter
1323+
// with a dynamic type as its annotation. (We must use `Any` here as all dynamic types
1324+
// normalize to `Any`.)
1325+
let annotated_type = annotated_type
1326+
.map(|ty| ty.normalized_impl(db, visitor))
1327+
.unwrap_or_else(Type::any);
13231328

13241329
// Ensure that parameter names are stripped from positional-only, variadic and keyword-variadic parameters.
13251330
// Ensure that we only record whether a parameter *has* a default
@@ -1351,7 +1356,7 @@ impl<'db> Parameter<'db> {
13511356
};
13521357

13531358
Self {
1354-
annotated_type,
1359+
annotated_type: Some(annotated_type),
13551360
kind,
13561361
form: *form,
13571362
}

0 commit comments

Comments
 (0)