Skip to content

Commit

Permalink
Add tests; support unions
Browse files Browse the repository at this point in the history
  • Loading branch information
charliermarsh committed Oct 1, 2024
1 parent 0ef8def commit 11acc07
Show file tree
Hide file tree
Showing 2 changed files with 211 additions and 61 deletions.
10 changes: 10 additions & 0 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,16 @@ impl<'db> Type<'db> {
}
}

/// Return true if the type is a class or a union of classes.
pub fn is_class(&self, db: &'db dyn Db) -> bool {
match self {
Type::Union(union) => union.elements(db).iter().all(|ty| ty.is_class(db)),
Type::Class(_) => true,
// / TODO include type[X], once we add that type
_ => false,
}
}

/// Return true if this type is a [subtype of] type `target`.
///
/// [subtype of]: https://typing.readthedocs.io/en/latest/spec/concepts.html#subtype-supertype-and-type-equivalence
Expand Down
262 changes: 201 additions & 61 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ use crate::stdlib::builtins_module_scope;
use crate::types::diagnostic::{TypeCheckDiagnostic, TypeCheckDiagnostics};
use crate::types::{
bindings_ty, builtins_symbol_ty, declarations_ty, global_symbol_ty, symbol_ty,
typing_extensions_symbol_ty, BytesLiteralType, CallOutcome, ClassType, FunctionKind,
FunctionType, StringLiteralType, Truthiness, TupleType, Type, TypeArrayDisplay, UnionType,
typing_extensions_symbol_ty, BytesLiteralType, ClassType, FunctionKind, FunctionType,
StringLiteralType, Truthiness, TupleType, Type, TypeArrayDisplay, UnionType,
};
use crate::Db;

Expand Down Expand Up @@ -1338,23 +1338,6 @@ impl<'db> TypeInferenceBuilder<'db> {
);
}

/// Emit a diagnostic declaring that a dunder method is not callable.
pub(super) fn dunder_not_callable_diagnostic(
&mut self,
node: AnyNodeRef,
not_callable_ty: Type<'db>,
dunder: &str,
) {
self.add_diagnostic(
node,
"not-callable",
format_args!(
"Method `{dunder}` is not callable on object of type '{}'.",
not_callable_ty.display(self.db)
),
);
}

fn infer_for_statement_definition(
&mut self,
target: &ast::ExprName,
Expand Down Expand Up @@ -2630,35 +2613,19 @@ impl<'db> TypeInferenceBuilder<'db> {
// See: https://docs.python.org/3/reference/datamodel.html#class-getitem-versus-getitem
let dunder_getitem_method = value_meta_ty.member(self.db, "__getitem__");
if !dunder_getitem_method.is_unbound() {
let CallOutcome::Callable { return_ty } =
dunder_getitem_method.call(self.db, &[slice_ty])
else {
self.dunder_not_callable_diagnostic(
(&**value).into(),
value_ty,
"__getitem__",
);
return Type::Unknown;
};
return return_ty;
return dunder_getitem_method
.call(self.db, &[slice_ty])
.unwrap_with_diagnostic(self.db, value.as_ref().into(), self);
}

// Otherwise, if the value is itself a class and defines `__class_getitem__`,
// return its return type.
if matches!(value_ty, Type::Class(_)) {
if value_ty.is_class(self.db) {
let dunder_class_getitem_method = value_ty.member(self.db, "__class_getitem__");
if !dunder_class_getitem_method.is_unbound() {
let CallOutcome::Callable { return_ty } =
dunder_class_getitem_method.call(self.db, &[slice_ty])
else {
self.dunder_not_callable_diagnostic(
(&**value).into(),
value_ty,
"__class_getitem__",
);
return Type::Unknown;
};
return return_ty;
return dunder_class_getitem_method
.call(self.db, &[slice_ty])
.unwrap_with_diagnostic(self.db, value.as_ref().into(), self);
}
}

Expand Down Expand Up @@ -6801,14 +6768,14 @@ mod tests {
}

#[test]
fn subscript_not_callable_getitem() -> anyhow::Result<()> {
fn subscript_getitem_unbound() -> anyhow::Result<()> {
let mut db = setup_db();

db.write_dedented(
"/src/a.py",
"
class NotSubscriptable:
__getitem__ = None
pass
a = NotSubscriptable()[0]
",
Expand All @@ -6818,39 +6785,33 @@ mod tests {
assert_file_diagnostics(
&db,
"/src/a.py",
&["Method `__getitem__` is not callable on object of type 'NotSubscriptable'."],
&["Cannot subscript object of type 'NotSubscriptable' with no `__getitem__` method."],
);

Ok(())
}

#[test]
fn dunder_call() -> anyhow::Result<()> {
fn subscript_not_callable_getitem() -> anyhow::Result<()> {
let mut db = setup_db();

db.write_dedented(
"/src/a.py",
"
class Multiplier:
def __init__(self, factor: float):
self.factor = factor
def __call__(self, number: float) -> float:
return number * self.factor
a = Multiplier(2.0)(3.0)
class Unit:
...
class NotSubscriptable:
__getitem__ = None
b = Unit()(3.0)
a = NotSubscriptable()[0]
",
)?;

assert_public_ty(&db, "/src/a.py", "a", "float");
assert_public_ty(&db, "/src/a.py", "b", "Unknown");
assert_public_ty(&db, "/src/a.py", "a", "Unknown");
assert_file_diagnostics(
&db,
"/src/a.py",
&["Object of type 'None' is not callable."],
);

assert_file_diagnostics(&db, "src/a.py", &["Object of type 'Unit' is not callable."]);
Ok(())
}

Expand Down Expand Up @@ -6913,6 +6874,185 @@ mod tests {
Ok(())
}

#[test]
fn subscript_getitem_union() -> anyhow::Result<()> {
let mut db = setup_db();

db.write_dedented(
"/src/a.py",
"
flag = True
class Identity:
if flag:
def __getitem__(self, index: int) -> int:
return index
else:
def __getitem__(self, index: int) -> str:
return str(index)
a = Identity()[0]
",
)?;

assert_public_ty(&db, "/src/a.py", "a", "int | str");

Ok(())
}

#[test]
fn subscript_class_getitem_union() -> anyhow::Result<()> {
let mut db = setup_db();

db.write_dedented(
"/src/a.py",
"
flag = True
class Identity:
if flag:
def __class_getitem__(cls, item: int) -> str:
return item
else:
def __class_getitem__(cls, item: int) -> int:
return item
a = Identity[0]
",
)?;

assert_public_ty(&db, "/src/a.py", "a", "str | int");

Ok(())
}

#[test]
fn subscript_class_getitem_class_union() -> anyhow::Result<()> {
let mut db = setup_db();

db.write_dedented(
"/src/a.py",
"
flag = True
class Identity1:
def __class_getitem__(cls, item: int) -> str:
return item
class Identity2:
def __class_getitem__(cls, item: int) -> int:
return item
if flag:
a = Identity1
else:
a = Identity2
b = a[0]
",
)?;

assert_public_ty(&db, "/src/a.py", "a", "Literal[Identity1, Identity2]");
assert_public_ty(&db, "/src/a.py", "b", "str | int");

Ok(())
}

#[test]
fn subscript_class_getitem_unbound_method_union() -> anyhow::Result<()> {
let mut db = setup_db();

db.write_dedented(
"/src/a.py",
"
flag = True
if flag:
class Identity:
def __class_getitem__(self, x: int) -> str:
pass
else:
class Identity:
pass
a = Identity[42]
",
)?;

assert_public_ty(&db, "/src/a.py", "a", "str | Unknown");

assert_file_diagnostics(
&db,
"/src/a.py",
&["Object of type 'Literal[__class_getitem__] | Unbound' is not callable (due to union element 'Unbound')."],
);

Ok(())
}

#[test]
fn subscript_class_getitem_non_class_union() -> anyhow::Result<()> {
let mut db = setup_db();

db.write_dedented(
"/src/a.py",
"
flag = True
if flag:
class Identity:
def __class_getitem__(self, x: int) -> str:
pass
else:
Identity = 1
a = Identity[42]
",
)?;

// TODO this should _probably_ emit `str | Unknown` instead of `Unknown`.
assert_public_ty(&db, "/src/a.py", "a", "Unknown");

assert_file_diagnostics(
&db,
"/src/a.py",
&["Cannot subscript object of type 'Literal[Identity] | Literal[1]' with no `__getitem__` method."],
);

Ok(())
}

#[test]
fn dunder_call() -> anyhow::Result<()> {
let mut db = setup_db();

db.write_dedented(
"/src/a.py",
"
class Multiplier:
def __init__(self, factor: float):
self.factor = factor
def __call__(self, number: float) -> float:
return number * self.factor
a = Multiplier(2.0)(3.0)
class Unit:
...
b = Unit()(3.0)
",
)?;

assert_public_ty(&db, "/src/a.py", "a", "float");
assert_public_ty(&db, "/src/a.py", "b", "Unknown");

assert_file_diagnostics(&db, "src/a.py", &["Object of type 'Unit' is not callable."]);

Ok(())
}

#[test]
fn boolean_or_expression() -> anyhow::Result<()> {
let mut db = setup_db();
Expand Down

0 comments on commit 11acc07

Please sign in to comment.