Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions crates/ty_python_semantic/resources/mdtest/protocols.md
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,31 @@ class Foo(Protocol):
reveal_type(get_protocol_members(Foo)) # revealed: frozenset[Literal["method_member", "x", "y", "z"]]
```

To see the kinds and types of the protocol members, you can use the debugging aid
`ty_extensions.reveal_protocol_interface`, meanwhile:

```py
from ty_extensions import reveal_protocol_interface
from typing import SupportsIndex, SupportsAbs

# error: [revealed-type] "Revealed protocol interface: `{"method_member": MethodMember(`(self) -> bytes`), "x": AttributeMember(`int`), "y": PropertyMember { getter: `def y(self) -> str` }, "z": PropertyMember { getter: `def z(self) -> int`, setter: `def z(self, z: int) -> None` }}`"
reveal_protocol_interface(Foo)
# error: [revealed-type] "Revealed protocol interface: `{"__index__": MethodMember(`(self) -> int`)}`"
reveal_protocol_interface(SupportsIndex)
# error: [revealed-type] "Revealed protocol interface: `{"__abs__": MethodMember(`(self) -> _T_co`)}`"
reveal_protocol_interface(SupportsAbs)

# error: [invalid-argument-type] "Invalid argument to `reveal_protocol_interface`: Only protocol classes can be passed to `reveal_protocol_interface`"
reveal_protocol_interface(int)
# error: [invalid-argument-type] "Argument to function `reveal_protocol_interface` is incorrect: Expected `type`, found `Literal["foo"]`"
reveal_protocol_interface("foo")

# TODO: this should be a `revealed-type` diagnostic rather than `invalid-argument-type`, and it should reveal `{"__abs__": MethodMember(`(self) -> int`)}` for the protocol interface
#
# error: [invalid-argument-type] "Invalid argument to `reveal_protocol_interface`: Only protocol classes can be passed to `reveal_protocol_interface`"
reveal_protocol_interface(SupportsAbs[int])
```

Certain special attributes and methods are not considered protocol members at runtime, and should
not be considered protocol members by type checkers either:

Expand Down
35 changes: 35 additions & 0 deletions crates/ty_python_semantic/src/types/diagnostic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2251,6 +2251,41 @@ pub(crate) fn report_bad_argument_to_get_protocol_members(
diagnostic.info("See https://typing.python.org/en/latest/spec/protocol.html#");
}

pub(crate) fn report_bad_argument_to_protocol_interface(
context: &InferContext,
call: &ast::ExprCall,
param_type: Type,
) {
let Some(builder) = context.report_lint(&INVALID_ARGUMENT_TYPE, call) else {
return;
};
let db = context.db();
let mut diagnostic = builder.into_diagnostic("Invalid argument to `reveal_protocol_interface`");
diagnostic
.set_primary_message("Only protocol classes can be passed to `reveal_protocol_interface`");

if let Some(class) = param_type.to_class_type(context.db()) {
let mut class_def_diagnostic = SubDiagnostic::new(
SubDiagnosticSeverity::Info,
format_args!(
"`{}` is declared here, but it is not a protocol class:",
class.name(db)
),
);
class_def_diagnostic.annotate(Annotation::primary(
class.class_literal(db).0.header_span(db),
));
diagnostic.sub(class_def_diagnostic);
}

diagnostic.info(
"A class is only a protocol class if it directly inherits \
from `typing.Protocol` or `typing_extensions.Protocol`",
);
// See TODO in `report_bad_argument_to_get_protocol_members` above
diagnostic.info("See https://typing.python.org/en/latest/spec/protocol.html");
}

pub(crate) fn report_invalid_arguments_to_callable(
context: &InferContext,
subscript: &ast::ExprSubscript,
Expand Down
33 changes: 32 additions & 1 deletion crates/ty_python_semantic/src/types/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ use crate::types::call::{Binding, CallArguments};
use crate::types::context::InferContext;
use crate::types::diagnostic::{
REDUNDANT_CAST, STATIC_ASSERT_ERROR, TYPE_ASSERTION_FAILURE,
report_bad_argument_to_get_protocol_members,
report_bad_argument_to_get_protocol_members, report_bad_argument_to_protocol_interface,
report_runtime_check_against_non_runtime_checkable_protocol,
};
use crate::types::generics::{GenericContext, walk_generic_context};
Expand Down Expand Up @@ -1088,6 +1088,8 @@ pub enum KnownFunction {
TopMaterialization,
/// `ty_extensions.bottom_materialization`
BottomMaterialization,
/// `ty_extensions.reveal_protocol_interface`
RevealProtocolInterface,
}

impl KnownFunction {
Expand Down Expand Up @@ -1153,6 +1155,7 @@ impl KnownFunction {
| Self::EnumMembers
| Self::StaticAssert
| Self::HasMember
| Self::RevealProtocolInterface
| Self::AllMembers => module.is_ty_extensions(),
Self::ImportModule => module.is_importlib(),
}
Expand Down Expand Up @@ -1345,6 +1348,33 @@ impl KnownFunction {
report_bad_argument_to_get_protocol_members(context, call_expression, *class);
}

KnownFunction::RevealProtocolInterface => {
let [Some(param_type)] = parameter_types else {
return;
};
let Some(protocol_class) = param_type
.into_class_literal()
.and_then(|class| class.into_protocol_class(db))
else {
report_bad_argument_to_protocol_interface(
context,
call_expression,
*param_type,
);
return;
};
if let Some(builder) =
context.report_diagnostic(DiagnosticId::RevealedType, Severity::Info)
{
let mut diag = builder.into_diagnostic("Revealed protocol interface");
let span = context.span(&call_expression.arguments.args[0]);
diag.annotate(Annotation::primary(span).message(format_args!(
"`{}`",
protocol_class.interface(db).display(db)
)));
}
}

KnownFunction::IsInstance | KnownFunction::IsSubclass => {
let [Some(first_arg), Some(Type::ClassLiteral(class))] = parameter_types else {
return;
Expand Down Expand Up @@ -1458,6 +1488,7 @@ pub(crate) mod tests {
| KnownFunction::TopMaterialization
| KnownFunction::BottomMaterialization
| KnownFunction::HasMember
| KnownFunction::RevealProtocolInterface
| KnownFunction::AllMembers => KnownModule::TyExtensions,

KnownFunction::ImportModule => KnownModule::ImportLib,
Expand Down
61 changes: 61 additions & 0 deletions crates/ty_python_semantic/src/types/protocol_class.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::fmt::Write;
use std::{collections::BTreeMap, ops::Deref};

use itertools::Itertools;
Expand Down Expand Up @@ -215,6 +216,31 @@ impl<'db> ProtocolInterface<'db> {
data.find_legacy_typevars(db, typevars);
}
}

pub(super) fn display(self, db: &'db dyn Db) -> impl std::fmt::Display {
struct ProtocolInterfaceDisplay<'db> {
db: &'db dyn Db,
interface: ProtocolInterface<'db>,
}

impl std::fmt::Display for ProtocolInterfaceDisplay<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_char('{')?;
for (i, (name, data)) in self.interface.inner(self.db).iter().enumerate() {
write!(f, "\"{name}\": {data}", data = data.display(self.db))?;
if i < self.interface.inner(self.db).len() - 1 {
f.write_str(", ")?;
}
}
f.write_char('}')
}
}

ProtocolInterfaceDisplay {
db,
interface: self,
}
}
}

#[derive(Debug, PartialEq, Eq, Clone, Hash, salsa::Update)]
Expand Down Expand Up @@ -256,6 +282,41 @@ impl<'db> ProtocolMemberData<'db> {
qualifiers: self.qualifiers,
}
}

fn display(&self, db: &'db dyn Db) -> impl std::fmt::Display {
struct ProtocolMemberDataDisplay<'db> {
db: &'db dyn Db,
data: ProtocolMemberKind<'db>,
}

impl std::fmt::Display for ProtocolMemberDataDisplay<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.data {
ProtocolMemberKind::Method(callable) => {
write!(f, "MethodMember(`{}`)", callable.display(self.db))
}
ProtocolMemberKind::Property(property) => {
let mut d = f.debug_struct("PropertyMember");
if let Some(getter) = property.getter(self.db) {
d.field("getter", &format_args!("`{}`", &getter.display(self.db)));
}
if let Some(setter) = property.setter(self.db) {
d.field("setter", &format_args!("`{}`", &setter.display(self.db)));
}
d.finish()
}
ProtocolMemberKind::Other(ty) => {
write!(f, "AttributeMember(`{}`)", ty.display(self.db))
}
}
}
}

ProtocolMemberDataDisplay {
db,
data: self.kind,
}
}
}

#[derive(Debug, Copy, Clone, PartialEq, Eq, salsa::Update, Hash)]
Expand Down
5 changes: 5 additions & 0 deletions crates/ty_vendored/ty_extensions/ty_extensions.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,8 @@ def all_members(obj: Any) -> tuple[str, ...]: ...

# Returns `True` if the given object has a member with the given name.
def has_member(obj: Any, name: str) -> bool: ...

# Passing a protocol type to this function will cause ty to emit an info-level
# diagnostic describing the protocol's interface. Passing a non-protocol type
# will cause ty to emit an error diagnostic.
def reveal_protocol_interface(protocol: type) -> None: ...
Loading