From 8ca348deefcfb4f211b96c2de675d9a1cd241efa Mon Sep 17 00:00:00 2001 From: Raphael Gaschignard Date: Thu, 10 Oct 2024 16:26:25 +1000 Subject: [PATCH] Gracefully recover from a cycle in defered type inference --- .../src/types/infer.rs | 210 +++++++++++++----- 1 file changed, 153 insertions(+), 57 deletions(-) diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index f894d6dc24133..bef42a6526262 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -96,6 +96,27 @@ fn infer_definition_types_cycle_recovery<'db>( inference } +/// Cycle recovery for [`infer_deferred_types()`]: for now, just [`Type::Unknown`] +/// TODO fixpoint iteration +fn infer_deferred_types_cycle_recovery<'db>( + db: &'db dyn Db, + _cycle: &salsa::Cycle, + input: Definition<'db>, +) -> TypeInference<'db> { + tracing::trace!("infer_deferred_types_cycle_recovery"); + let mut inference = TypeInference::default(); + let category = input.category(db); + if category.is_declaration() { + inference.declarations.insert(input, Type::Unknown); + } + if category.is_binding() { + inference.bindings.insert(input, Type::Unknown); + } + // TODO we don't fill in expression types for the cycle-participant definitions, which can + // later cause a panic when looking up an expression type. + inference +} + /// Infer all types for a [`Definition`] (including sub-expressions). /// Use when resolving a symbol name use or public type of a symbol. #[salsa::tracked(return_ref, recovery_fn=infer_definition_types_cycle_recovery)] @@ -120,7 +141,7 @@ pub(crate) fn infer_definition_types<'db>( /// /// Deferred expressions are type expressions (annotations, base classes, aliases...) in a stub /// file, or in a file with `from __future__ import annotations`, or stringified annotations. -#[salsa::tracked(return_ref)] +#[salsa::tracked(return_ref, recovery_fn=infer_deferred_types_cycle_recovery)] pub(crate) fn infer_deferred_types<'db>( db: &'db dyn Db, definition: Definition<'db>, @@ -190,9 +211,9 @@ pub(crate) struct TypeInference<'db> { } impl<'db> TypeInference<'db> { - pub(crate) fn expression_ty(&self, expression: ScopedExpressionId) -> Type<'db> { - self.expressions[&expression] - } + // pub(crate) fn expression_ty(&self, expression: ScopedExpressionId) -> Type<'db> { + // self.expressions[&expression] + // } pub(crate) fn try_expression_ty(&self, expression: ScopedExpressionId) -> Option> { self.expressions.get(&expression).copied() @@ -334,9 +355,11 @@ impl<'db> TypeInferenceBuilder<'db> { /// Get the already-inferred type of an expression node. /// /// PANIC if no type has been inferred for this node. - fn expression_ty(&self, expr: &ast::Expr) -> Type<'db> { - self.types - .expression_ty(expr.scoped_ast_id(self.db, self.scope)) + fn try_expression_ty(&self, expr: &ast::Expr) -> Option> { + match expr.scoped_ast_id(self.db, self.scope) { + Some(id) => self.types.try_expression_ty(id), + None => None, + } } /// Infers types in the given [`InferenceRegion`]. @@ -700,9 +723,15 @@ impl<'db> TypeInferenceBuilder<'db> { } fn infer_definition(&mut self, node: impl Into) { - let definition = self.index.definition(node); - let result = infer_definition_types(self.db, definition); - self.extend(result); + match self.index.definition(node) { + Some(definition) => { + let result = infer_definition_types(self.db, definition); + self.extend(result); + } + None => { + tracing::warn!("Couldn't find definition"); + } + } } fn infer_function_definition_statement(&mut self, function: &ast::StmtFunctionDef) { @@ -1001,12 +1030,17 @@ impl<'db> TypeInferenceBuilder<'db> { // TODO(dhruvmanila): The correct type inference here is the return type of the __enter__ // method of the context manager. - let context_expr_ty = self.expression_ty(&with_item.context_expr); + let context_expr_ty = self.try_expression_ty(&with_item.context_expr).unwrap(); - self.types - .expressions - .insert(target.scoped_ast_id(self.db, self.scope), context_expr_ty); - self.add_binding(target.into(), definition, context_expr_ty); + match target.scoped_ast_id(self.db, self.scope) { + Some(id) => { + self.types.expressions.insert(id, context_expr_ty); + self.add_binding(target.into(), definition, context_expr_ty); + } + _ => { + tracing::warn!("Couldn't find ID to infer with"); + } + } } fn infer_except_handler_definition( @@ -1173,11 +1207,18 @@ impl<'db> TypeInferenceBuilder<'db> { let expression = self.index.expression(assignment.value.as_ref()); let result = infer_expression_types(self.db, expression); self.extend(result); - let value_ty = self.expression_ty(&assignment.value); + let value_ty = self + .try_expression_ty(&assignment.value) + .unwrap_or(Type::Unknown); self.add_binding(assignment.into(), definition, value_ty); - self.types - .expressions - .insert(target.scoped_ast_id(self.db, self.scope), value_ty); + match target.scoped_ast_id(self.db, self.scope) { + Some(id) => { + self.types.expressions.insert(id, value_ty); + } + None => { + tracing::warn!("Couldn't find ID for target"); + } + } } fn infer_annotated_assignment_statement(&mut self, assignment: &ast::StmtAnnAssign) { @@ -1369,7 +1410,7 @@ impl<'db> TypeInferenceBuilder<'db> { let expression = self.index.expression(iterable); let result = infer_expression_types(self.db, expression); self.extend(result); - let iterable_ty = self.expression_ty(iterable); + let iterable_ty = self.try_expression_ty(iterable).unwrap(); let loop_var_value_ty = if is_async { // TODO(Alex): async iterables/iterators! @@ -1380,10 +1421,15 @@ impl<'db> TypeInferenceBuilder<'db> { .unwrap_with_diagnostic(iterable.into(), self) }; - self.types - .expressions - .insert(target.scoped_ast_id(self.db, self.scope), loop_var_value_ty); - self.add_binding(target.into(), definition, loop_var_value_ty); + match target.scoped_ast_id(self.db, self.scope) { + Some(id) => { + self.types.expressions.insert(id, loop_var_value_ty); + self.add_binding(target.into(), definition, loop_var_value_ty); + } + None => { + tracing::warn!("Failed to find target ID"); + } + } } fn infer_while_statement(&mut self, while_statement: &ast::StmtWhile) { @@ -1706,12 +1752,29 @@ impl<'db> TypeInferenceBuilder<'db> { ast::Expr::Yield(yield_expression) => self.infer_yield_expression(yield_expression), ast::Expr::YieldFrom(yield_from) => self.infer_yield_from_expression(yield_from), ast::Expr::Await(await_expression) => self.infer_await_expression(await_expression), - ast::Expr::IpyEscapeCommand(_) => todo!("Implement Ipy escape command support"), + ast::Expr::IpyEscapeCommand(_) => { + // todo!("Implement Ipy escape command support"), + return Type::Unknown; + } }; - let expr_id = expression.scoped_ast_id(self.db, self.scope); - let previous = self.types.expressions.insert(expr_id, ty); - assert_eq!(previous, None); + match expression.scoped_ast_id(self.db, self.scope) { + Some(expr_id) => { + let previous = self.types.expressions.insert(expr_id, ty); + match previous { + None => {} + Some(Type::Unknown) => { + tracing::warn!("Already had included an unknown type"); + } + other => { + assert_eq!(other, None); + } + } + } + None => { + tracing::warn!("Could not find ID for expression"); + } + } ty } @@ -2034,10 +2097,20 @@ impl<'db> TypeInferenceBuilder<'db> { .parent_scope_id(self.scope.file_scope_id(self.db)) .expect("A comprehension should never be the top-level scope") .to_scope_id(self.db, self.file); - result.expression_ty(iterable.scoped_ast_id(self.db, lookup_scope)) + if let Some(id) = iterable.scoped_ast_id(self.db, lookup_scope) { + result.try_expression_ty(id).unwrap_or(Type::Unknown) + } else { + tracing::warn!("Couldn't find AST ID for iterable"); + Type::Unknown + } } else { self.extend(result); - result.expression_ty(iterable.scoped_ast_id(self.db, self.scope)) + if let Some(id) = iterable.scoped_ast_id(self.db, self.scope) { + result.try_expression_ty(id).unwrap_or(Type::Unknown) + } else { + tracing::warn!("Couldn't find AST ID for iterable"); + Type::Unknown + } }; let target_ty = if is_async { @@ -2049,17 +2122,26 @@ impl<'db> TypeInferenceBuilder<'db> { .unwrap_with_diagnostic(iterable.into(), self) }; - self.types - .expressions - .insert(target.scoped_ast_id(self.db, self.scope), target_ty); + match target.scoped_ast_id(self.db, self.scope) { + Some(id) => { + self.types.expressions.insert(id, target_ty); + } + None => { + tracing::warn!("Couldn't find AST ID for expression"); + } + } self.add_binding(target.into(), definition, target_ty); } fn infer_named_expression(&mut self, named: &ast::ExprNamed) -> Type<'db> { - let definition = self.index.definition(named); - let result = infer_definition_types(self.db, definition); - self.extend(result); - result.binding_ty(definition) + if let Some(definition) = self.index.definition(named) { + let result = infer_definition_types(self.db, definition); + self.extend(result); + result.binding_ty(definition) + } else { + tracing::warn!("Couldn't find definition"); + Type::Unknown + } } fn infer_named_expression_definition( @@ -2258,11 +2340,13 @@ impl<'db> TypeInferenceBuilder<'db> { match ctx { ExprContext::Load => { let use_def = self.index.use_def_map(file_scope_id); - let symbol = self - .index - .symbol_table(file_scope_id) - .symbol_id_by_name(id) - .expect("Expected the symbol table to create a symbol for every Name node"); + let Some(symbol) = self.index.symbol_table(file_scope_id).symbol_id_by_name(id) + else { + tracing::warn!( + "Expected the symbol table to create a symbol for every Name node" + ); + return Type::Unknown; + }; // if we're inferring types of deferred expressions, always treat them as public symbols let (definitions, may_be_unbound) = if self.is_deferred() { ( @@ -2270,11 +2354,15 @@ impl<'db> TypeInferenceBuilder<'db> { use_def.public_may_be_unbound(symbol), ) } else { - let use_id = name.scoped_use_id(self.db, self.scope); - ( - use_def.bindings_at_use(use_id), - use_def.use_may_be_unbound(use_id), - ) + if let Some(use_id) = name.scoped_use_id(self.db, self.scope) { + ( + use_def.bindings_at_use(use_id), + use_def.use_may_be_unbound(use_id), + ) + } else { + tracing::warn!("Failed to find name"); + return Type::Unknown; + } }; let unbound_ty = if may_be_unbound { @@ -2521,8 +2609,8 @@ impl<'db> TypeInferenceBuilder<'db> { .tuple_windows::<(_, _)>() .zip(ops.iter()) .map(|((left, right), op)| { - let left_ty = self.expression_ty(left); - let right_ty = self.expression_ty(right); + let left_ty = self.try_expression_ty(left).unwrap_or(Type::Unknown); + let right_ty = self.try_expression_ty(right).unwrap_or(Type::Unknown); self.infer_binary_type_comparison(left_ty, *op, right_ty) .unwrap_or_else(|| { @@ -2990,11 +3078,13 @@ impl<'db> TypeInferenceBuilder<'db> { let ty = match expression { ast::Expr::Name(name) => { - debug_assert!( - name.ctx.is_load(), - "name in a type expression is always 'load' but got: '{:?}'", - name.ctx - ); + if !name.ctx.is_load() { + tracing::warn!( + "name in a type expression is always 'load' but got: '{:?}'", + name.ctx + ); + return Type::Unknown; + } self.infer_name_expression(name).to_instance(self.db) } @@ -3122,9 +3212,15 @@ impl<'db> TypeInferenceBuilder<'db> { ast::Expr::IpyEscapeCommand(_) => todo!("Implement Ipy escape command support"), }; - let expr_id = expression.scoped_ast_id(self.db, self.scope); - let previous = self.types.expressions.insert(expr_id, ty); - assert!(previous.is_none()); + match expression.scoped_ast_id(self.db, self.scope) { + Some(expr_id) => { + let previous = self.types.expressions.insert(expr_id, ty); + assert!(previous.is_none()); + } + None => { + tracing::warn!("Could not find AST ID for expression"); + } + } ty }