Skip to content

Commit

Permalink
Add some basic subscript type inference (#13562)
Browse files Browse the repository at this point in the history
## Summary

Just for tuples and strings -- the easiest cases. I think most of the
rest require generic support?
  • Loading branch information
charliermarsh authored Sep 30, 2024
1 parent 32c746b commit c9c748a
Showing 1 changed file with 228 additions and 5 deletions.
233 changes: 228 additions & 5 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1250,6 +1250,42 @@ impl<'db> TypeInferenceBuilder<'db> {
);
}

/// Emit a diagnostic declaring that an index is out of bounds for a tuple.
pub(super) fn tuple_index_out_of_bounds_diagnostic(
&mut self,
node: AnyNodeRef,
tuple_ty: Type<'db>,
length: usize,
index: i64,
) {
self.add_diagnostic(
node,
"index-out-of-bounds",
format_args!(
"Index {index} is out of bounds for tuple of type '{}' with length {length}.",
tuple_ty.display(self.db)
),
);
}

/// Emit a diagnostic declaring that an index is out of bounds for a string.
pub(super) fn string_index_out_of_bounds_diagnostic(
&mut self,
node: AnyNodeRef,
string_ty: Type<'db>,
length: usize,
index: i64,
) {
self.add_diagnostic(
node,
"index-out-of-bounds",
format_args!(
"Index {index} is out of bounds for string '{}' with length {length}.",
string_ty.display(self.db)
),
);
}

fn infer_for_statement_definition(
&mut self,
target: &ast::ExprName,
Expand Down Expand Up @@ -2389,11 +2425,127 @@ impl<'db> TypeInferenceBuilder<'db> {
ctx: _,
} = subscript;

self.infer_expression(slice);
self.infer_expression(value);

// TODO actual subscript support
Type::Unknown
let value_ty = self.infer_expression(value);
let slice_ty = self.infer_expression(slice);

match (value_ty, slice_ty) {
// Ex) Given `("a", "b", "c", "d")[1]`, return `"b"`
(Type::Tuple(tuple_ty), Type::IntLiteral(int)) if int >= 0 => {
let elements = tuple_ty.elements(self.db);
usize::try_from(int)
.ok()
.and_then(|index| elements.get(index).copied())
.unwrap_or_else(|| {
self.tuple_index_out_of_bounds_diagnostic(
(&**value).into(),
value_ty,
elements.len(),
int,
);
Type::Unknown
})
}
// Ex) Given `("a", "b", "c", "d")[-1]`, return `"c"`
(Type::Tuple(tuple_ty), Type::IntLiteral(int)) if int < 0 => {
let elements = tuple_ty.elements(self.db);
int.checked_neg()
.and_then(|int| usize::try_from(int).ok())
.and_then(|index| elements.len().checked_sub(index))
.and_then(|index| elements.get(index).copied())
.unwrap_or_else(|| {
self.tuple_index_out_of_bounds_diagnostic(
(&**value).into(),
value_ty,
elements.len(),
int,
);
Type::Unknown
})
}
// Ex) Given `("a", "b", "c", "d")[True]`, return `"b"`
(Type::Tuple(tuple_ty), Type::BooleanLiteral(bool)) => {
let elements = tuple_ty.elements(self.db);
let int = i64::from(bool);
elements.get(usize::from(bool)).copied().unwrap_or_else(|| {
self.tuple_index_out_of_bounds_diagnostic(
(&**value).into(),
value_ty,
elements.len(),
int,
);
Type::Unknown
})
}
// Ex) Given `"value"[1]`, return `"a"`
(Type::StringLiteral(literal_ty), Type::IntLiteral(int)) if int >= 0 => {
let literal_value = literal_ty.value(self.db);
usize::try_from(int)
.ok()
.and_then(|index| literal_value.chars().nth(index))
.map(|ch| {
Type::StringLiteral(StringLiteralType::new(
self.db,
ch.to_string().into_boxed_str(),
))
})
.unwrap_or_else(|| {
self.string_index_out_of_bounds_diagnostic(
(&**value).into(),
value_ty,
literal_value.chars().count(),
int,
);
Type::Unknown
})
}
// Ex) Given `"value"[-1]`, return `"e"`
(Type::StringLiteral(literal_ty), Type::IntLiteral(int)) if int < 0 => {
let literal_value = literal_ty.value(self.db);
int.checked_neg()
.and_then(|int| usize::try_from(int).ok())
.and_then(|index| index.checked_sub(1))
.and_then(|index| literal_value.chars().rev().nth(index))
.map(|ch| {
Type::StringLiteral(StringLiteralType::new(
self.db,
ch.to_string().into_boxed_str(),
))
})
.unwrap_or_else(|| {
self.string_index_out_of_bounds_diagnostic(
(&**value).into(),
value_ty,
literal_value.chars().count(),
int,
);
Type::Unknown
})
}
// Ex) Given `"value"[True]`, return `"a"`
(Type::StringLiteral(literal_ty), Type::BooleanLiteral(bool)) => {
let literal_value = literal_ty.value(self.db);
let int = i64::from(bool);
literal_value
.chars()
.nth(usize::from(bool))
.map(|ch| {
Type::StringLiteral(StringLiteralType::new(
self.db,
ch.to_string().into_boxed_str(),
))
})
.unwrap_or_else(|| {
self.string_index_out_of_bounds_diagnostic(
(&**value).into(),
value_ty,
literal_value.chars().count(),
int,
);
Type::Unknown
})
}
_ => Type::Unknown,
}
}

fn infer_slice_expression(&mut self, slice: &ast::ExprSlice) -> Type<'db> {
Expand Down Expand Up @@ -6425,6 +6577,77 @@ mod tests {
Ok(())
}

#[test]
fn subscript_tuple() -> anyhow::Result<()> {
let mut db = setup_db();

db.write_dedented(
"/src/a.py",
"
t = (1, 'a', 'b')
a = t[0]
b = t[1]
c = t[-1]
d = t[-2]
e = t[4]
f = t[-4]
",
)?;

assert_public_ty(&db, "/src/a.py", "a", "Literal[1]");
assert_public_ty(&db, "/src/a.py", "b", "Literal[\"a\"]");
assert_public_ty(&db, "/src/a.py", "c", "Literal[\"b\"]");
assert_public_ty(&db, "/src/a.py", "d", "Literal[\"a\"]");
assert_public_ty(&db, "/src/a.py", "e", "Unknown");
assert_public_ty(&db, "/src/a.py", "f", "Unknown");

assert_file_diagnostics(
&db,
"src/a.py",
&["Index 4 is out of bounds for tuple of type 'tuple[Literal[1], Literal[\"a\"], Literal[\"b\"]]' with length 3.", "Index -4 is out of bounds for tuple of type 'tuple[Literal[1], Literal[\"a\"], Literal[\"b\"]]' with length 3."],
);

Ok(())
}

#[test]
fn subscript_literal_string() -> anyhow::Result<()> {
let mut db = setup_db();

db.write_dedented(
"/src/a.py",
"
s = 'abcde'
a = s[0]
b = s[1]
c = s[-1]
d = s[-2]
e = s[8]
f = s[-8]
",
)?;

assert_public_ty(&db, "/src/a.py", "a", "Literal[\"a\"]");
assert_public_ty(&db, "/src/a.py", "b", "Literal[\"b\"]");
assert_public_ty(&db, "/src/a.py", "c", "Literal[\"e\"]");
assert_public_ty(&db, "/src/a.py", "d", "Literal[\"d\"]");
assert_public_ty(&db, "/src/a.py", "e", "Unknown");
assert_public_ty(&db, "/src/a.py", "f", "Unknown");

assert_file_diagnostics(
&db,
"src/a.py",
&[
"Index 8 is out of bounds for string 'Literal[\"abcde\"]' with length 5.",
"Index -8 is out of bounds for string 'Literal[\"abcde\"]' with length 5.",
],
);

Ok(())
}

#[test]
fn boolean_or_expression() -> anyhow::Result<()> {
let mut db = setup_db();
Expand Down

0 comments on commit c9c748a

Please sign in to comment.