Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[red-knot] Enhancing Diagnostics for Compare Expression Inference #13819

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,20 @@ reveal_type(a) # revealed: bool
b = 0 not in 10 # error: "Operator `not in` is not supported for types `Literal[0]` and `Literal[10]`"
reveal_type(b) # revealed: bool

c = object() < 5 # error: "Operator `<` is not supported for types `object` and `Literal[5]`"
c = object() < 5 # error: "Operator `<` is not supported for types `object` and `int`"
reveal_type(c) # revealed: Unknown

# TODO should error, need to check if __lt__ signature is valid for right operand
d = 5 < object()
# TODO: should be `Unknown`
reveal_type(d) # revealed: bool

int_literal_or_str_literal = (1 if flag else "foo")
e = 42 in int_literal_or_str_literal # error: "Operator `in` is not supported for types `Literal[42]` and `Literal[1]`, in comparing `Literal[42]` with `Literal[1] | Literal["foo"]`"
reveal_type(e) # revealed: bool

# TODO: should error, need to check if __lt__ signature is valid for right operand
# error may be "Operator `<` is not supported for types `int` and `str`, in comparing `tuple[Literal[1], Literal[2]]` with `tuple[Literal[1], Literal["hello"]]`
f = (1, 2) < (1, "hello")
reveal_type(f) # revealed: @Todo
```
172 changes: 109 additions & 63 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2672,18 +2672,28 @@ impl<'db> TypeInferenceBuilder<'db> {
let right_ty = self.expression_ty(right);

self.infer_binary_type_comparison(left_ty, *op, right_ty)
.unwrap_or_else(|| {
.unwrap_or_else(|error| {
// Handle unsupported operators (diagnostic, `bool`/`Unknown` outcome)
self.add_diagnostic(
AnyNodeRef::ExprCompare(compare),
"operator-unsupported",
format_args!(
"Operator `{}` is not supported for types `{}` and `{}`",
op,
left_ty.display(self.db),
right_ty.display(self.db)
"Operator `{}` is not supported for types `{}` and `{}`{}",
error.op,
error.left_ty.display(self.db),
error.right_ty.display(self.db),
if (left_ty, right_ty) == (error.left_ty, error.right_ty) {
String::new()
} else {
format!(
", in comparing `{}` with `{}`",
left_ty.display(self.db),
right_ty.display(self.db)
)
}
),
);

match op {
// `in, not in, is, is not` always return bool instances
ast::CmpOp::In
Expand All @@ -2710,7 +2720,7 @@ impl<'db> TypeInferenceBuilder<'db> {
left: Type<'db>,
op: ast::CmpOp,
right: Type<'db>,
) -> Option<Type<'db>> {
) -> Result<Type<'db>, OperatorUnsupportedError<'db>> {
// Note: identity (is, is not) for equal builtin types is unreliable and not part of the
// language spec.
// - `[ast::CompOp::Is]`: return `false` if unequal, `bool` if equal
Expand All @@ -2721,39 +2731,43 @@ impl<'db> TypeInferenceBuilder<'db> {
for element in union.elements(self.db) {
builder = builder.add(self.infer_binary_type_comparison(*element, op, other)?);
}
Some(builder.build())
Ok(builder.build())
}
(other, Type::Union(union)) => {
let mut builder = UnionBuilder::new(self.db);
for element in union.elements(self.db) {
builder = builder.add(self.infer_binary_type_comparison(other, op, *element)?);
}
Some(builder.build())
Ok(builder.build())
}

(Type::IntLiteral(n), Type::IntLiteral(m)) => match op {
ast::CmpOp::Eq => Some(Type::BooleanLiteral(n == m)),
ast::CmpOp::NotEq => Some(Type::BooleanLiteral(n != m)),
ast::CmpOp::Lt => Some(Type::BooleanLiteral(n < m)),
ast::CmpOp::LtE => Some(Type::BooleanLiteral(n <= m)),
ast::CmpOp::Gt => Some(Type::BooleanLiteral(n > m)),
ast::CmpOp::GtE => Some(Type::BooleanLiteral(n >= m)),
ast::CmpOp::Eq => Ok(Type::BooleanLiteral(n == m)),
ast::CmpOp::NotEq => Ok(Type::BooleanLiteral(n != m)),
ast::CmpOp::Lt => Ok(Type::BooleanLiteral(n < m)),
ast::CmpOp::LtE => Ok(Type::BooleanLiteral(n <= m)),
ast::CmpOp::Gt => Ok(Type::BooleanLiteral(n > m)),
ast::CmpOp::GtE => Ok(Type::BooleanLiteral(n >= m)),
ast::CmpOp::Is => {
if n == m {
Some(KnownClass::Bool.to_instance(self.db))
Ok(KnownClass::Bool.to_instance(self.db))
} else {
Some(Type::BooleanLiteral(false))
Ok(Type::BooleanLiteral(false))
}
}
ast::CmpOp::IsNot => {
if n == m {
Some(KnownClass::Bool.to_instance(self.db))
Ok(KnownClass::Bool.to_instance(self.db))
} else {
Some(Type::BooleanLiteral(true))
Ok(Type::BooleanLiteral(true))
}
}
// Undefined for (int, int)
ast::CmpOp::In | ast::CmpOp::NotIn => None,
ast::CmpOp::In | ast::CmpOp::NotIn => Err(OperatorUnsupportedError {
op,
left_ty: left,
right_ty: right,
}),
},
(Type::IntLiteral(_), Type::Instance(_)) => {
self.infer_binary_type_comparison(KnownClass::Int.to_instance(self.db), op, right)
Expand Down Expand Up @@ -2784,26 +2798,26 @@ impl<'db> TypeInferenceBuilder<'db> {
let s1 = salsa_s1.value(self.db);
let s2 = salsa_s2.value(self.db);
match op {
ast::CmpOp::Eq => Some(Type::BooleanLiteral(s1 == s2)),
ast::CmpOp::NotEq => Some(Type::BooleanLiteral(s1 != s2)),
ast::CmpOp::Lt => Some(Type::BooleanLiteral(s1 < s2)),
ast::CmpOp::LtE => Some(Type::BooleanLiteral(s1 <= s2)),
ast::CmpOp::Gt => Some(Type::BooleanLiteral(s1 > s2)),
ast::CmpOp::GtE => Some(Type::BooleanLiteral(s1 >= s2)),
ast::CmpOp::In => Some(Type::BooleanLiteral(s2.contains(s1.as_ref()))),
ast::CmpOp::NotIn => Some(Type::BooleanLiteral(!s2.contains(s1.as_ref()))),
ast::CmpOp::Eq => Ok(Type::BooleanLiteral(s1 == s2)),
ast::CmpOp::NotEq => Ok(Type::BooleanLiteral(s1 != s2)),
ast::CmpOp::Lt => Ok(Type::BooleanLiteral(s1 < s2)),
ast::CmpOp::LtE => Ok(Type::BooleanLiteral(s1 <= s2)),
ast::CmpOp::Gt => Ok(Type::BooleanLiteral(s1 > s2)),
ast::CmpOp::GtE => Ok(Type::BooleanLiteral(s1 >= s2)),
ast::CmpOp::In => Ok(Type::BooleanLiteral(s2.contains(s1.as_ref()))),
ast::CmpOp::NotIn => Ok(Type::BooleanLiteral(!s2.contains(s1.as_ref()))),
ast::CmpOp::Is => {
if s1 == s2 {
Some(KnownClass::Bool.to_instance(self.db))
Ok(KnownClass::Bool.to_instance(self.db))
} else {
Some(Type::BooleanLiteral(false))
Ok(Type::BooleanLiteral(false))
}
}
ast::CmpOp::IsNot => {
if s1 == s2 {
Some(KnownClass::Bool.to_instance(self.db))
Ok(KnownClass::Bool.to_instance(self.db))
} else {
Some(Type::BooleanLiteral(true))
Ok(Type::BooleanLiteral(true))
}
}
}
Expand All @@ -2826,30 +2840,30 @@ impl<'db> TypeInferenceBuilder<'db> {
let b1 = &**salsa_b1.value(self.db);
let b2 = &**salsa_b2.value(self.db);
match op {
ast::CmpOp::Eq => Some(Type::BooleanLiteral(b1 == b2)),
ast::CmpOp::NotEq => Some(Type::BooleanLiteral(b1 != b2)),
ast::CmpOp::Lt => Some(Type::BooleanLiteral(b1 < b2)),
ast::CmpOp::LtE => Some(Type::BooleanLiteral(b1 <= b2)),
ast::CmpOp::Gt => Some(Type::BooleanLiteral(b1 > b2)),
ast::CmpOp::GtE => Some(Type::BooleanLiteral(b1 >= b2)),
ast::CmpOp::Eq => Ok(Type::BooleanLiteral(b1 == b2)),
ast::CmpOp::NotEq => Ok(Type::BooleanLiteral(b1 != b2)),
ast::CmpOp::Lt => Ok(Type::BooleanLiteral(b1 < b2)),
ast::CmpOp::LtE => Ok(Type::BooleanLiteral(b1 <= b2)),
ast::CmpOp::Gt => Ok(Type::BooleanLiteral(b1 > b2)),
ast::CmpOp::GtE => Ok(Type::BooleanLiteral(b1 >= b2)),
ast::CmpOp::In => {
Some(Type::BooleanLiteral(memchr::memmem::find(b2, b1).is_some()))
Ok(Type::BooleanLiteral(memchr::memmem::find(b2, b1).is_some()))
}
ast::CmpOp::NotIn => {
Some(Type::BooleanLiteral(memchr::memmem::find(b2, b1).is_none()))
Ok(Type::BooleanLiteral(memchr::memmem::find(b2, b1).is_none()))
}
ast::CmpOp::Is => {
if b1 == b2 {
Some(KnownClass::Bool.to_instance(self.db))
Ok(KnownClass::Bool.to_instance(self.db))
} else {
Some(Type::BooleanLiteral(false))
Ok(Type::BooleanLiteral(false))
}
}
ast::CmpOp::IsNot => {
if b1 == b2 {
Some(KnownClass::Bool.to_instance(self.db))
Ok(KnownClass::Bool.to_instance(self.db))
} else {
Some(Type::BooleanLiteral(true))
Ok(Type::BooleanLiteral(true))
}
}
}
Expand Down Expand Up @@ -2887,7 +2901,7 @@ impl<'db> TypeInferenceBuilder<'db> {
).expect("infer_binary_type_comparison should never return None for `CmpOp::Eq`");

match eq_result {
Type::Todo => return Some(Type::Todo),
Type::Todo => return Ok(Type::Todo),
ty => match ty.bool(self.db) {
Truthiness::AlwaysTrue => eq_count += 1,
Truthiness::AlwaysFalse => not_eq_count += 1,
Expand All @@ -2897,11 +2911,11 @@ impl<'db> TypeInferenceBuilder<'db> {
}

if eq_count >= 1 {
Some(Type::BooleanLiteral(op.is_in()))
Ok(Type::BooleanLiteral(op.is_in()))
} else if not_eq_count == rhs_elements.len() {
Some(Type::BooleanLiteral(op.is_not_in()))
Ok(Type::BooleanLiteral(op.is_not_in()))
} else {
Some(KnownClass::Bool.to_instance(self.db))
Ok(KnownClass::Bool.to_instance(self.db))
}
}
ast::CmpOp::Is | ast::CmpOp::IsNot => {
Expand All @@ -2912,7 +2926,7 @@ impl<'db> TypeInferenceBuilder<'db> {
"infer_binary_type_comparison should never return None for `CmpOp::Eq`",
);

Some(match eq_result {
Ok(match eq_result {
Type::Todo => Type::Todo,
ty => match ty.bool(self.db) {
Truthiness::AlwaysFalse => Type::BooleanLiteral(op.is_is_not()),
Expand All @@ -2925,16 +2939,19 @@ impl<'db> TypeInferenceBuilder<'db> {

// Lookup the rich comparison `__dunder__` methods on instances
(Type::Instance(left_class_ty), Type::Instance(right_class_ty)) => match op {
ast::CmpOp::Lt => {
perform_rich_comparison(self.db, left_class_ty, right_class_ty, "__lt__")
}
ast::CmpOp::Lt => perform_rich_comparison(
self.db,
left_class_ty,
right_class_ty,
RichCompareOperator::Lt,
),
// TODO: implement mapping from `ast::CmpOp` to rich comparison methods
_ => Some(Type::Todo),
_ => Ok(Type::Todo),
},
// TODO: handle more types
_ => match op {
ast::CmpOp::Is | ast::CmpOp::IsNot => Some(KnownClass::Bool.to_instance(self.db)),
_ => Some(Type::Todo),
ast::CmpOp::Is | ast::CmpOp::IsNot => Ok(KnownClass::Bool.to_instance(self.db)),
_ => Ok(Type::Todo),
},
}
}
Expand All @@ -2949,7 +2966,7 @@ impl<'db> TypeInferenceBuilder<'db> {
left: &[Type<'db>],
op: RichCompareOperator,
right: &[Type<'db>],
) -> Option<Type<'db>> {
) -> Result<Type<'db>, OperatorUnsupportedError<'db>> {
// Compare paired elements from left and right slices
for (l_ty, r_ty) in left.iter().copied().zip(right.iter().copied()) {
let eq_result = self
Expand All @@ -2958,7 +2975,7 @@ impl<'db> TypeInferenceBuilder<'db> {

match eq_result {
// If propagation is required, return the result as is
Type::Todo => return Some(Type::Todo),
Type::Todo => return Ok(Type::Todo),
ty => match ty.bool(self.db) {
// Types are equal, continue to the next pair
Truthiness::AlwaysTrue => continue,
Expand All @@ -2968,7 +2985,7 @@ impl<'db> TypeInferenceBuilder<'db> {
}
// If the intermediate result is ambiguous, we cannot determine the final result as BooleanLiteral.
// In this case, we simply return a bool instance.
Truthiness::Ambiguous => return Some(KnownClass::Bool.to_instance(self.db)),
Truthiness::Ambiguous => return Ok(KnownClass::Bool.to_instance(self.db)),
},
}
}
Expand All @@ -2978,7 +2995,7 @@ impl<'db> TypeInferenceBuilder<'db> {
// We return a comparison of the slice lengths based on the operator.
let (left_len, right_len) = (left.len(), right.len());

Some(Type::BooleanLiteral(match op {
Ok(Type::BooleanLiteral(match op {
RichCompareOperator::Eq => left_len == right_len,
RichCompareOperator::Ne => left_len != right_len,
RichCompareOperator::Lt => left_len < right_len,
Expand Down Expand Up @@ -3452,6 +3469,26 @@ impl From<RichCompareOperator> for ast::CmpOp {
}
}

impl RichCompareOperator {
fn dunder_name(self) -> &'static str {
AlexWaygood marked this conversation as resolved.
Show resolved Hide resolved
match self {
RichCompareOperator::Eq => "__eq__",
RichCompareOperator::Ne => "__ne__",
RichCompareOperator::Lt => "__lt__",
RichCompareOperator::Le => "__le__",
RichCompareOperator::Gt => "__gt__",
RichCompareOperator::Ge => "__ge__",
}
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct OperatorUnsupportedError<'db> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor nit: perhaps this should be called BinaryOperatorUnsupportedError, since there are also unsupported unary operations, but this error struct is only suitable for carrying information about a binary operator.

Copy link
Contributor

@carljm carljm Oct 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, should be CompareOperatorUnsupportedError, since this is actually specific to ast.CmpOp. Or maybe for brevity could just be CompareUnsupportedError.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, it’s coming from the rule name
( https://github.com/astral-sh/ruff/pull/13819/files#diff-65c2c229c88f4021638c996a7496384000d9e7b53b08426b34e92f120bd30b06R2783)
But you're right, CompareUnsupportedError makes more sense! I think the rule name could be updated as well, but it’s a pretty minor thing.

op: ast::CmpOp,
left_ty: Type<'db>,
right_ty: Type<'db>,
}

fn format_import_from_module(level: u32, module: Option<&str>) -> String {
format!(
"{}{}",
Expand Down Expand Up @@ -3532,26 +3569,35 @@ fn perform_rich_comparison<'db>(
db: &'db dyn Db,
left: ClassType<'db>,
right: ClassType<'db>,
dunder_name: &str,
) -> Option<Type<'db>> {
op: RichCompareOperator,
) -> Result<Type<'db>, OperatorUnsupportedError<'db>> {
// The following resource has details about the rich comparison algorithm:
// https://snarky.ca/unravelling-rich-comparison-operators/
//
// TODO: the reflected dunder actually has priority if the r.h.s. is a strict subclass of the
// l.h.s.
// TODO: `object.__ne__` will call `__eq__` if `__ne__` is not defined

let dunder = left.class_member(db, dunder_name);
let dunder = left.class_member(db, op.dunder_name());
if !dunder.is_unbound() {
// TODO: this currently gives the return type even if the arg types are invalid
// (e.g. int.__lt__ with string instance should be None, currently bool)
return dunder
.call(db, &[Type::Instance(left), Type::Instance(right)])
.return_ty(db);
.return_ty(db)
.ok_or(OperatorUnsupportedError {
AlexWaygood marked this conversation as resolved.
Show resolved Hide resolved
op: op.into(),
left_ty: Type::Instance(left),
right_ty: Type::Instance(right),
});
}

// TODO: reflected dunder -- (==, ==), (!=, !=), (<, >), (>, <), (<=, >=), (>=, <=)
None
Err(OperatorUnsupportedError {
op: op.into(),
left_ty: Type::Instance(left),
right_ty: Type::Instance(right),
})
}

#[cfg(test)]
Expand Down
Loading