Skip to content

Commit

Permalink
Gracefully recover from a cycle in defered type inference
Browse files Browse the repository at this point in the history
  • Loading branch information
rtpg committed Oct 10, 2024
1 parent e3af8dc commit 8ca348d
Showing 1 changed file with 153 additions and 57 deletions.
210 changes: 153 additions & 57 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,27 @@ fn infer_definition_types_cycle_recovery<'db>(
inference
}

/// Cycle recovery for [`infer_deferred_types()`]: for now, just [`Type::Unknown`]
/// TODO fixpoint iteration
fn infer_deferred_types_cycle_recovery<'db>(
db: &'db dyn Db,
_cycle: &salsa::Cycle,
input: Definition<'db>,
) -> TypeInference<'db> {
tracing::trace!("infer_deferred_types_cycle_recovery");
let mut inference = TypeInference::default();
let category = input.category(db);
if category.is_declaration() {
inference.declarations.insert(input, Type::Unknown);
}
if category.is_binding() {
inference.bindings.insert(input, Type::Unknown);
}
// TODO we don't fill in expression types for the cycle-participant definitions, which can
// later cause a panic when looking up an expression type.
inference
}

/// Infer all types for a [`Definition`] (including sub-expressions).
/// Use when resolving a symbol name use or public type of a symbol.
#[salsa::tracked(return_ref, recovery_fn=infer_definition_types_cycle_recovery)]
Expand All @@ -120,7 +141,7 @@ pub(crate) fn infer_definition_types<'db>(
///
/// Deferred expressions are type expressions (annotations, base classes, aliases...) in a stub
/// file, or in a file with `from __future__ import annotations`, or stringified annotations.
#[salsa::tracked(return_ref)]
#[salsa::tracked(return_ref, recovery_fn=infer_deferred_types_cycle_recovery)]
pub(crate) fn infer_deferred_types<'db>(
db: &'db dyn Db,
definition: Definition<'db>,
Expand Down Expand Up @@ -190,9 +211,9 @@ pub(crate) struct TypeInference<'db> {
}

impl<'db> TypeInference<'db> {
pub(crate) fn expression_ty(&self, expression: ScopedExpressionId) -> Type<'db> {
self.expressions[&expression]
}
// pub(crate) fn expression_ty(&self, expression: ScopedExpressionId) -> Type<'db> {
// self.expressions[&expression]
// }

pub(crate) fn try_expression_ty(&self, expression: ScopedExpressionId) -> Option<Type<'db>> {
self.expressions.get(&expression).copied()
Expand Down Expand Up @@ -334,9 +355,11 @@ impl<'db> TypeInferenceBuilder<'db> {
/// Get the already-inferred type of an expression node.
///
/// PANIC if no type has been inferred for this node.
fn expression_ty(&self, expr: &ast::Expr) -> Type<'db> {
self.types
.expression_ty(expr.scoped_ast_id(self.db, self.scope))
fn try_expression_ty(&self, expr: &ast::Expr) -> Option<Type<'db>> {
match expr.scoped_ast_id(self.db, self.scope) {
Some(id) => self.types.try_expression_ty(id),
None => None,
}
}

/// Infers types in the given [`InferenceRegion`].
Expand Down Expand Up @@ -700,9 +723,15 @@ impl<'db> TypeInferenceBuilder<'db> {
}

fn infer_definition(&mut self, node: impl Into<DefinitionNodeKey>) {
let definition = self.index.definition(node);
let result = infer_definition_types(self.db, definition);
self.extend(result);
match self.index.definition(node) {
Some(definition) => {
let result = infer_definition_types(self.db, definition);
self.extend(result);
}
None => {
tracing::warn!("Couldn't find definition");
}
}
}

fn infer_function_definition_statement(&mut self, function: &ast::StmtFunctionDef) {
Expand Down Expand Up @@ -1001,12 +1030,17 @@ impl<'db> TypeInferenceBuilder<'db> {

// TODO(dhruvmanila): The correct type inference here is the return type of the __enter__
// method of the context manager.
let context_expr_ty = self.expression_ty(&with_item.context_expr);
let context_expr_ty = self.try_expression_ty(&with_item.context_expr).unwrap();

self.types
.expressions
.insert(target.scoped_ast_id(self.db, self.scope), context_expr_ty);
self.add_binding(target.into(), definition, context_expr_ty);
match target.scoped_ast_id(self.db, self.scope) {
Some(id) => {
self.types.expressions.insert(id, context_expr_ty);
self.add_binding(target.into(), definition, context_expr_ty);
}
_ => {
tracing::warn!("Couldn't find ID to infer with");
}
}
}

fn infer_except_handler_definition(
Expand Down Expand Up @@ -1173,11 +1207,18 @@ impl<'db> TypeInferenceBuilder<'db> {
let expression = self.index.expression(assignment.value.as_ref());
let result = infer_expression_types(self.db, expression);
self.extend(result);
let value_ty = self.expression_ty(&assignment.value);
let value_ty = self
.try_expression_ty(&assignment.value)
.unwrap_or(Type::Unknown);
self.add_binding(assignment.into(), definition, value_ty);
self.types
.expressions
.insert(target.scoped_ast_id(self.db, self.scope), value_ty);
match target.scoped_ast_id(self.db, self.scope) {
Some(id) => {
self.types.expressions.insert(id, value_ty);
}
None => {
tracing::warn!("Couldn't find ID for target");
}
}
}

fn infer_annotated_assignment_statement(&mut self, assignment: &ast::StmtAnnAssign) {
Expand Down Expand Up @@ -1369,7 +1410,7 @@ impl<'db> TypeInferenceBuilder<'db> {
let expression = self.index.expression(iterable);
let result = infer_expression_types(self.db, expression);
self.extend(result);
let iterable_ty = self.expression_ty(iterable);
let iterable_ty = self.try_expression_ty(iterable).unwrap();

let loop_var_value_ty = if is_async {
// TODO(Alex): async iterables/iterators!
Expand All @@ -1380,10 +1421,15 @@ impl<'db> TypeInferenceBuilder<'db> {
.unwrap_with_diagnostic(iterable.into(), self)
};

self.types
.expressions
.insert(target.scoped_ast_id(self.db, self.scope), loop_var_value_ty);
self.add_binding(target.into(), definition, loop_var_value_ty);
match target.scoped_ast_id(self.db, self.scope) {
Some(id) => {
self.types.expressions.insert(id, loop_var_value_ty);
self.add_binding(target.into(), definition, loop_var_value_ty);
}
None => {
tracing::warn!("Failed to find target ID");
}
}
}

fn infer_while_statement(&mut self, while_statement: &ast::StmtWhile) {
Expand Down Expand Up @@ -1706,12 +1752,29 @@ impl<'db> TypeInferenceBuilder<'db> {
ast::Expr::Yield(yield_expression) => self.infer_yield_expression(yield_expression),
ast::Expr::YieldFrom(yield_from) => self.infer_yield_from_expression(yield_from),
ast::Expr::Await(await_expression) => self.infer_await_expression(await_expression),
ast::Expr::IpyEscapeCommand(_) => todo!("Implement Ipy escape command support"),
ast::Expr::IpyEscapeCommand(_) => {
// todo!("Implement Ipy escape command support"),
return Type::Unknown;
}
};

let expr_id = expression.scoped_ast_id(self.db, self.scope);
let previous = self.types.expressions.insert(expr_id, ty);
assert_eq!(previous, None);
match expression.scoped_ast_id(self.db, self.scope) {
Some(expr_id) => {
let previous = self.types.expressions.insert(expr_id, ty);
match previous {
None => {}
Some(Type::Unknown) => {
tracing::warn!("Already had included an unknown type");
}
other => {
assert_eq!(other, None);
}
}
}
None => {
tracing::warn!("Could not find ID for expression");
}
}

ty
}
Expand Down Expand Up @@ -2034,10 +2097,20 @@ impl<'db> TypeInferenceBuilder<'db> {
.parent_scope_id(self.scope.file_scope_id(self.db))
.expect("A comprehension should never be the top-level scope")
.to_scope_id(self.db, self.file);
result.expression_ty(iterable.scoped_ast_id(self.db, lookup_scope))
if let Some(id) = iterable.scoped_ast_id(self.db, lookup_scope) {
result.try_expression_ty(id).unwrap_or(Type::Unknown)
} else {
tracing::warn!("Couldn't find AST ID for iterable");
Type::Unknown
}
} else {
self.extend(result);
result.expression_ty(iterable.scoped_ast_id(self.db, self.scope))
if let Some(id) = iterable.scoped_ast_id(self.db, self.scope) {
result.try_expression_ty(id).unwrap_or(Type::Unknown)
} else {
tracing::warn!("Couldn't find AST ID for iterable");
Type::Unknown
}
};

let target_ty = if is_async {
Expand All @@ -2049,17 +2122,26 @@ impl<'db> TypeInferenceBuilder<'db> {
.unwrap_with_diagnostic(iterable.into(), self)
};

self.types
.expressions
.insert(target.scoped_ast_id(self.db, self.scope), target_ty);
match target.scoped_ast_id(self.db, self.scope) {
Some(id) => {
self.types.expressions.insert(id, target_ty);
}
None => {
tracing::warn!("Couldn't find AST ID for expression");
}
}
self.add_binding(target.into(), definition, target_ty);
}

fn infer_named_expression(&mut self, named: &ast::ExprNamed) -> Type<'db> {
let definition = self.index.definition(named);
let result = infer_definition_types(self.db, definition);
self.extend(result);
result.binding_ty(definition)
if let Some(definition) = self.index.definition(named) {
let result = infer_definition_types(self.db, definition);
self.extend(result);
result.binding_ty(definition)
} else {
tracing::warn!("Couldn't find definition");
Type::Unknown
}
}

fn infer_named_expression_definition(
Expand Down Expand Up @@ -2258,23 +2340,29 @@ impl<'db> TypeInferenceBuilder<'db> {
match ctx {
ExprContext::Load => {
let use_def = self.index.use_def_map(file_scope_id);
let symbol = self
.index
.symbol_table(file_scope_id)
.symbol_id_by_name(id)
.expect("Expected the symbol table to create a symbol for every Name node");
let Some(symbol) = self.index.symbol_table(file_scope_id).symbol_id_by_name(id)
else {
tracing::warn!(
"Expected the symbol table to create a symbol for every Name node"
);
return Type::Unknown;
};
// if we're inferring types of deferred expressions, always treat them as public symbols
let (definitions, may_be_unbound) = if self.is_deferred() {
(
use_def.public_bindings(symbol),
use_def.public_may_be_unbound(symbol),
)
} else {
let use_id = name.scoped_use_id(self.db, self.scope);
(
use_def.bindings_at_use(use_id),
use_def.use_may_be_unbound(use_id),
)
if let Some(use_id) = name.scoped_use_id(self.db, self.scope) {
(
use_def.bindings_at_use(use_id),
use_def.use_may_be_unbound(use_id),
)
} else {
tracing::warn!("Failed to find name");
return Type::Unknown;
}
};

let unbound_ty = if may_be_unbound {
Expand Down Expand Up @@ -2521,8 +2609,8 @@ impl<'db> TypeInferenceBuilder<'db> {
.tuple_windows::<(_, _)>()
.zip(ops.iter())
.map(|((left, right), op)| {
let left_ty = self.expression_ty(left);
let right_ty = self.expression_ty(right);
let left_ty = self.try_expression_ty(left).unwrap_or(Type::Unknown);
let right_ty = self.try_expression_ty(right).unwrap_or(Type::Unknown);

self.infer_binary_type_comparison(left_ty, *op, right_ty)
.unwrap_or_else(|| {
Expand Down Expand Up @@ -2990,11 +3078,13 @@ impl<'db> TypeInferenceBuilder<'db> {

let ty = match expression {
ast::Expr::Name(name) => {
debug_assert!(
name.ctx.is_load(),
"name in a type expression is always 'load' but got: '{:?}'",
name.ctx
);
if !name.ctx.is_load() {
tracing::warn!(
"name in a type expression is always 'load' but got: '{:?}'",
name.ctx
);
return Type::Unknown;
}

self.infer_name_expression(name).to_instance(self.db)
}
Expand Down Expand Up @@ -3122,9 +3212,15 @@ impl<'db> TypeInferenceBuilder<'db> {
ast::Expr::IpyEscapeCommand(_) => todo!("Implement Ipy escape command support"),
};

let expr_id = expression.scoped_ast_id(self.db, self.scope);
let previous = self.types.expressions.insert(expr_id, ty);
assert!(previous.is_none());
match expression.scoped_ast_id(self.db, self.scope) {
Some(expr_id) => {
let previous = self.types.expressions.insert(expr_id, ty);
assert!(previous.is_none());
}
None => {
tracing::warn!("Could not find AST ID for expression");
}
}

ty
}
Expand Down

0 comments on commit 8ca348d

Please sign in to comment.