Skip to content

Commit

Permalink
Support __getitem__ type inference for subscripts
Browse files Browse the repository at this point in the history
  • Loading branch information
charliermarsh committed Oct 1, 2024
1 parent 73e884b commit 0ef8def
Showing 1 changed file with 162 additions and 3 deletions.
165 changes: 162 additions & 3 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ use crate::stdlib::builtins_module_scope;
use crate::types::diagnostic::{TypeCheckDiagnostic, TypeCheckDiagnostics};
use crate::types::{
bindings_ty, builtins_symbol_ty, declarations_ty, global_symbol_ty, symbol_ty,
typing_extensions_symbol_ty, BytesLiteralType, ClassType, FunctionKind, FunctionType,
StringLiteralType, Truthiness, TupleType, Type, TypeArrayDisplay, UnionType,
typing_extensions_symbol_ty, BytesLiteralType, CallOutcome, ClassType, FunctionKind,
FunctionType, StringLiteralType, Truthiness, TupleType, Type, TypeArrayDisplay, UnionType,
};
use crate::Db;

Expand Down Expand Up @@ -1322,6 +1322,39 @@ impl<'db> TypeInferenceBuilder<'db> {
);
}

/// Emit a diagnostic declaring that a type does not support subscripting.
pub(super) fn non_subscriptable_diagnostic(
&mut self,
node: AnyNodeRef,
non_subscriptable_ty: Type<'db>,
) {
self.add_diagnostic(
node,
"non-subscriptable",
format_args!(
"Cannot subscript object of type '{}' with no `__getitem__` method.",
non_subscriptable_ty.display(self.db)
),
);
}

/// Emit a diagnostic declaring that a dunder method is not callable.
pub(super) fn dunder_not_callable_diagnostic(
&mut self,
node: AnyNodeRef,
not_callable_ty: Type<'db>,
dunder: &str,
) {
self.add_diagnostic(
node,
"not-callable",
format_args!(
"Method `{dunder}` is not callable on object of type '{}'.",
not_callable_ty.display(self.db)
),
);
}

fn infer_for_statement_definition(
&mut self,
target: &ast::ExprName,
Expand Down Expand Up @@ -2588,7 +2621,51 @@ impl<'db> TypeInferenceBuilder<'db> {
Type::Unknown
})
}
_ => Type::Todo,
(value_ty, slice_ty) => {
// Resolve the value to its class.
let value_meta_ty = value_ty.to_meta_type(self.db);

// If the class defines `__getitem__`, return its return type.
//
// See: https://docs.python.org/3/reference/datamodel.html#class-getitem-versus-getitem
let dunder_getitem_method = value_meta_ty.member(self.db, "__getitem__");
if !dunder_getitem_method.is_unbound() {
let CallOutcome::Callable { return_ty } =
dunder_getitem_method.call(self.db, &[slice_ty])
else {
self.dunder_not_callable_diagnostic(
(&**value).into(),
value_ty,
"__getitem__",
);
return Type::Unknown;
};
return return_ty;
}

// Otherwise, if the value is itself a class and defines `__class_getitem__`,
// return its return type.
if matches!(value_ty, Type::Class(_)) {
let dunder_class_getitem_method = value_ty.member(self.db, "__class_getitem__");
if !dunder_class_getitem_method.is_unbound() {
let CallOutcome::Callable { return_ty } =
dunder_class_getitem_method.call(self.db, &[slice_ty])
else {
self.dunder_not_callable_diagnostic(
(&**value).into(),
value_ty,
"__class_getitem__",
);
return Type::Unknown;
};
return return_ty;
}
}

// Otherwise, emit a diagnostic.
self.non_subscriptable_diagnostic((&**value).into(), value_ty);
Type::Unknown
}
}
}

Expand Down Expand Up @@ -6723,6 +6800,30 @@ mod tests {
Ok(())
}

#[test]
fn subscript_not_callable_getitem() -> anyhow::Result<()> {
let mut db = setup_db();

db.write_dedented(
"/src/a.py",
"
class NotSubscriptable:
__getitem__ = None
a = NotSubscriptable()[0]
",
)?;

assert_public_ty(&db, "/src/a.py", "a", "Unknown");
assert_file_diagnostics(
&db,
"/src/a.py",
&["Method `__getitem__` is not callable on object of type 'NotSubscriptable'."],
);

Ok(())
}

#[test]
fn dunder_call() -> anyhow::Result<()> {
let mut db = setup_db();
Expand Down Expand Up @@ -6750,6 +6851,64 @@ mod tests {
assert_public_ty(&db, "/src/a.py", "b", "Unknown");

assert_file_diagnostics(&db, "src/a.py", &["Object of type 'Unit' is not callable."]);
Ok(())
}

#[test]
fn subscript_str_literal() -> anyhow::Result<()> {
let mut db = setup_db();

db.write_dedented(
"/src/a.py",
"
def add(x: int, y: int) -> int:
return x + y
a = 'abcde'[add(0, 1)]
",
)?;

assert_public_ty(&db, "/src/a.py", "a", "str");

Ok(())
}

#[test]
fn subscript_getitem() -> anyhow::Result<()> {
let mut db = setup_db();

db.write_dedented(
"/src/a.py",
"
class Identity:
def __getitem__(self, index: int) -> int:
return index
a = Identity()[0]
",
)?;

assert_public_ty(&db, "/src/a.py", "a", "int");

Ok(())
}

#[test]
fn subscript_class_getitem() -> anyhow::Result<()> {
let mut db = setup_db();

db.write_dedented(
"/src/a.py",
"
class Identity:
def __class_getitem__(cls, item: int) -> str:
return item
a = Identity[0]
",
)?;

assert_public_ty(&db, "/src/a.py", "a", "str");

Ok(())
}
Expand Down

0 comments on commit 0ef8def

Please sign in to comment.