Skip to content

Commit 2afdca5

Browse files
committed
[ty] Various fixes for generic protocols
1 parent 9ac39ce commit 2afdca5

File tree

8 files changed

+85
-33
lines changed

8 files changed

+85
-33
lines changed

crates/ty_python_semantic/resources/mdtest/protocols.md

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,20 @@ class NotAProtocol: ...
9595
reveal_type(is_protocol(NotAProtocol)) # revealed: Literal[False]
9696
```
9797

98+
Note, however, that `is_protocol` returns `False` at runtime for specializations of generic
99+
protocols. We still consider these to be "protocol classes" internally, regardless:
100+
101+
```py
102+
class MyGenericProtocol[T](Protocol):
103+
x: T
104+
105+
reveal_type(is_protocol(MyGenericProtocol)) # revealed: Literal[True]
106+
107+
# We still consider this a protocol class internally,
108+
# but the inferred type of the call here reflects the result at runtime:
109+
reveal_type(is_protocol(MyGenericProtocol[int])) # revealed: Literal[False]
110+
```
111+
98112
A type checker should follow the typeshed stubs if a non-class is passed in, and typeshed's stubs
99113
indicate that the argument passed in must be an instance of `type`.
100114

@@ -395,24 +409,38 @@ To see the kinds and types of the protocol members, you can use the debugging ai
395409

396410
```py
397411
from ty_extensions import reveal_protocol_interface
398-
from typing import SupportsIndex, SupportsAbs
412+
from typing import SupportsIndex, SupportsAbs, Iterator
399413

400414
# error: [revealed-type] "Revealed protocol interface: `{"method_member": MethodMember(`(self) -> bytes`), "x": AttributeMember(`int`), "y": PropertyMember { getter: `def y(self) -> str` }, "z": PropertyMember { getter: `def z(self) -> int`, setter: `def z(self, z: int) -> None` }}`"
401415
reveal_protocol_interface(Foo)
402416
# error: [revealed-type] "Revealed protocol interface: `{"__index__": MethodMember(`(self) -> int`)}`"
403417
reveal_protocol_interface(SupportsIndex)
404-
# error: [revealed-type] "Revealed protocol interface: `{"__abs__": MethodMember(`(self) -> _T_co@SupportsAbs`)}`"
418+
# error: [revealed-type] "Revealed protocol interface: `{"__abs__": MethodMember(`(self) -> Unknown`)}`"
405419
reveal_protocol_interface(SupportsAbs)
420+
# error: [revealed-type] "Revealed protocol interface: `{"__iter__": MethodMember(`(self) -> Iterator[Unknown]`), "__next__": MethodMember(`(self) -> Unknown`)}`"
421+
reveal_protocol_interface(Iterator)
406422

407423
# error: [invalid-argument-type] "Invalid argument to `reveal_protocol_interface`: Only protocol classes can be passed to `reveal_protocol_interface`"
408424
reveal_protocol_interface(int)
409425
# error: [invalid-argument-type] "Argument to function `reveal_protocol_interface` is incorrect: Expected `type`, found `Literal["foo"]`"
410426
reveal_protocol_interface("foo")
427+
```
411428

412-
# TODO: this should be a `revealed-type` diagnostic rather than `invalid-argument-type`, and it should reveal `{"__abs__": MethodMember(`(self) -> int`)}` for the protocol interface
413-
#
414-
# error: [invalid-argument-type] "Invalid argument to `reveal_protocol_interface`: Only protocol classes can be passed to `reveal_protocol_interface`"
429+
Similar to the way that `typing.is_protocol` returns `False` at runtime for all generic aliases,
430+
`typing.get_protocol_members` raises an exception at runtime if you pass it a generic alias, so we
431+
do not implement any special handling for generic aliases passed to the function.
432+
`ty_extensions.reveal_protocol_interface` can be used on both, however:
433+
434+
```py
435+
# TODO: these fail at runtime, but we don't emit `[invalid-argument-type]` diagnostics
436+
# currently due to https://github.com/astral-sh/ty/issues/116
437+
reveal_type(get_protocol_members(SupportsAbs[int])) # revealed: frozenset[str]
438+
reveal_type(get_protocol_members(Iterator[int])) # revealed: frozenset[str]
439+
440+
# error: [revealed-type] "Revealed protocol interface: `{"__abs__": MethodMember(`(self) -> int`)}`"
415441
reveal_protocol_interface(SupportsAbs[int])
442+
# error: [revealed-type] "Revealed protocol interface: `{"__iter__": MethodMember(`(self) -> Iterator[int]`), "__next__": MethodMember(`(self) -> int`)}`"
443+
reveal_protocol_interface(Iterator[int])
416444
```
417445

418446
Certain special attributes and methods are not considered protocol members at runtime, and should

crates/ty_python_semantic/src/types/call/bind.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -759,6 +759,10 @@ impl<'db> Bindings<'db> {
759759

760760
Some(KnownFunction::IsProtocol) => {
761761
if let [Some(ty)] = overload.parameter_types() {
762+
// We evaluate this to `Literal[True]` only if the runtime function `typing.is_protocol`
763+
// would return `True` for the given type. Internally we consider `SupportsAbs[int]` to
764+
// be a "(specialised) protocol class", but `typing.is_protocol(SupportsAbs[int])` returns
765+
// `False` at runtime, so we do not set the return type to `Literal[True]` in this case.
762766
overload.set_return_type(Type::BooleanLiteral(
763767
ty.into_class_literal()
764768
.is_some_and(|class| class.is_protocol(db)),
@@ -767,6 +771,9 @@ impl<'db> Bindings<'db> {
767771
}
768772

769773
Some(KnownFunction::GetProtocolMembers) => {
774+
// Similarly to `is_protocol`, we only evaluate to this a frozenset of literal strings if a
775+
// class-literal is passed in, not if a generic alias is passed in, to emulate the behaviour
776+
// of `typing.get_protocol_members` at runtime.
770777
if let [Some(Type::ClassLiteral(class))] = overload.parameter_types() {
771778
if let Some(protocol_class) = class.into_protocol_class(db) {
772779
let member_names = protocol_class

crates/ty_python_semantic/src/types/class.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,6 +1113,14 @@ impl<'db> ClassType<'db> {
11131113
}
11141114
}
11151115
}
1116+
1117+
pub(super) fn is_protocol(self, db: &'db dyn Db) -> bool {
1118+
self.class_literal(db).0.is_protocol(db)
1119+
}
1120+
1121+
pub(super) fn header_span(self, db: &'db dyn Db) -> Span {
1122+
self.class_literal(db).0.header_span(db)
1123+
}
11161124
}
11171125

11181126
impl<'db> From<GenericAlias<'db>> for ClassType<'db> {

crates/ty_python_semantic/src/types/diagnostic.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use crate::types::string_annotation::{
1515
IMPLICIT_CONCATENATED_STRING_TYPE_ANNOTATION, INVALID_SYNTAX_IN_FORWARD_ANNOTATION,
1616
RAW_STRING_TYPE_ANNOTATION,
1717
};
18-
use crate::types::{SpecialFormType, Type, protocol_class::ProtocolClassLiteral};
18+
use crate::types::{SpecialFormType, Type, protocol_class::ProtocolClass};
1919
use crate::util::diagnostics::format_enumeration;
2020
use crate::{Db, FxIndexMap, FxOrderMap, Module, ModuleName, Program, declare_lint};
2121
use itertools::Itertools;
@@ -2369,7 +2369,7 @@ pub(crate) fn add_type_expression_reference_link<'db, 'ctx>(
23692369
pub(crate) fn report_runtime_check_against_non_runtime_checkable_protocol(
23702370
context: &InferContext,
23712371
call: &ast::ExprCall,
2372-
protocol: ProtocolClassLiteral,
2372+
protocol: ProtocolClass,
23732373
function: KnownFunction,
23742374
) {
23752375
let Some(builder) = context.report_lint(&INVALID_ARGUMENT_TYPE, call) else {
@@ -2406,7 +2406,7 @@ pub(crate) fn report_runtime_check_against_non_runtime_checkable_protocol(
24062406
pub(crate) fn report_attempted_protocol_instantiation(
24072407
context: &InferContext,
24082408
call: &ast::ExprCall,
2409-
protocol: ProtocolClassLiteral,
2409+
protocol: ProtocolClass,
24102410
) {
24112411
let Some(builder) = context.report_lint(&CALL_NON_CALLABLE, call) else {
24122412
return;

crates/ty_python_semantic/src/types/function.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1427,7 +1427,7 @@ impl KnownFunction {
14271427
return;
14281428
};
14291429
let Some(protocol_class) = param_type
1430-
.into_class_literal()
1430+
.to_class_type(db)
14311431
.and_then(|class| class.into_protocol_class(db))
14321432
else {
14331433
report_bad_argument_to_protocol_interface(

crates/ty_python_semantic/src/types/infer.rs

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1158,8 +1158,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
11581158
}
11591159

11601160
if is_protocol
1161-
&& !(base_class.class_literal(self.db()).0.is_protocol(self.db())
1162-
|| base_class.is_known(self.db(), KnownClass::Object))
1161+
&& !(base_class.is_protocol(self.db()) || base_class.is_object(self.db()))
11631162
{
11641163
if let Some(builder) = self
11651164
.context
@@ -6184,11 +6183,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
61846183
// subclasses of the protocol to be passed to parameters that accept `type[SomeProtocol]`.
61856184
// <https://typing.python.org/en/latest/spec/protocol.html#type-and-class-objects-vs-protocols>.
61866185
if !callable_type.is_subclass_of() {
6187-
if let Some(protocol) = class
6188-
.class_literal(self.db())
6189-
.0
6190-
.into_protocol_class(self.db())
6191-
{
6186+
if let Some(protocol) = class.into_protocol_class(self.db()) {
61926187
report_attempted_protocol_instantiation(
61936188
&self.context,
61946189
call_expression,

crates/ty_python_semantic/src/types/instance.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -608,10 +608,8 @@ impl<'db> Protocol<'db> {
608608
fn interface(self, db: &'db dyn Db) -> ProtocolInterface<'db> {
609609
match self {
610610
Self::FromClass(class) => class
611-
.class_literal(db)
612-
.0
613611
.into_protocol_class(db)
614-
.expect("Protocol class literal should be a protocol class")
612+
.expect("Class wrapped by `Protocol` should be a protocol class")
615613
.interface(db),
616614
Self::Synthesized(synthesized) => synthesized.interface(),
617615
}

crates/ty_python_semantic/src/types/protocol_class.rs

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use ruff_python_ast::name::Name;
77

88
use super::TypeVarVariance;
99
use crate::semantic_index::place_table;
10+
use crate::types::ClassType;
1011
use crate::{
1112
Db, FxOrderSet,
1213
place::{Boundness, Place, PlaceAndQualifiers, place_from_bindings, place_from_declarations},
@@ -21,16 +22,24 @@ use crate::{
2122

2223
impl<'db> ClassLiteral<'db> {
2324
/// Returns `Some` if this is a protocol class, `None` otherwise.
24-
pub(super) fn into_protocol_class(self, db: &'db dyn Db) -> Option<ProtocolClassLiteral<'db>> {
25-
self.is_protocol(db).then_some(ProtocolClassLiteral(self))
25+
pub(super) fn into_protocol_class(self, db: &'db dyn Db) -> Option<ProtocolClass<'db>> {
26+
self.is_protocol(db)
27+
.then_some(ProtocolClass(ClassType::NonGeneric(self)))
28+
}
29+
}
30+
31+
impl<'db> ClassType<'db> {
32+
/// Returns `Some` if this is a protocol class, `None` otherwise.
33+
pub(super) fn into_protocol_class(self, db: &'db dyn Db) -> Option<ProtocolClass<'db>> {
34+
self.is_protocol(db).then_some(ProtocolClass(self))
2635
}
2736
}
2837

2938
/// Representation of a single `Protocol` class definition.
3039
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
31-
pub(super) struct ProtocolClassLiteral<'db>(ClassLiteral<'db>);
40+
pub(super) struct ProtocolClass<'db>(ClassType<'db>);
3241

33-
impl<'db> ProtocolClassLiteral<'db> {
42+
impl<'db> ProtocolClass<'db> {
3443
/// Returns the protocol members of this class.
3544
///
3645
/// A protocol's members define the interface declared by the protocol.
@@ -51,13 +60,15 @@ impl<'db> ProtocolClassLiteral<'db> {
5160
}
5261

5362
pub(super) fn is_runtime_checkable(self, db: &'db dyn Db) -> bool {
54-
self.known_function_decorators(db)
63+
self.class_literal(db)
64+
.0
65+
.known_function_decorators(db)
5566
.contains(&KnownFunction::RuntimeCheckable)
5667
}
5768
}
5869

59-
impl<'db> Deref for ProtocolClassLiteral<'db> {
60-
type Target = ClassLiteral<'db>;
70+
impl<'db> Deref for ProtocolClass<'db> {
71+
type Target = ClassType<'db>;
6172

6273
fn deref(&self) -> &Self::Target {
6374
&self.0
@@ -521,16 +532,19 @@ enum BoundOnClass {
521532
#[salsa::tracked(cycle_fn=proto_interface_cycle_recover, cycle_initial=proto_interface_cycle_initial, heap_size=ruff_memory_usage::heap_size)]
522533
fn cached_protocol_interface<'db>(
523534
db: &'db dyn Db,
524-
class: ClassLiteral<'db>,
535+
class: ClassType<'db>,
525536
) -> ProtocolInterface<'db> {
526537
let mut members = BTreeMap::default();
527538

528-
for parent_protocol in class
529-
.iter_mro(db, None)
539+
for (parent_protocol, specialization) in class
540+
.iter_mro(db)
530541
.filter_map(ClassBase::into_class)
531-
.filter_map(|class| class.class_literal(db).0.into_protocol_class(db))
542+
.filter_map(|class| {
543+
let (class, specialization) = class.class_literal(db);
544+
Some((class.into_protocol_class(db)?, specialization))
545+
})
532546
{
533-
let parent_scope = parent_protocol.body_scope(db);
547+
let parent_scope = parent_protocol.class_literal(db).0.body_scope(db);
534548
let use_def_map = use_def_map(db, parent_scope);
535549
let place_table = place_table(db, parent_scope);
536550

@@ -574,6 +588,8 @@ fn cached_protocol_interface<'db>(
574588
})
575589
.filter(|(name, _, _, _)| !excluded_from_proto_members(name))
576590
.map(|(name, ty, qualifiers, bound_on_class)| {
591+
let ty = ty.apply_optional_specialization(db, specialization);
592+
577593
let kind = match (ty, bound_on_class) {
578594
// TODO: if the getter or setter is a function literal, we should
579595
// upcast it to a `CallableType` so that two protocols with identical property
@@ -606,14 +622,14 @@ fn proto_interface_cycle_recover<'db>(
606622
_db: &dyn Db,
607623
_value: &ProtocolInterface<'db>,
608624
_count: u32,
609-
_class: ClassLiteral<'db>,
625+
_class: ClassType<'db>,
610626
) -> salsa::CycleRecoveryAction<ProtocolInterface<'db>> {
611627
salsa::CycleRecoveryAction::Iterate
612628
}
613629

614630
fn proto_interface_cycle_initial<'db>(
615631
db: &'db dyn Db,
616-
_class: ClassLiteral<'db>,
632+
_class: ClassType<'db>,
617633
) -> ProtocolInterface<'db> {
618634
ProtocolInterface::empty(db)
619635
}

0 commit comments

Comments
 (0)