Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add some basic subscript type inference #13562

Merged
merged 5 commits into from
Sep 30, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 228 additions & 5 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1250,6 +1250,42 @@ impl<'db> TypeInferenceBuilder<'db> {
);
}

/// Emit a diagnostic declaring that an index is out of bounds for a tuple.
pub(super) fn tuple_index_out_of_bounds_diagnostic(
&mut self,
node: AnyNodeRef,
tuple_ty: Type<'db>,
length: usize,
index: i64,
) {
self.add_diagnostic(
node,
"index-out-of-bounds",
format_args!(
"Index {index} is out of bounds for tuple of type '{}' with length {length}.",
tuple_ty.display(self.db)
),
);
}

/// Emit a diagnostic declaring that an index is out of bounds for a string.
pub(super) fn string_index_out_of_bounds_diagnostic(
&mut self,
node: AnyNodeRef,
string_ty: Type<'db>,
length: usize,
index: i64,
) {
self.add_diagnostic(
node,
"index-out-of-bounds",
format_args!(
"Index {index} is out of bounds for string '{}' with length {length}.",
string_ty.display(self.db)
),
);
}

fn infer_for_statement_definition(
&mut self,
target: &ast::ExprName,
Expand Down Expand Up @@ -2380,11 +2416,127 @@ impl<'db> TypeInferenceBuilder<'db> {
ctx: _,
} = subscript;

self.infer_expression(slice);
self.infer_expression(value);

// TODO actual subscript support
Type::Unknown
let value_ty = self.infer_expression(value);
let slice_ty = self.infer_expression(slice);

match (value_ty, slice_ty) {
// Ex) Given `("a", "b", "c", "d")[1]`, return `"b"`
(Type::Tuple(tuple_ty), Type::IntLiteral(int)) if int >= 0 => {
let elements = tuple_ty.elements(self.db);
usize::try_from(int)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Python allows negative indexing, which indexes from the end of the container:

>>> t = (1, 2, 3)
>>> t[-1]
3

So we should handle that here. I would say we don't have to do it in this PR, and we can just add a TODO, but I don't like that in the meantime this would mean we'd wrongly emit an out-of-bounds error.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll handle it here, thanks.

.ok()
.and_then(|index| elements.get(index).copied())
.unwrap_or_else(|| {
self.tuple_index_out_of_bounds_diagnostic(
(&**value).into(),
value_ty,
elements.len(),
int,
);
Type::Unknown
})
}
// Ex) Given `("a", "b", "c", "d")[-1]`, return `"c"`
(Type::Tuple(tuple_ty), Type::IntLiteral(int)) if int < 0 => {
let elements = tuple_ty.elements(self.db);
int.checked_neg()
.and_then(|int| usize::try_from(int).ok())
.and_then(|index| elements.len().checked_sub(index))
.and_then(|index| elements.get(index).copied())
.unwrap_or_else(|| {
self.tuple_index_out_of_bounds_diagnostic(
(&**value).into(),
value_ty,
elements.len(),
int,
);
Type::Unknown
})
}
// Ex) Given `("a", "b", "c", "d")[True]`, return `"b"`
(Type::Tuple(tuple_ty), Type::BooleanLiteral(bool)) => {
let elements = tuple_ty.elements(self.db);
let int = i64::from(bool);
elements.get(usize::from(bool)).copied().unwrap_or_else(|| {
self.tuple_index_out_of_bounds_diagnostic(
(&**value).into(),
value_ty,
elements.len(),
int,
);
Type::Unknown
})
}
// Ex) Given `"value"[1]`, return `"a"`
(Type::StringLiteral(literal_ty), Type::IntLiteral(int)) if int >= 0 => {
let literal_value = literal_ty.value(self.db);
usize::try_from(int)
.ok()
.and_then(|index| literal_value.chars().nth(index))
.map(|ch| {
Type::StringLiteral(StringLiteralType::new(
self.db,
ch.to_string().into_boxed_str(),
))
})
.unwrap_or_else(|| {
self.string_index_out_of_bounds_diagnostic(
(&**value).into(),
value_ty,
literal_value.chars().count(),
int,
);
Type::Unknown
})
}
// Ex) Given `"value"[-1]`, return `"e"`
(Type::StringLiteral(literal_ty), Type::IntLiteral(int)) if int < 0 => {
let literal_value = literal_ty.value(self.db);
int.checked_neg()
.and_then(|int| usize::try_from(int).ok())
.and_then(|index| index.checked_sub(1))
.and_then(|index| literal_value.chars().rev().nth(index))
.map(|ch| {
Type::StringLiteral(StringLiteralType::new(
self.db,
ch.to_string().into_boxed_str(),
))
})
.unwrap_or_else(|| {
self.string_index_out_of_bounds_diagnostic(
(&**value).into(),
value_ty,
literal_value.chars().count(),
int,
);
Type::Unknown
})
}
// Ex) Given `"value"[True]`, return `"a"`
(Type::StringLiteral(literal_ty), Type::BooleanLiteral(bool)) => {
let literal_value = literal_ty.value(self.db);
let int = i64::from(bool);
literal_value
.chars()
.nth(usize::from(bool))
.map(|ch| {
Type::StringLiteral(StringLiteralType::new(
self.db,
ch.to_string().into_boxed_str(),
))
})
.unwrap_or_else(|| {
self.string_index_out_of_bounds_diagnostic(
(&**value).into(),
value_ty,
literal_value.chars().count(),
int,
);
Type::Unknown
})
}
_ => Type::Unknown,
}
}

fn infer_slice_expression(&mut self, slice: &ast::ExprSlice) -> Type<'db> {
Expand Down Expand Up @@ -6399,6 +6551,77 @@ mod tests {
Ok(())
}

#[test]
fn subscript_tuple() -> anyhow::Result<()> {
let mut db = setup_db();

db.write_dedented(
"/src/a.py",
"
t = (1, 'a', 'b')

a = t[0]
b = t[1]
c = t[-1]
d = t[-2]
e = t[4]
f = t[-4]
",
)?;

assert_public_ty(&db, "/src/a.py", "a", "Literal[1]");
assert_public_ty(&db, "/src/a.py", "b", "Literal[\"a\"]");
assert_public_ty(&db, "/src/a.py", "c", "Literal[\"b\"]");
assert_public_ty(&db, "/src/a.py", "d", "Literal[\"a\"]");
assert_public_ty(&db, "/src/a.py", "e", "Unknown");
assert_public_ty(&db, "/src/a.py", "f", "Unknown");

assert_file_diagnostics(
&db,
"src/a.py",
&["Index 4 is out of bounds for tuple of type 'tuple[Literal[1], Literal[\"a\"], Literal[\"b\"]]' with length 3.", "Index -4 is out of bounds for tuple of type 'tuple[Literal[1], Literal[\"a\"], Literal[\"b\"]]' with length 3."],
);

Ok(())
}

#[test]
fn subscript_literal_string() -> anyhow::Result<()> {
let mut db = setup_db();

db.write_dedented(
"/src/a.py",
"
s = 'abcde'

a = s[0]
b = s[1]
c = s[-1]
d = s[-2]
e = s[8]
f = s[-8]
",
)?;

assert_public_ty(&db, "/src/a.py", "a", "Literal[\"a\"]");
assert_public_ty(&db, "/src/a.py", "b", "Literal[\"b\"]");
assert_public_ty(&db, "/src/a.py", "c", "Literal[\"e\"]");
assert_public_ty(&db, "/src/a.py", "d", "Literal[\"d\"]");
assert_public_ty(&db, "/src/a.py", "e", "Unknown");
assert_public_ty(&db, "/src/a.py", "f", "Unknown");

assert_file_diagnostics(
&db,
"src/a.py",
&[
"Index 8 is out of bounds for string 'Literal[\"abcde\"]' with length 5.",
"Index -8 is out of bounds for string 'Literal[\"abcde\"]' with length 5.",
],
);

Ok(())
}

#[test]
fn boolean_or_expression() -> anyhow::Result<()> {
let mut db = setup_db();
Expand Down
Loading