From 22733cb7c7eda64f535458c4e0cd47c71874d2b2 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Thu, 20 Jun 2024 11:49:38 +0100 Subject: [PATCH] red-knot(Salsa): Types without refinements (#11899) --- Cargo.lock | 1 + crates/ruff_python_semantic/Cargo.toml | 3 +- crates/ruff_python_semantic/src/db.rs | 125 ++- .../src/module/resolver.rs | 2 +- .../ruff_python_semantic/src/red_knot/mod.rs | 5 + .../src/red_knot/semantic_index.rs | 69 +- .../src/red_knot/semantic_index/ast_ids.rs | 18 +- .../src/red_knot/semantic_index/builder.rs | 126 ++- .../src/red_knot/semantic_index/definition.rs | 14 +- .../src/red_knot/semantic_index/symbol.rs | 147 +-- .../src/red_knot/types.rs | 684 +++++++++++++ .../src/red_knot/types/display.rs | 175 ++++ .../src/red_knot/types/infer.rs | 945 ++++++++++++++++++ 13 files changed, 2168 insertions(+), 146 deletions(-) create mode 100644 crates/ruff_python_semantic/src/red_knot/types.rs create mode 100644 crates/ruff_python_semantic/src/red_knot/types/display.rs create mode 100644 crates/ruff_python_semantic/src/red_knot/types/infer.rs diff --git a/Cargo.lock b/Cargo.lock index 737829095ce25..796e985c90940 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2474,6 +2474,7 @@ dependencies = [ "anyhow", "bitflags 2.5.0", "hashbrown 0.14.5", + "indexmap", "is-macro", "ruff_db", "ruff_index", diff --git a/crates/ruff_python_semantic/Cargo.toml b/crates/ruff_python_semantic/Cargo.toml index 1b767e8e95178..c823bc29fbe1d 100644 --- a/crates/ruff_python_semantic/Cargo.toml +++ b/crates/ruff_python_semantic/Cargo.toml @@ -20,6 +20,7 @@ ruff_text_size = { workspace = true } bitflags = { workspace = true } is-macro = { workspace = true } +indexmap = { workspace = true, optional = true } salsa = { workspace = true, optional = true } smallvec = { workspace = true, optional = true } smol_str = { workspace = true } @@ -36,4 +37,4 @@ tempfile = { workspace = true } workspace = true [features] -red_knot = ["dep:salsa", "dep:tracing", "dep:hashbrown", "dep:smallvec"] +red_knot = ["dep:salsa", "dep:tracing", "dep:hashbrown", "dep:smallvec", "dep:indexmap"] diff --git a/crates/ruff_python_semantic/src/db.rs b/crates/ruff_python_semantic/src/db.rs index ae3d71ea01d4f..e8a0f7fd75080 100644 --- a/crates/ruff_python_semantic/src/db.rs +++ b/crates/ruff_python_semantic/src/db.rs @@ -6,20 +6,27 @@ use crate::module::resolver::{ file_to_module, internal::ModuleNameIngredient, internal::ModuleResolverSearchPaths, resolve_module_query, }; - -use crate::red_knot::semantic_index::symbol::ScopeId; -use crate::red_knot::semantic_index::{scopes_map, semantic_index, symbol_table}; +use crate::red_knot::semantic_index::symbol::{ + public_symbols_map, scopes_map, PublicSymbolId, ScopeId, +}; +use crate::red_knot::semantic_index::{root_scope, semantic_index, symbol_table}; +use crate::red_knot::types::{infer_types, public_symbol_ty}; #[salsa::jar(db=Db)] pub struct Jar( ModuleNameIngredient, ModuleResolverSearchPaths, ScopeId, + PublicSymbolId, symbol_table, resolve_module_query, file_to_module, scopes_map, + root_scope, semantic_index, + infer_types, + public_symbol_ty, + public_symbols_map, ); /// Database giving access to semantic information about a Python program. @@ -27,9 +34,13 @@ pub trait Db: SourceDb + DbWithJar + Upcast {} #[cfg(test)] pub(crate) mod tests { + use std::fmt::Formatter; + use std::marker::PhantomData; use std::sync::Arc; - use salsa::DebugWithDb; + use salsa::ingredient::Ingredient; + use salsa::storage::HasIngredientsFor; + use salsa::{AsId, DebugWithDb}; use ruff_db::file_system::{FileSystem, MemoryFileSystem, OsFileSystem}; use ruff_db::vfs::Vfs; @@ -86,7 +97,7 @@ pub(crate) mod tests { /// /// ## Panics /// If there are any pending salsa snapshots. - pub(crate) fn take_sale_events(&mut self) -> Vec { + pub(crate) fn take_salsa_events(&mut self) -> Vec { let inner = Arc::get_mut(&mut self.events).expect("no pending salsa snapshots"); let events = inner.get_mut().unwrap(); @@ -98,7 +109,7 @@ pub(crate) mod tests { /// ## Panics /// If there are any pending salsa snapshots. pub(crate) fn clear_salsa_events(&mut self) { - self.take_sale_events(); + self.take_salsa_events(); } } @@ -150,4 +161,106 @@ pub(crate) mod tests { #[allow(unused)] Os(OsFileSystem), } + + pub(crate) fn assert_will_run_function_query( + db: &Db, + to_function: impl FnOnce(&C) -> &salsa::function::FunctionIngredient, + key: C::Key, + events: &[salsa::Event], + ) where + C: salsa::function::Configuration + + salsa::storage::IngredientsFor, + Jar: HasIngredientsFor, + Db: salsa::DbWithJar, + C::Key: AsId, + { + will_run_function_query(db, to_function, key, events, true); + } + + pub(crate) fn assert_will_not_run_function_query( + db: &Db, + to_function: impl FnOnce(&C) -> &salsa::function::FunctionIngredient, + key: C::Key, + events: &[salsa::Event], + ) where + C: salsa::function::Configuration + + salsa::storage::IngredientsFor, + Jar: HasIngredientsFor, + Db: salsa::DbWithJar, + C::Key: AsId, + { + will_run_function_query(db, to_function, key, events, false); + } + + fn will_run_function_query( + db: &Db, + to_function: impl FnOnce(&C) -> &salsa::function::FunctionIngredient, + key: C::Key, + events: &[salsa::Event], + should_run: bool, + ) where + C: salsa::function::Configuration + + salsa::storage::IngredientsFor, + Jar: HasIngredientsFor, + Db: salsa::DbWithJar, + C::Key: AsId, + { + let (jar, _) = + <_ as salsa::storage::HasJar<::Jar>>::jar(db); + let ingredient = jar.ingredient(); + + let function_ingredient = to_function(ingredient); + + let ingredient_index = + as Ingredient>::ingredient_index( + function_ingredient, + ); + + let did_run = events.iter().any(|event| { + if let salsa::EventKind::WillExecute { database_key } = event.kind { + database_key.ingredient_index() == ingredient_index + && database_key.key_index() == key.as_id() + } else { + false + } + }); + + if should_run && !did_run { + panic!( + "Expected query {:?} to run but it didn't", + DebugIdx { + db: PhantomData::, + value_id: key.as_id(), + ingredient: function_ingredient, + } + ); + } else if !should_run && did_run { + panic!( + "Expected query {:?} not to run but it did", + DebugIdx { + db: PhantomData::, + value_id: key.as_id(), + ingredient: function_ingredient, + } + ); + } + } + + struct DebugIdx<'a, I, Db> + where + I: Ingredient, + { + value_id: salsa::Id, + ingredient: &'a I, + db: PhantomData, + } + + impl<'a, I, Db> std::fmt::Debug for DebugIdx<'a, I, Db> + where + I: Ingredient, + { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + self.ingredient.fmt_index(Some(self.value_id), f) + } + } } diff --git a/crates/ruff_python_semantic/src/module/resolver.rs b/crates/ruff_python_semantic/src/module/resolver.rs index 1113b97591530..8c7cf60526ac1 100644 --- a/crates/ruff_python_semantic/src/module/resolver.rs +++ b/crates/ruff_python_semantic/src/module/resolver.rs @@ -886,7 +886,7 @@ mod tests { let foo_module2 = resolve_module(&db, foo_module_name); assert!(!db - .take_sale_events() + .take_salsa_events() .iter() .any(|event| { matches!(event.kind, salsa::EventKind::WillExecute { .. }) })); diff --git a/crates/ruff_python_semantic/src/red_knot/mod.rs b/crates/ruff_python_semantic/src/red_knot/mod.rs index 9a21b4c4cf3c4..2db3e01b1e465 100644 --- a/crates/ruff_python_semantic/src/red_knot/mod.rs +++ b/crates/ruff_python_semantic/src/red_knot/mod.rs @@ -1,3 +1,8 @@ +use rustc_hash::FxHasher; +use std::hash::BuildHasherDefault; + pub mod ast_node_ref; mod node_key; pub mod semantic_index; +pub mod types; +pub(crate) type FxIndexSet = indexmap::set::IndexSet>; diff --git a/crates/ruff_python_semantic/src/red_knot/semantic_index.rs b/crates/ruff_python_semantic/src/red_knot/semantic_index.rs index 8764a00c7b07b..3ecbf68caefc8 100644 --- a/crates/ruff_python_semantic/src/red_knot/semantic_index.rs +++ b/crates/ruff_python_semantic/src/red_knot/semantic_index.rs @@ -9,10 +9,10 @@ use ruff_index::{IndexSlice, IndexVec}; use ruff_python_ast as ast; use crate::red_knot::node_key::NodeKey; -use crate::red_knot::semantic_index::ast_ids::AstIds; +use crate::red_knot::semantic_index::ast_ids::{AstId, AstIds, ScopeClassId, ScopeFunctionId}; use crate::red_knot::semantic_index::builder::SemanticIndexBuilder; use crate::red_knot::semantic_index::symbol::{ - FileScopeId, PublicSymbolId, Scope, ScopeId, ScopeSymbolId, ScopesMap, SymbolTable, + FileScopeId, PublicSymbolId, Scope, ScopeId, ScopeKind, ScopedSymbolId, SymbolTable, }; use crate::Db; @@ -21,7 +21,7 @@ mod builder; pub mod definition; pub mod symbol; -type SymbolMap = hashbrown::HashMap; +type SymbolMap = hashbrown::HashMap; /// Returns the semantic index for `file`. /// @@ -42,33 +42,22 @@ pub(crate) fn semantic_index(db: &dyn Db, file: VfsFile) -> SemanticIndex { pub(crate) fn symbol_table(db: &dyn Db, scope: ScopeId) -> Arc { let index = semantic_index(db, scope.file(db)); - index.symbol_table(scope.scope_id(db)) -} - -/// Returns a mapping from file specific [`FileScopeId`] to a program-wide unique [`ScopeId`]. -#[salsa::tracked(return_ref)] -pub(crate) fn scopes_map(db: &dyn Db, file: VfsFile) -> ScopesMap { - let index = semantic_index(db, file); - - let scopes: IndexVec<_, _> = index - .scopes - .indices() - .map(|id| ScopeId::new(db, file, id)) - .collect(); - - ScopesMap::new(scopes) + index.symbol_table(scope.file_scope_id(db)) } /// Returns the root scope of `file`. -pub fn root_scope(db: &dyn Db, file: VfsFile) -> ScopeId { +#[salsa::tracked] +pub(crate) fn root_scope(db: &dyn Db, file: VfsFile) -> ScopeId { FileScopeId::root().to_scope_id(db, file) } /// Returns the symbol with the given name in `file`'s public scope or `None` if /// no symbol with the given name exists. -pub fn global_symbol(db: &dyn Db, file: VfsFile, name: &str) -> Option { +pub fn public_symbol(db: &dyn Db, file: VfsFile, name: &str) -> Option { let root_scope = root_scope(db, file); - root_scope.symbol(db, name) + let symbol_table = symbol_table(db, root_scope); + let local = symbol_table.symbol_id_by_name(name)?; + Some(local.to_public_symbol(db, file)) } /// The symbol tables for an entire file. @@ -90,6 +79,9 @@ pub struct SemanticIndex { /// Note: We should not depend on this map when analysing other files or /// changing a file invalidates all dependents. ast_ids: IndexVec, + + /// Map from scope to the node that introduces the scope. + scope_nodes: IndexVec, } impl SemanticIndex { @@ -97,7 +89,7 @@ impl SemanticIndex { /// /// Use the Salsa cached [`symbol_table`] query if you only need the /// symbol table for a single scope. - fn symbol_table(&self, scope_id: FileScopeId) -> Arc { + pub(super) fn symbol_table(&self, scope_id: FileScopeId) -> Arc { self.symbol_tables[scope_id].clone() } @@ -152,6 +144,10 @@ impl SemanticIndex { pub(crate) fn ancestor_scopes(&self, scope: FileScopeId) -> AncestorsIter { AncestorsIter::new(self, scope) } + + pub(crate) fn scope_node(&self, scope_id: FileScopeId) -> NodeWithScopeId { + self.scope_nodes[scope_id] + } } /// ID that uniquely identifies an expression inside a [`Scope`]. @@ -246,6 +242,28 @@ impl<'a> Iterator for ChildrenIter<'a> { } } +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub(crate) enum NodeWithScopeId { + Module, + Class(AstId), + ClassTypeParams(AstId), + Function(AstId), + FunctionTypeParams(AstId), +} + +impl NodeWithScopeId { + fn scope_kind(self) -> ScopeKind { + match self { + NodeWithScopeId::Module => ScopeKind::Module, + NodeWithScopeId::Class(_) => ScopeKind::Class, + NodeWithScopeId::Function(_) => ScopeKind::Function, + NodeWithScopeId::ClassTypeParams(_) | NodeWithScopeId::FunctionTypeParams(_) => { + ScopeKind::Annotation + } + } + } +} + impl FusedIterator for ChildrenIter<'_> {} #[cfg(test)] @@ -583,19 +601,14 @@ class C[T]: let TestCase { db, file } = test_case("x = 1;\ndef test():\n y = 4"); let index = semantic_index(&db, file); - let root_table = index.symbol_table(FileScopeId::root()); let parsed = parsed_module(&db, file); let ast = parsed.syntax(); - let x_sym = root_table - .symbol_by_name("x") - .expect("x symbol should exist"); - let x_stmt = ast.body[0].as_assign_stmt().unwrap(); let x = &x_stmt.targets[0]; assert_eq!(index.expression_scope(x).kind(), ScopeKind::Module); - assert_eq!(index.expression_scope_id(x), x_sym.scope()); + assert_eq!(index.expression_scope_id(x), FileScopeId::root()); let def = ast.body[1].as_function_def_stmt().unwrap(); let y_stmt = def.body[0].as_assign_stmt().unwrap(); diff --git a/crates/ruff_python_semantic/src/red_knot/semantic_index/ast_ids.rs b/crates/ruff_python_semantic/src/red_knot/semantic_index/ast_ids.rs index 9d1fd1a9891ac..f615d5f44197f 100644 --- a/crates/ruff_python_semantic/src/red_knot/semantic_index/ast_ids.rs +++ b/crates/ruff_python_semantic/src/red_knot/semantic_index/ast_ids.rs @@ -66,13 +66,13 @@ impl std::fmt::Debug for AstIds { } fn ast_ids(db: &dyn Db, scope: ScopeId) -> &AstIds { - semantic_index(db, scope.file(db)).ast_ids(scope.scope_id(db)) + semantic_index(db, scope.file(db)).ast_ids(scope.file_scope_id(db)) } /// Node that can be uniquely identified by an id in a [`FileScopeId`]. pub trait ScopeAstIdNode { /// The type of the ID uniquely identifying the node. - type Id; + type Id: Copy; /// Returns the ID that uniquely identifies the node in `scope`. /// @@ -91,7 +91,7 @@ pub trait ScopeAstIdNode { /// Extension trait for AST nodes that can be resolved by an `AstId`. pub trait AstIdNode { - type ScopeId; + type ScopeId: Copy; /// Resolves the AST id of the node. /// @@ -133,7 +133,7 @@ where /// Uniquely identifies an AST node in a file. #[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] -pub struct AstId { +pub struct AstId { /// The node's scope. scope: FileScopeId, @@ -141,6 +141,16 @@ pub struct AstId { in_scope_id: L, } +impl AstId { + pub(super) fn new(scope: FileScopeId, in_scope_id: L) -> Self { + Self { scope, in_scope_id } + } + + pub(super) fn in_scope_id(self) -> L { + self.in_scope_id + } +} + /// Uniquely identifies an [`ast::Expr`] in a [`FileScopeId`]. #[newtype_index] pub struct ScopeExpressionId; diff --git a/crates/ruff_python_semantic/src/red_knot/semantic_index/builder.rs b/crates/ruff_python_semantic/src/red_knot/semantic_index/builder.rs index ef1237b70e278..d2107a0012a02 100644 --- a/crates/ruff_python_semantic/src/red_knot/semantic_index/builder.rs +++ b/crates/ruff_python_semantic/src/red_knot/semantic_index/builder.rs @@ -10,16 +10,16 @@ use ruff_python_ast::visitor::{walk_expr, walk_stmt, Visitor}; use crate::name::Name; use crate::red_knot::node_key::NodeKey; use crate::red_knot::semantic_index::ast_ids::{ - AstIdsBuilder, ScopeAssignmentId, ScopeClassId, ScopeFunctionId, ScopeImportFromId, + AstId, AstIdsBuilder, ScopeAssignmentId, ScopeClassId, ScopeFunctionId, ScopeImportFromId, ScopeImportId, ScopeNamedExprId, }; use crate::red_knot::semantic_index::definition::{ Definition, ImportDefinition, ImportFromDefinition, }; use crate::red_knot::semantic_index::symbol::{ - FileScopeId, FileSymbolId, Scope, ScopeKind, ScopeSymbolId, SymbolFlags, SymbolTableBuilder, + FileScopeId, FileSymbolId, Scope, ScopedSymbolId, SymbolFlags, SymbolTableBuilder, }; -use crate::red_knot::semantic_index::SemanticIndex; +use crate::red_knot::semantic_index::{NodeWithScopeId, SemanticIndex}; pub(super) struct SemanticIndexBuilder<'a> { // Builder state @@ -33,6 +33,7 @@ pub(super) struct SemanticIndexBuilder<'a> { symbol_tables: IndexVec, ast_ids: IndexVec, expression_scopes: FxHashMap, + scope_nodes: IndexVec, } impl<'a> SemanticIndexBuilder<'a> { @@ -46,10 +47,11 @@ impl<'a> SemanticIndexBuilder<'a> { symbol_tables: IndexVec::new(), ast_ids: IndexVec::new(), expression_scopes: FxHashMap::default(), + scope_nodes: IndexVec::new(), }; builder.push_scope_with_parent( - ScopeKind::Module, + NodeWithScopeId::Module, &Name::new_static(""), None, None, @@ -68,18 +70,18 @@ impl<'a> SemanticIndexBuilder<'a> { fn push_scope( &mut self, - scope_kind: ScopeKind, + node: NodeWithScopeId, name: &Name, defining_symbol: Option, definition: Option, ) { let parent = self.current_scope(); - self.push_scope_with_parent(scope_kind, name, defining_symbol, definition, Some(parent)); + self.push_scope_with_parent(node, name, defining_symbol, definition, Some(parent)); } fn push_scope_with_parent( &mut self, - scope_kind: ScopeKind, + node: NodeWithScopeId, name: &Name, defining_symbol: Option, definition: Option, @@ -92,13 +94,17 @@ impl<'a> SemanticIndexBuilder<'a> { parent, defining_symbol, definition, - kind: scope_kind, + kind: node.scope_kind(), descendents: children_start..children_start, }; let scope_id = self.scopes.push(scope); self.symbol_tables.push(SymbolTableBuilder::new()); - self.ast_ids.push(AstIdsBuilder::new()); + let ast_id_scope = self.ast_ids.push(AstIdsBuilder::new()); + let scope_node_id = self.scope_nodes.push(node); + + debug_assert_eq!(ast_id_scope, scope_id); + debug_assert_eq!(scope_id, scope_node_id); self.scope_stack.push(scope_id); } @@ -120,11 +126,10 @@ impl<'a> SemanticIndexBuilder<'a> { &mut self.ast_ids[scope_id] } - fn add_or_update_symbol(&mut self, name: Name, flags: SymbolFlags) -> ScopeSymbolId { - let scope = self.current_scope(); + fn add_or_update_symbol(&mut self, name: Name, flags: SymbolFlags) -> ScopedSymbolId { let symbol_table = self.current_symbol_table(); - symbol_table.add_or_update_symbol(name, scope, flags, None) + symbol_table.add_or_update_symbol(name, flags, None) } fn add_or_update_symbol_with_definition( @@ -132,27 +137,32 @@ impl<'a> SemanticIndexBuilder<'a> { name: Name, definition: Definition, - ) -> ScopeSymbolId { - let scope = self.current_scope(); + ) -> ScopedSymbolId { let symbol_table = self.current_symbol_table(); - symbol_table.add_or_update_symbol(name, scope, SymbolFlags::IS_DEFINED, Some(definition)) + symbol_table.add_or_update_symbol(name, SymbolFlags::IS_DEFINED, Some(definition)) } fn with_type_params( &mut self, name: &Name, - params: &Option>, - definition: Option, + with_params: &WithTypeParams, defining_symbol: FileSymbolId, nested: impl FnOnce(&mut Self) -> FileScopeId, ) -> FileScopeId { - if let Some(type_params) = params { + let type_params = with_params.type_parameters(); + + if let Some(type_params) = type_params { + let type_node = match with_params { + WithTypeParams::ClassDef { id, .. } => NodeWithScopeId::ClassTypeParams(*id), + WithTypeParams::FunctionDef { id, .. } => NodeWithScopeId::FunctionTypeParams(*id), + }; + self.push_scope( - ScopeKind::Annotation, + type_node, name, Some(defining_symbol), - definition, + Some(with_params.definition()), ); for type_param in &type_params.type_params { let name = match type_param { @@ -163,9 +173,10 @@ impl<'a> SemanticIndexBuilder<'a> { self.add_or_update_symbol(Name::new(name), SymbolFlags::IS_DEFINED); } } + let nested_scope = nested(self); - if params.is_some() { + if type_params.is_some() { self.pop_scope(); } @@ -198,10 +209,12 @@ impl<'a> SemanticIndexBuilder<'a> { ast_ids.shrink_to_fit(); symbol_tables.shrink_to_fit(); self.expression_scopes.shrink_to_fit(); + self.scope_nodes.shrink_to_fit(); SemanticIndex { symbol_tables, scopes: self.scopes, + scope_nodes: self.scope_nodes, ast_ids, expression_scopes: self.expression_scopes, } @@ -223,7 +236,8 @@ impl Visitor<'_> for SemanticIndexBuilder<'_> { self.visit_decorator(decorator); } let name = Name::new(&function_def.name.id); - let definition = Definition::FunctionDef(ScopeFunctionId(statement_id)); + let function_id = ScopeFunctionId(statement_id); + let definition = Definition::FunctionDef(function_id); let scope = self.current_scope(); let symbol = FileSymbolId::new( scope, @@ -232,8 +246,10 @@ impl Visitor<'_> for SemanticIndexBuilder<'_> { self.with_type_params( &name, - &function_def.type_params, - Some(definition), + &WithTypeParams::FunctionDef { + node: function_def, + id: AstId::new(scope, function_id), + }, symbol, |builder| { builder.visit_parameters(&function_def.parameters); @@ -242,7 +258,7 @@ impl Visitor<'_> for SemanticIndexBuilder<'_> { } builder.push_scope( - ScopeKind::Function, + NodeWithScopeId::Function(AstId::new(scope, function_id)), &name, Some(symbol), Some(definition), @@ -258,21 +274,36 @@ impl Visitor<'_> for SemanticIndexBuilder<'_> { } let name = Name::new(&class.name.id); - let definition = Definition::from(ScopeClassId(statement_id)); + let class_id = ScopeClassId(statement_id); + let definition = Definition::from(class_id); + let scope = self.current_scope(); let id = FileSymbolId::new( self.current_scope(), self.add_or_update_symbol_with_definition(name.clone(), definition), ); - self.with_type_params(&name, &class.type_params, Some(definition), id, |builder| { - if let Some(arguments) = &class.arguments { - builder.visit_arguments(arguments); - } + self.with_type_params( + &name, + &WithTypeParams::ClassDef { + node: class, + id: AstId::new(scope, class_id), + }, + id, + |builder| { + if let Some(arguments) = &class.arguments { + builder.visit_arguments(arguments); + } - builder.push_scope(ScopeKind::Class, &name, Some(id), Some(definition)); - builder.visit_body(&class.body); + builder.push_scope( + NodeWithScopeId::Class(AstId::new(scope, class_id)), + &name, + Some(id), + Some(definition), + ); + builder.visit_body(&class.body); - builder.pop_scope() - }); + builder.pop_scope() + }, + ); } ast::Stmt::Import(ast::StmtImport { names, .. }) => { for (i, alias) in names.iter().enumerate() { @@ -396,3 +427,30 @@ impl Visitor<'_> for SemanticIndexBuilder<'_> { } } } + +enum WithTypeParams<'a> { + ClassDef { + node: &'a ast::StmtClassDef, + id: AstId, + }, + FunctionDef { + node: &'a ast::StmtFunctionDef, + id: AstId, + }, +} + +impl<'a> WithTypeParams<'a> { + fn type_parameters(&self) -> Option<&'a ast::TypeParams> { + match self { + WithTypeParams::ClassDef { node, .. } => node.type_params.as_deref(), + WithTypeParams::FunctionDef { node, .. } => node.type_params.as_deref(), + } + } + + fn definition(&self) -> Definition { + match self { + WithTypeParams::ClassDef { id, .. } => Definition::ClassDef(id.in_scope_id()), + WithTypeParams::FunctionDef { id, .. } => Definition::FunctionDef(id.in_scope_id()), + } + } +} diff --git a/crates/ruff_python_semantic/src/red_knot/semantic_index/definition.rs b/crates/ruff_python_semantic/src/red_knot/semantic_index/definition.rs index 97170b9e27003..9c91e99aa969b 100644 --- a/crates/ruff_python_semantic/src/red_knot/semantic_index/definition.rs +++ b/crates/ruff_python_semantic/src/red_knot/semantic_index/definition.rs @@ -3,7 +3,7 @@ use crate::red_knot::semantic_index::ast_ids::{ ScopeImportFromId, ScopeImportId, ScopeNamedExprId, }; -#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub enum Definition { Import(ImportDefinition), ImportFrom(ImportFromDefinition), @@ -59,18 +59,18 @@ impl From for Definition { } } -#[derive(Copy, Clone, Debug, Eq, PartialEq)] +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] pub struct ImportDefinition { - pub(super) import_id: ScopeImportId, + pub(crate) import_id: ScopeImportId, /// Index into [`ruff_python_ast::StmtImport::names`]. - pub(super) alias: u32, + pub(crate) alias: u32, } -#[derive(Copy, Clone, Debug, Eq, PartialEq)] +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] pub struct ImportFromDefinition { - pub(super) import_id: ScopeImportFromId, + pub(crate) import_id: ScopeImportFromId, /// Index into [`ruff_python_ast::StmtImportFrom::names`]. - pub(super) name: u32, + pub(crate) name: u32, } diff --git a/crates/ruff_python_semantic/src/red_knot/semantic_index/symbol.rs b/crates/ruff_python_semantic/src/red_knot/semantic_index/symbol.rs index b543742d3f76b..2869c7fea4f06 100644 --- a/crates/ruff_python_semantic/src/red_knot/semantic_index/symbol.rs +++ b/crates/ruff_python_semantic/src/red_knot/semantic_index/symbol.rs @@ -15,24 +15,21 @@ use ruff_index::{newtype_index, IndexVec}; use crate::name::Name; use crate::red_knot::semantic_index::definition::Definition; -use crate::red_knot::semantic_index::{scopes_map, symbol_table, SymbolMap}; +use crate::red_knot::semantic_index::{root_scope, semantic_index, symbol_table, SymbolMap}; use crate::Db; #[derive(Eq, PartialEq, Debug)] pub struct Symbol { name: Name, flags: SymbolFlags, - scope: FileScopeId, - /// The nodes that define this symbol, in source order. definitions: SmallVec<[Definition; 4]>, } impl Symbol { - fn new(name: Name, scope: FileScopeId, definition: Option) -> Self { + fn new(name: Name, definition: Option) -> Self { Self { name, - scope, flags: SymbolFlags::empty(), definitions: definition.into_iter().collect(), } @@ -51,11 +48,6 @@ impl Symbol { &self.name } - /// The scope in which this symbol is defined. - pub fn scope(&self) -> FileScopeId { - self.scope - } - /// Is the symbol used in its containing scope? pub fn is_used(&self) -> bool { self.flags.contains(SymbolFlags::IS_USED) @@ -84,62 +76,72 @@ bitflags! { } /// ID that uniquely identifies a public symbol defined in a module's root scope. -#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] +#[salsa::tracked] pub struct PublicSymbolId { - scope: ScopeId, - symbol: ScopeSymbolId, -} - -impl PublicSymbolId { - pub(crate) fn new(scope: ScopeId, symbol: ScopeSymbolId) -> Self { - Self { scope, symbol } - } - - pub fn scope(self) -> ScopeId { - self.scope - } - - pub(crate) fn scope_symbol(self) -> ScopeSymbolId { - self.symbol - } -} - -impl From for ScopeSymbolId { - fn from(val: PublicSymbolId) -> Self { - val.scope_symbol() - } + #[id] + pub(crate) file: VfsFile, + #[id] + pub(crate) scoped_symbol_id: ScopedSymbolId, } /// ID that uniquely identifies a symbol in a file. #[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] pub struct FileSymbolId { scope: FileScopeId, - symbol: ScopeSymbolId, + scoped_symbol_id: ScopedSymbolId, } impl FileSymbolId { - pub(super) fn new(scope: FileScopeId, symbol: ScopeSymbolId) -> Self { - Self { scope, symbol } + pub(super) fn new(scope: FileScopeId, symbol: ScopedSymbolId) -> Self { + Self { + scope, + scoped_symbol_id: symbol, + } } pub fn scope(self) -> FileScopeId { self.scope } - pub(crate) fn symbol(self) -> ScopeSymbolId { - self.symbol + pub(crate) fn scoped_symbol_id(self) -> ScopedSymbolId { + self.scoped_symbol_id } } -impl From for ScopeSymbolId { +impl From for ScopedSymbolId { fn from(val: FileSymbolId) -> Self { - val.symbol() + val.scoped_symbol_id() } } /// Symbol ID that uniquely identifies a symbol inside a [`Scope`]. #[newtype_index] -pub(crate) struct ScopeSymbolId; +pub struct ScopedSymbolId; + +impl ScopedSymbolId { + /// Converts the symbol to a public symbol. + /// + /// # Panics + /// May panic if the symbol does not belong to `file` or is not a symbol of `file`'s root scope. + pub(crate) fn to_public_symbol(self, db: &dyn Db, file: VfsFile) -> PublicSymbolId { + let symbols = public_symbols_map(db, file); + symbols.public(self) + } +} + +/// Returns a mapping from [`FileScopeId`] to globally unique [`ScopeId`]. +#[salsa::tracked(return_ref)] +pub(crate) fn scopes_map(db: &dyn Db, file: VfsFile) -> ScopesMap { + let index = semantic_index(db, file); + + let scopes: IndexVec<_, _> = index + .scopes + .indices() + .map(|id| ScopeId::new(db, file, id)) + .collect(); + + ScopesMap { scopes } +} /// Maps from the file specific [`FileScopeId`] to the global [`ScopeId`] that can be used as a Salsa query parameter. /// @@ -152,13 +154,37 @@ pub(crate) struct ScopesMap { } impl ScopesMap { - pub(super) fn new(scopes: IndexVec) -> Self { - Self { scopes } + /// Gets the program-wide unique scope id for the given file specific `scope_id`. + fn get(&self, scope: FileScopeId) -> ScopeId { + self.scopes[scope] } +} - /// Gets the program-wide unique scope id for the given file specific `scope_id`. - fn get(&self, scope_id: FileScopeId) -> ScopeId { - self.scopes[scope_id] +#[salsa::tracked(return_ref)] +pub(crate) fn public_symbols_map(db: &dyn Db, file: VfsFile) -> PublicSymbolsMap { + let module_scope = root_scope(db, file); + let symbols = symbol_table(db, module_scope); + + let public_symbols: IndexVec<_, _> = symbols + .symbol_ids() + .map(|id| PublicSymbolId::new(db, file, id)) + .collect(); + + PublicSymbolsMap { + symbols: public_symbols, + } +} + +/// Maps [`LocalSymbolId`] of a file's root scope to the corresponding [`PublicSymbolId`] (Salsa ingredients). +#[derive(Eq, PartialEq, Debug)] +pub(crate) struct PublicSymbolsMap { + symbols: IndexVec, +} + +impl PublicSymbolsMap { + /// Resolve the [`PublicSymbolId`] for the module-level `symbol_id`. + fn public(&self, symbol_id: ScopedSymbolId) -> PublicSymbolId { + self.symbols[symbol_id] } } @@ -166,18 +192,10 @@ impl ScopesMap { #[salsa::tracked] pub struct ScopeId { #[allow(clippy::used_underscore_binding)] + #[id] pub file: VfsFile, - pub scope_id: FileScopeId, -} - -impl ScopeId { - /// Resolves the symbol named `name` in this scope. - pub fn symbol(self, db: &dyn Db, name: &str) -> Option { - let symbol_table = symbol_table(db, self); - let in_scope_id = symbol_table.symbol_id_by_name(name)?; - - Some(PublicSymbolId::new(self, in_scope_id)) - } + #[id] + pub file_scope_id: FileScopeId, } /// ID that uniquely identifies a scope inside of a module. @@ -239,7 +257,7 @@ pub enum ScopeKind { #[derive(Debug)] pub struct SymbolTable { /// The symbols in this scope. - symbols: IndexVec, + symbols: IndexVec, /// The symbols indexed by name. symbols_by_name: SymbolMap, @@ -257,12 +275,12 @@ impl SymbolTable { self.symbols.shrink_to_fit(); } - pub(crate) fn symbol(&self, symbol_id: impl Into) -> &Symbol { + pub(crate) fn symbol(&self, symbol_id: impl Into) -> &Symbol { &self.symbols[symbol_id.into()] } #[allow(unused)] - pub(crate) fn symbol_ids(&self) -> impl Iterator { + pub(crate) fn symbol_ids(&self) -> impl Iterator { self.symbols.indices() } @@ -277,8 +295,8 @@ impl SymbolTable { Some(self.symbol(id)) } - /// Returns the [`ScopeSymbolId`] of the symbol named `name`. - pub(crate) fn symbol_id_by_name(&self, name: &str) -> Option { + /// Returns the [`ScopedSymbolId`] of the symbol named `name`. + pub(crate) fn symbol_id_by_name(&self, name: &str) -> Option { let (id, ()) = self .symbols_by_name .raw_entry() @@ -320,10 +338,9 @@ impl SymbolTableBuilder { pub(super) fn add_or_update_symbol( &mut self, name: Name, - scope: FileScopeId, flags: SymbolFlags, definition: Option, - ) -> ScopeSymbolId { + ) -> ScopedSymbolId { let hash = SymbolTable::hash_name(&name); let entry = self .table @@ -343,7 +360,7 @@ impl SymbolTableBuilder { *entry.key() } RawEntryMut::Vacant(entry) => { - let mut symbol = Symbol::new(name, scope, definition); + let mut symbol = Symbol::new(name, definition); symbol.insert_flags(flags); let id = self.table.symbols.push(symbol); diff --git a/crates/ruff_python_semantic/src/red_knot/types.rs b/crates/ruff_python_semantic/src/red_knot/types.rs new file mode 100644 index 0000000000000..a3d7da4ad15cc --- /dev/null +++ b/crates/ruff_python_semantic/src/red_knot/types.rs @@ -0,0 +1,684 @@ +use salsa::DebugWithDb; + +use ruff_db::parsed::parsed_module; +use ruff_db::vfs::VfsFile; +use ruff_index::newtype_index; +use ruff_python_ast as ast; + +use crate::name::Name; +use crate::red_knot::semantic_index::ast_ids::{AstIdNode, ScopeAstIdNode}; +use crate::red_knot::semantic_index::symbol::{FileScopeId, PublicSymbolId, ScopeId}; +use crate::red_knot::semantic_index::{ + public_symbol, root_scope, semantic_index, symbol_table, NodeWithScopeId, +}; +use crate::red_knot::types::infer::{TypeInference, TypeInferenceBuilder}; +use crate::red_knot::FxIndexSet; +use crate::Db; + +mod display; +mod infer; + +/// Infers the type of `expr`. +/// +/// Calling this function from a salsa query adds a dependency on [`semantic_index`] +/// which changes with every AST change. That's why you should only call +/// this function for the current file that's being analyzed and not for +/// a dependency (or the query reruns whenever a dependency change). +/// +/// Prefer [`public_symbol_ty`] when resolving the type of symbol from another file. +#[tracing::instrument(level = "debug", skip(db))] +pub(crate) fn expression_ty(db: &dyn Db, file: VfsFile, expression: &ast::Expr) -> Type { + let index = semantic_index(db, file); + let file_scope = index.expression_scope_id(expression); + let expression_id = expression.scope_ast_id(db, file, file_scope); + let scope = file_scope.to_scope_id(db, file); + + infer_types(db, scope).expression_ty(expression_id) +} + +/// Infers the type of a public symbol. +/// +/// This is a Salsa query to get symbol-level invalidation instead of file-level dependency invalidation. +/// Without this being a query, changing any public type of a module would invalidate the type inference +/// for the module scope of its dependents and the transitive dependents because. +/// +/// For example if we have +/// ```python +/// # a.py +/// import x from b +/// +/// # b.py +/// +/// x = 20 +/// ``` +/// +/// And x is now changed from `x = 20` to `x = 30`. The following happens: +/// +/// * The module level types of `b.py` change because `x` now is a `Literal[30]`. +/// * The module level types of `a.py` change because the imported symbol `x` now has a `Literal[30]` type +/// * The module level types of any dependents of `a.py` change because the imported symbol `x` now has a `Literal[30]` type +/// * And so on for all transitive dependencies. +/// +/// This being a query ensures that the invalidation short-circuits if the type of this symbol didn't change. +#[salsa::tracked] +pub(crate) fn public_symbol_ty(db: &dyn Db, symbol: PublicSymbolId) -> Type { + let _ = tracing::debug_span!("public_symbol_ty", "{:?}", symbol.debug(db)); + + let file = symbol.file(db); + let scope = root_scope(db, file); + + let inference = infer_types(db, scope); + inference.symbol_ty(symbol.scoped_symbol_id(db)) +} + +/// Shorthand for [`public_symbol_ty()`] that takes a symbol name instead of a [`PublicSymbolId`]. +pub fn public_symbol_ty_by_name(db: &dyn Db, file: VfsFile, name: &str) -> Option { + let symbol = public_symbol(db, file, name)?; + Some(public_symbol_ty(db, symbol)) +} + +/// Infers all types for `scope`. +#[salsa::tracked(return_ref)] +pub(crate) fn infer_types(db: &dyn Db, scope: ScopeId) -> TypeInference { + let file = scope.file(db); + // Using the index here is fine because the code below depends on the AST anyway. + // The isolation of the query is by the return inferred types. + let index = semantic_index(db, file); + + let scope_id = scope.file_scope_id(db); + let node = index.scope_node(scope_id); + + let mut context = TypeInferenceBuilder::new(db, scope, index); + + match node { + NodeWithScopeId::Module => { + let parsed = parsed_module(db.upcast(), file); + context.infer_module(parsed.syntax()); + } + NodeWithScopeId::Class(class_id) => { + let class = ast::StmtClassDef::lookup(db, file, class_id); + context.infer_class_body(class); + } + NodeWithScopeId::ClassTypeParams(class_id) => { + let class = ast::StmtClassDef::lookup(db, file, class_id); + context.infer_class_type_params(class); + } + NodeWithScopeId::Function(function_id) => { + let function = ast::StmtFunctionDef::lookup(db, file, function_id); + context.infer_function_body(function); + } + NodeWithScopeId::FunctionTypeParams(function_id) => { + let function = ast::StmtFunctionDef::lookup(db, file, function_id); + context.infer_function_type_params(function); + } + } + + context.finish() +} + +/// unique ID for a type +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +pub enum Type { + /// the dynamic type: a statically-unknown set of values + Any, + /// the empty set of values + Never, + /// unknown type (no annotation) + /// equivalent to Any, or to object in strict mode + Unknown, + /// name is not bound to any value + Unbound, + /// the None object (TODO remove this in favor of Instance(types.NoneType) + None, + /// a specific function object + Function(TypeId), + /// a specific module object + Module(TypeId), + /// a specific class object + Class(TypeId), + /// the set of Python objects with the given class in their __class__'s method resolution order + Instance(TypeId), + Union(TypeId), + Intersection(TypeId), + IntLiteral(i64), + // TODO protocols, callable types, overloads, generics, type vars +} + +impl Type { + pub const fn is_unbound(&self) -> bool { + matches!(self, Type::Unbound) + } + + pub const fn is_unknown(&self) -> bool { + matches!(self, Type::Unknown) + } + + pub fn member(&self, context: &TypingContext, name: &Name) -> Option { + match self { + Type::Any => Some(Type::Any), + Type::Never => todo!("attribute lookup on Never type"), + Type::Unknown => Some(Type::Unknown), + Type::Unbound => todo!("attribute lookup on Unbound type"), + Type::None => todo!("attribute lookup on None type"), + Type::Function(_) => todo!("attribute lookup on Function type"), + Type::Module(module) => module.member(context, name), + Type::Class(class) => class.class_member(context, name), + Type::Instance(_) => { + // TODO MRO? get_own_instance_member, get_instance_member + todo!("attribute lookup on Instance type") + } + Type::Union(union_id) => { + let _union = union_id.lookup(context); + // TODO perform the get_member on each type in the union + // TODO return the union of those results + // TODO if any of those results is `None` then include Unknown in the result union + todo!("attribute lookup on Union type") + } + Type::Intersection(_) => { + // TODO perform the get_member on each type in the intersection + // TODO return the intersection of those results + todo!("attribute lookup on Intersection type") + } + Type::IntLiteral(_) => { + // TODO raise error + Some(Type::Unknown) + } + } + } +} + +/// ID that uniquely identifies a type in a program. +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] +pub struct TypeId { + /// The scope in which this type is defined or was created. + scope: ScopeId, + /// The type's local ID in its scope. + scoped: L, +} + +impl TypeId +where + Id: Copy, +{ + pub fn scope(&self) -> ScopeId { + self.scope + } + + pub fn scoped_id(&self) -> Id { + self.scoped + } + + /// Resolves the type ID to the actual type. + pub(crate) fn lookup<'a>(self, context: &'a TypingContext) -> &'a Id::Ty + where + Id: ScopedTypeId, + { + let types = context.types(self.scope); + self.scoped.lookup_scoped(types) + } +} + +/// ID that uniquely identifies a type in a scope. +pub(crate) trait ScopedTypeId { + /// The type that this ID points to. + type Ty; + + /// Looks up the type in `index`. + /// + /// ## Panics + /// May panic if this type is from another scope than `index`, or might just return an invalid type. + fn lookup_scoped(self, index: &TypeInference) -> &Self::Ty; +} + +/// ID uniquely identifying a function type in a `scope`. +#[newtype_index] +pub struct ScopedFunctionTypeId; + +impl ScopedTypeId for ScopedFunctionTypeId { + type Ty = FunctionType; + + fn lookup_scoped(self, types: &TypeInference) -> &Self::Ty { + types.function_ty(self) + } +} + +#[derive(Debug, Eq, PartialEq, Clone)] +pub struct FunctionType { + /// name of the function at definition + name: Name, + /// types of all decorators on this function + decorators: Vec, +} + +impl FunctionType { + fn name(&self) -> &str { + self.name.as_str() + } + + #[allow(unused)] + pub(crate) fn decorators(&self) -> &[Type] { + self.decorators.as_slice() + } +} + +#[newtype_index] +pub struct ScopedClassTypeId; + +impl ScopedTypeId for ScopedClassTypeId { + type Ty = ClassType; + + fn lookup_scoped(self, types: &TypeInference) -> &Self::Ty { + types.class_ty(self) + } +} + +impl TypeId { + /// Returns the class member of this class named `name`. + /// + /// The member resolves to a member of the class itself or any of its bases. + fn class_member(self, context: &TypingContext, name: &Name) -> Option { + if let Some(member) = self.own_class_member(context, name) { + return Some(member); + } + + let class = self.lookup(context); + for base in &class.bases { + if let Some(member) = base.member(context, name) { + return Some(member); + } + } + + None + } + + /// Returns the inferred type of the class member named `name`. + fn own_class_member(self, context: &TypingContext, name: &Name) -> Option { + let class = self.lookup(context); + + let symbols = symbol_table(context.db, class.body_scope); + let symbol = symbols.symbol_id_by_name(name)?; + let types = context.types(class.body_scope); + + Some(types.symbol_ty(symbol)) + } +} + +#[derive(Debug, Eq, PartialEq, Clone)] +pub struct ClassType { + /// Name of the class at definition + name: Name, + + /// Types of all class bases + bases: Vec, + + body_scope: ScopeId, +} + +impl ClassType { + fn name(&self) -> &str { + self.name.as_str() + } + + #[allow(unused)] + pub(super) fn bases(&self) -> &[Type] { + self.bases.as_slice() + } +} + +#[newtype_index] +pub struct ScopedUnionTypeId; + +impl ScopedTypeId for ScopedUnionTypeId { + type Ty = UnionType; + + fn lookup_scoped(self, types: &TypeInference) -> &Self::Ty { + types.union_ty(self) + } +} + +#[derive(Debug, Eq, PartialEq, Clone)] +pub struct UnionType { + // the union type includes values in any of these types + elements: FxIndexSet, +} + +struct UnionTypeBuilder<'a> { + elements: FxIndexSet, + context: &'a TypingContext<'a>, +} + +impl<'a> UnionTypeBuilder<'a> { + fn new(context: &'a TypingContext<'a>) -> Self { + Self { + context, + elements: FxIndexSet::default(), + } + } + + /// Adds a type to this union. + fn add(mut self, ty: Type) -> Self { + match ty { + Type::Union(union_id) => { + let union = union_id.lookup(self.context); + self.elements.extend(&union.elements); + } + _ => { + self.elements.insert(ty); + } + } + + self + } + + fn build(self) -> UnionType { + UnionType { + elements: self.elements, + } + } +} + +#[newtype_index] +pub struct ScopedIntersectionTypeId; + +impl ScopedTypeId for ScopedIntersectionTypeId { + type Ty = IntersectionType; + + fn lookup_scoped(self, types: &TypeInference) -> &Self::Ty { + types.intersection_ty(self) + } +} + +// Negation types aren't expressible in annotations, and are most likely to arise from type +// narrowing along with intersections (e.g. `if not isinstance(...)`), so we represent them +// directly in intersections rather than as a separate type. This sacrifices some efficiency in the +// case where a Not appears outside an intersection (unclear when that could even happen, but we'd +// have to represent it as a single-element intersection if it did) in exchange for better +// efficiency in the within-intersection case. +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct IntersectionType { + // the intersection type includes only values in all of these types + positive: FxIndexSet, + // the intersection type does not include any value in any of these types + negative: FxIndexSet, +} + +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] +pub struct ScopedModuleTypeId; + +impl ScopedTypeId for ScopedModuleTypeId { + type Ty = ModuleType; + + fn lookup_scoped(self, types: &TypeInference) -> &Self::Ty { + types.module_ty() + } +} + +impl TypeId { + fn member(self, context: &TypingContext, name: &Name) -> Option { + context.public_symbol_ty(self.scope.file(context.db), name) + } +} + +#[derive(Debug, Eq, PartialEq, Clone)] +pub struct ModuleType { + file: VfsFile, +} + +/// Context in which to resolve types. +/// +/// This abstraction is necessary to support a uniform API that can be used +/// while in the process of building the type inference structure for a scope +/// but also when all types should be resolved by querying the db. +pub struct TypingContext<'a> { + db: &'a dyn Db, + + /// The Local type inference scope that is in the process of being built. + /// + /// Bypass the `db` when resolving the types for this scope. + local: Option<(ScopeId, &'a TypeInference)>, +} + +impl<'a> TypingContext<'a> { + /// Creates a context that resolves all types by querying the db. + #[allow(unused)] + pub(super) fn global(db: &'a dyn Db) -> Self { + Self { db, local: None } + } + + /// Creates a context that by-passes the `db` when resolving types from `scope_id` and instead uses `types`. + fn scoped(db: &'a dyn Db, scope_id: ScopeId, types: &'a TypeInference) -> Self { + Self { + db, + local: Some((scope_id, types)), + } + } + + /// Returns the [`TypeInference`] results (not guaranteed to be complete) for `scope_id`. + fn types(&self, scope_id: ScopeId) -> &'a TypeInference { + if let Some((scope, local_types)) = self.local { + if scope == scope_id { + return local_types; + } + } + + infer_types(self.db, scope_id) + } + + fn module_ty(&self, file: VfsFile) -> Type { + let scope = root_scope(self.db, file); + + Type::Module(TypeId { + scope, + scoped: ScopedModuleTypeId, + }) + } + + /// Resolves the public type of a symbol named `name` defined in `file`. + /// + /// This function calls [`public_symbol_ty`] if the local scope isn't the module scope of `file`. + /// It otherwise tries to resolve the symbol type locally. + fn public_symbol_ty(&self, file: VfsFile, name: &Name) -> Option { + let symbol = public_symbol(self.db, file, name)?; + + if let Some((scope, local_types)) = self.local { + if scope.file_scope_id(self.db) == FileScopeId::root() && scope.file(self.db) == file { + return Some(local_types.symbol_ty(symbol.scoped_symbol_id(self.db))); + } + } + + Some(public_symbol_ty(self.db, symbol)) + } +} + +#[cfg(test)] +mod tests { + use ruff_db::file_system::FileSystemPathBuf; + use ruff_db::parsed::parsed_module; + use ruff_db::vfs::system_path_to_file; + + use crate::db::tests::{ + assert_will_not_run_function_query, assert_will_run_function_query, TestDb, + }; + use crate::module::resolver::{set_module_resolution_settings, ModuleResolutionSettings}; + use crate::red_knot::semantic_index::root_scope; + use crate::red_knot::types::{ + expression_ty, infer_types, public_symbol_ty_by_name, TypingContext, + }; + + fn setup_db() -> TestDb { + let mut db = TestDb::new(); + set_module_resolution_settings( + &mut db, + ModuleResolutionSettings { + extra_paths: vec![], + workspace_root: FileSystemPathBuf::from("/src"), + site_packages: None, + custom_typeshed: None, + }, + ); + + db + } + + #[test] + fn local_inference() -> anyhow::Result<()> { + let db = setup_db(); + + db.memory_file_system().write_file("/src/a.py", "x = 10")?; + let a = system_path_to_file(&db, "/src/a.py").unwrap(); + + let parsed = parsed_module(&db, a); + + let statement = parsed.suite().first().unwrap().as_assign_stmt().unwrap(); + + let literal_ty = expression_ty(&db, a, &statement.value); + + assert_eq!( + format!("{}", literal_ty.display(&TypingContext::global(&db))), + "Literal[10]" + ); + + Ok(()) + } + + #[test] + fn dependency_public_symbol_type_change() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.memory_file_system().write_files([ + ("/src/a.py", "from foo import x"), + ("/src/foo.py", "x = 10\ndef foo(): ..."), + ])?; + + let a = system_path_to_file(&db, "/src/a.py").unwrap(); + let x_ty = public_symbol_ty_by_name(&db, a, "x").unwrap(); + + assert_eq!( + x_ty.display(&TypingContext::global(&db)).to_string(), + "Literal[10]" + ); + + // Change `x` to a different value + db.memory_file_system() + .write_file("/src/foo.py", "x = 20\ndef foo(): ...")?; + + let foo = system_path_to_file(&db, "/src/foo.py").unwrap(); + foo.touch(&mut db); + + let a = system_path_to_file(&db, "/src/a.py").unwrap(); + + db.clear_salsa_events(); + let x_ty_2 = public_symbol_ty_by_name(&db, a, "x").unwrap(); + + assert_eq!( + x_ty_2.display(&TypingContext::global(&db)).to_string(), + "Literal[20]" + ); + + let events = db.take_salsa_events(); + + let a_root_scope = root_scope(&db, a); + assert_will_run_function_query::( + &db, + |ty| &ty.function, + a_root_scope, + &events, + ); + + Ok(()) + } + + #[test] + fn dependency_non_public_symbol_change() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.memory_file_system().write_files([ + ("/src/a.py", "from foo import x"), + ("/src/foo.py", "x = 10\ndef foo(): y = 1"), + ])?; + + let a = system_path_to_file(&db, "/src/a.py").unwrap(); + let x_ty = public_symbol_ty_by_name(&db, a, "x").unwrap(); + + assert_eq!( + x_ty.display(&TypingContext::global(&db)).to_string(), + "Literal[10]" + ); + + db.memory_file_system() + .write_file("/src/foo.py", "x = 10\ndef foo(): pass")?; + + let a = system_path_to_file(&db, "/src/a.py").unwrap(); + let foo = system_path_to_file(&db, "/src/foo.py").unwrap(); + + foo.touch(&mut db); + + db.clear_salsa_events(); + + let x_ty_2 = public_symbol_ty_by_name(&db, a, "x").unwrap(); + + assert_eq!( + x_ty_2.display(&TypingContext::global(&db)).to_string(), + "Literal[10]" + ); + + let events = db.take_salsa_events(); + + let a_root_scope = root_scope(&db, a); + + assert_will_not_run_function_query::( + &db, + |ty| &ty.function, + a_root_scope, + &events, + ); + + Ok(()) + } + + #[test] + fn dependency_unrelated_public_symbol() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.memory_file_system().write_files([ + ("/src/a.py", "from foo import x"), + ("/src/foo.py", "x = 10\ny = 20"), + ])?; + + let a = system_path_to_file(&db, "/src/a.py").unwrap(); + let x_ty = public_symbol_ty_by_name(&db, a, "x").unwrap(); + + assert_eq!( + x_ty.display(&TypingContext::global(&db)).to_string(), + "Literal[10]" + ); + + db.memory_file_system() + .write_file("/src/foo.py", "x = 10\ny = 30")?; + + let a = system_path_to_file(&db, "/src/a.py").unwrap(); + let foo = system_path_to_file(&db, "/src/foo.py").unwrap(); + + foo.touch(&mut db); + + db.clear_salsa_events(); + + let x_ty_2 = public_symbol_ty_by_name(&db, a, "x").unwrap(); + + assert_eq!( + x_ty_2.display(&TypingContext::global(&db)).to_string(), + "Literal[10]" + ); + + let events = db.take_salsa_events(); + + let a_root_scope = root_scope(&db, a); + assert_will_not_run_function_query::( + &db, + |ty| &ty.function, + a_root_scope, + &events, + ); + Ok(()) + } +} diff --git a/crates/ruff_python_semantic/src/red_knot/types/display.rs b/crates/ruff_python_semantic/src/red_knot/types/display.rs new file mode 100644 index 0000000000000..804fbc2d6000f --- /dev/null +++ b/crates/ruff_python_semantic/src/red_knot/types/display.rs @@ -0,0 +1,175 @@ +//! Display implementations for types. + +use std::fmt::{Display, Formatter}; + +use crate::red_knot::types::{IntersectionType, Type, TypingContext, UnionType}; + +impl Type { + pub fn display<'a>(&'a self, context: &'a TypingContext) -> DisplayType<'a> { + DisplayType { ty: self, context } + } +} + +#[derive(Copy, Clone)] +pub struct DisplayType<'a> { + ty: &'a Type, + context: &'a TypingContext<'a>, +} + +impl Display for DisplayType<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self.ty { + Type::Any => f.write_str("Any"), + Type::Never => f.write_str("Never"), + Type::Unknown => f.write_str("Unknown"), + Type::Unbound => f.write_str("Unbound"), + Type::None => f.write_str("None"), + Type::Module(module_id) => { + write!( + f, + "", + module_id + .scope + .file(self.context.db) + .path(self.context.db.upcast()) + ) + } + // TODO functions and classes should display using a fully qualified name + Type::Class(class_id) => { + let class = class_id.lookup(self.context); + + f.write_str("Literal[")?; + f.write_str(class.name())?; + f.write_str("]") + } + Type::Instance(class_id) => { + let class = class_id.lookup(self.context); + f.write_str(class.name()) + } + Type::Function(function_id) => { + let function = function_id.lookup(self.context); + f.write_str(function.name()) + } + Type::Union(union_id) => { + let union = union_id.lookup(self.context); + + union.display(self.context).fmt(f) + } + Type::Intersection(intersection_id) => { + let intersection = intersection_id.lookup(self.context); + + intersection.display(self.context).fmt(f) + } + Type::IntLiteral(n) => write!(f, "Literal[{n}]"), + } + } +} + +impl std::fmt::Debug for DisplayType<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + std::fmt::Display::fmt(self, f) + } +} + +impl UnionType { + fn display<'a>(&'a self, context: &'a TypingContext<'a>) -> DisplayUnionType<'a> { + DisplayUnionType { context, ty: self } + } +} + +struct DisplayUnionType<'a> { + ty: &'a UnionType, + context: &'a TypingContext<'a>, +} + +impl Display for DisplayUnionType<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let union = self.ty; + + let (int_literals, other_types): (Vec, Vec) = union + .elements + .iter() + .copied() + .partition(|ty| matches!(ty, Type::IntLiteral(_))); + + let mut first = true; + if !int_literals.is_empty() { + f.write_str("Literal[")?; + let mut nums: Vec<_> = int_literals + .into_iter() + .filter_map(|ty| { + if let Type::IntLiteral(n) = ty { + Some(n) + } else { + None + } + }) + .collect(); + nums.sort_unstable(); + for num in nums { + if !first { + f.write_str(", ")?; + } + write!(f, "{num}")?; + first = false; + } + f.write_str("]")?; + } + + for ty in other_types { + if !first { + f.write_str(" | ")?; + }; + first = false; + write!(f, "{}", ty.display(self.context))?; + } + + Ok(()) + } +} + +impl std::fmt::Debug for DisplayUnionType<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + std::fmt::Display::fmt(self, f) + } +} + +impl IntersectionType { + fn display<'a>(&'a self, context: &'a TypingContext<'a>) -> DisplayIntersectionType<'a> { + DisplayIntersectionType { ty: self, context } + } +} + +struct DisplayIntersectionType<'a> { + ty: &'a IntersectionType, + context: &'a TypingContext<'a>, +} + +impl Display for DisplayIntersectionType<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let mut first = true; + for (neg, ty) in self + .ty + .positive + .iter() + .map(|ty| (false, ty)) + .chain(self.ty.negative.iter().map(|ty| (true, ty))) + { + if !first { + f.write_str(" & ")?; + }; + first = false; + if neg { + f.write_str("~")?; + }; + write!(f, "{}", ty.display(self.context))?; + } + Ok(()) + } +} + +impl std::fmt::Debug for DisplayIntersectionType<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + std::fmt::Display::fmt(self, f) + } +} diff --git a/crates/ruff_python_semantic/src/red_knot/types/infer.rs b/crates/ruff_python_semantic/src/red_knot/types/infer.rs new file mode 100644 index 0000000000000..67293971d7614 --- /dev/null +++ b/crates/ruff_python_semantic/src/red_knot/types/infer.rs @@ -0,0 +1,945 @@ +use std::sync::Arc; + +use rustc_hash::FxHashMap; + +use ruff_db::vfs::VfsFile; +use ruff_index::IndexVec; +use ruff_python_ast as ast; +use ruff_python_ast::{ExprContext, TypeParams}; + +use crate::module::resolver::resolve_module; +use crate::module::ModuleName; +use crate::name::Name; +use crate::red_knot::semantic_index::ast_ids::{ScopeAstIdNode, ScopeExpressionId}; +use crate::red_knot::semantic_index::definition::{ + Definition, ImportDefinition, ImportFromDefinition, +}; +use crate::red_knot::semantic_index::symbol::{ + FileScopeId, ScopeId, ScopeKind, ScopedSymbolId, SymbolTable, +}; +use crate::red_knot::semantic_index::{symbol_table, ChildrenIter, SemanticIndex}; +use crate::red_knot::types::{ + ClassType, FunctionType, IntersectionType, ModuleType, ScopedClassTypeId, ScopedFunctionTypeId, + ScopedIntersectionTypeId, ScopedUnionTypeId, Type, TypeId, TypingContext, UnionType, + UnionTypeBuilder, +}; +use crate::Db; + +/// The inferred types for a single scope. +#[derive(Debug, Eq, PartialEq, Default, Clone)] +pub(crate) struct TypeInference { + /// The type of the module if the scope is a module scope. + module_type: Option, + + /// The types of the defined classes in this scope. + class_types: IndexVec, + + /// The types of the defined functions in this scope. + function_types: IndexVec, + + union_types: IndexVec, + intersection_types: IndexVec, + + /// The types of every expression in this scope. + expression_tys: IndexVec, + + /// The public types of every symbol in this scope. + symbol_tys: IndexVec, +} + +impl TypeInference { + #[allow(unused)] + pub(super) fn expression_ty(&self, expression: ScopeExpressionId) -> Type { + self.expression_tys[expression] + } + + pub(super) fn symbol_ty(&self, symbol: ScopedSymbolId) -> Type { + self.symbol_tys[symbol] + } + + pub(super) fn module_ty(&self) -> &ModuleType { + self.module_type.as_ref().unwrap() + } + + pub(super) fn class_ty(&self, id: ScopedClassTypeId) -> &ClassType { + &self.class_types[id] + } + + pub(super) fn function_ty(&self, id: ScopedFunctionTypeId) -> &FunctionType { + &self.function_types[id] + } + + pub(super) fn union_ty(&self, id: ScopedUnionTypeId) -> &UnionType { + &self.union_types[id] + } + + pub(super) fn intersection_ty(&self, id: ScopedIntersectionTypeId) -> &IntersectionType { + &self.intersection_types[id] + } + + fn shrink_to_fit(&mut self) { + self.class_types.shrink_to_fit(); + self.function_types.shrink_to_fit(); + self.union_types.shrink_to_fit(); + self.intersection_types.shrink_to_fit(); + + self.expression_tys.shrink_to_fit(); + self.symbol_tys.shrink_to_fit(); + } +} + +/// Builder to infer all types in a [`ScopeId`]. +pub(super) struct TypeInferenceBuilder<'a> { + db: &'a dyn Db, + + // Cached lookups + index: &'a SemanticIndex, + scope: ScopeId, + file_scope_id: FileScopeId, + file_id: VfsFile, + symbol_table: Arc, + + /// The type inference results + types: TypeInference, + definition_tys: FxHashMap, + children_scopes: ChildrenIter<'a>, +} + +impl<'a> TypeInferenceBuilder<'a> { + /// Creates a new builder for inferring the types of `scope`. + pub(super) fn new(db: &'a dyn Db, scope: ScopeId, index: &'a SemanticIndex) -> Self { + let file_scope_id = scope.file_scope_id(db); + let file = scope.file(db); + let children_scopes = index.child_scopes(file_scope_id); + let symbol_table = index.symbol_table(file_scope_id); + + Self { + index, + file_scope_id, + file_id: file, + scope, + symbol_table, + + db, + types: TypeInference::default(), + definition_tys: FxHashMap::default(), + children_scopes, + } + } + + /// Infers the types of a `module`. + pub(super) fn infer_module(&mut self, module: &ast::ModModule) { + self.infer_body(&module.body); + } + + pub(super) fn infer_class_type_params(&mut self, class: &ast::StmtClassDef) { + if let Some(type_params) = class.type_params.as_deref() { + self.infer_type_parameters(type_params); + } + } + + pub(super) fn infer_class_body(&mut self, class: &ast::StmtClassDef) { + self.infer_body(&class.body); + } + + pub(super) fn infer_function_type_params(&mut self, function: &ast::StmtFunctionDef) { + if let Some(type_params) = function.type_params.as_deref() { + self.infer_type_parameters(type_params); + } + } + + pub(super) fn infer_function_body(&mut self, function: &ast::StmtFunctionDef) { + self.infer_body(&function.body); + } + + fn infer_body(&mut self, suite: &[ast::Stmt]) { + for statement in suite { + self.infer_statement(statement); + } + } + + fn infer_statement(&mut self, statement: &ast::Stmt) { + match statement { + ast::Stmt::FunctionDef(function) => self.infer_function_definition_statement(function), + ast::Stmt::ClassDef(class) => self.infer_class_definition_statement(class), + ast::Stmt::Expr(ast::StmtExpr { range: _, value }) => { + self.infer_expression(value); + } + ast::Stmt::If(if_statement) => self.infer_if_statement(if_statement), + ast::Stmt::Assign(assign) => self.infer_assignment_statement(assign), + ast::Stmt::AnnAssign(assign) => self.infer_annotated_assignment_statement(assign), + ast::Stmt::For(for_statement) => self.infer_for_statement(for_statement), + ast::Stmt::Import(import) => self.infer_import_statement(import), + ast::Stmt::ImportFrom(import) => self.infer_import_from_statement(import), + ast::Stmt::Break(_) | ast::Stmt::Continue(_) | ast::Stmt::Pass(_) => { + // No-op + } + _ => {} + } + } + + fn infer_function_definition_statement(&mut self, function: &ast::StmtFunctionDef) { + let ast::StmtFunctionDef { + range: _, + is_async: _, + name, + type_params: _, + parameters: _, + returns, + body: _, + decorator_list, + } = function; + + let function_id = function.scope_ast_id(self.db, self.file_id, self.file_scope_id); + let decorator_tys = decorator_list + .iter() + .map(|decorator| self.infer_decorator(decorator)) + .collect(); + + // TODO: Infer parameters + + if let Some(return_ty) = returns { + self.infer_expression(return_ty); + } + + let function_ty = self.function_ty(FunctionType { + name: Name::new(&name.id), + decorators: decorator_tys, + }); + + // Skip over the function or type params child scope. + let (_, scope) = self.children_scopes.next().unwrap(); + + assert!(matches!( + scope.kind(), + ScopeKind::Function | ScopeKind::Annotation + )); + + self.definition_tys + .insert(Definition::FunctionDef(function_id), function_ty); + } + + fn infer_class_definition_statement(&mut self, class: &ast::StmtClassDef) { + let ast::StmtClassDef { + range: _, + name, + type_params, + decorator_list, + arguments, + body: _, + } = class; + + let class_id = class.scope_ast_id(self.db, self.file_id, self.file_scope_id); + + for decorator in decorator_list { + self.infer_decorator(decorator); + } + + let bases = arguments + .as_deref() + .map(|arguments| self.infer_arguments(arguments)) + .unwrap_or(Vec::new()); + + // If the class has type parameters, then the class body scope is the first child scope of the type parameter's scope + // Otherwise the next scope must be the class definition scope. + let (class_body_scope_id, class_body_scope) = if type_params.is_some() { + let (type_params_scope, _) = self.children_scopes.next().unwrap(); + self.index.child_scopes(type_params_scope).next().unwrap() + } else { + self.children_scopes.next().unwrap() + }; + + assert_eq!(class_body_scope.kind(), ScopeKind::Class); + + let class_ty = self.class_ty(ClassType { + name: Name::new(name), + bases, + body_scope: class_body_scope_id.to_scope_id(self.db, self.file_id), + }); + + self.definition_tys + .insert(Definition::ClassDef(class_id), class_ty); + } + + fn infer_if_statement(&mut self, if_statement: &ast::StmtIf) { + let ast::StmtIf { + range: _, + test, + body, + elif_else_clauses, + } = if_statement; + + self.infer_expression(test); + self.infer_body(body); + + for clause in elif_else_clauses { + let ast::ElifElseClause { + range: _, + test, + body, + } = clause; + + if let Some(test) = &test { + self.infer_expression(test); + } + + self.infer_body(body); + } + } + + fn infer_assignment_statement(&mut self, assignment: &ast::StmtAssign) { + let ast::StmtAssign { + range: _, + targets, + value, + } = assignment; + + let value_ty = self.infer_expression(value); + + for target in targets { + self.infer_expression(target); + } + + let assign_id = assignment.scope_ast_id(self.db, self.file_id, self.file_scope_id); + + // TODO: Handle multiple targets. + self.definition_tys + .insert(Definition::Assignment(assign_id), value_ty); + } + + fn infer_annotated_assignment_statement(&mut self, assignment: &ast::StmtAnnAssign) { + let ast::StmtAnnAssign { + range: _, + target, + annotation, + value, + simple: _, + } = assignment; + + if let Some(value) = value { + let _ = self.infer_expression(value); + } + + let annotation_ty = self.infer_expression(annotation); + self.infer_expression(target); + + self.definition_tys.insert( + Definition::AnnotatedAssignment(assignment.scope_ast_id( + self.db, + self.file_id, + self.file_scope_id, + )), + annotation_ty, + ); + } + + fn infer_for_statement(&mut self, for_statement: &ast::StmtFor) { + let ast::StmtFor { + range: _, + target, + iter, + body, + orelse, + is_async: _, + } = for_statement; + + self.infer_expression(iter); + self.infer_expression(target); + self.infer_body(body); + self.infer_body(orelse); + } + + fn infer_import_statement(&mut self, import: &ast::StmtImport) { + let ast::StmtImport { range: _, names } = import; + + let import_id = import.scope_ast_id(self.db, self.file_id, self.file_scope_id); + + for (i, alias) in names.iter().enumerate() { + let ast::Alias { + range: _, + name, + asname: _, + } = alias; + + let module_name = ModuleName::new(&name.id); + let module = module_name.and_then(|name| resolve_module(self.db, name)); + let module_ty = module + .map(|module| self.typing_context().module_ty(module.file())) + .unwrap_or(Type::Unknown); + + self.definition_tys.insert( + Definition::Import(ImportDefinition { + import_id, + alias: u32::try_from(i).unwrap(), + }), + module_ty, + ); + } + } + + fn infer_import_from_statement(&mut self, import: &ast::StmtImportFrom) { + let ast::StmtImportFrom { + range: _, + module, + names, + level: _, + } = import; + + let import_id = import.scope_ast_id(self.db, self.file_id, self.file_scope_id); + let module_name = ModuleName::new(module.as_deref().expect("Support relative imports")); + + let module = module_name.and_then(|module_name| resolve_module(self.db, module_name)); + let module_ty = module + .map(|module| self.typing_context().module_ty(module.file())) + .unwrap_or(Type::Unknown); + + for (i, alias) in names.iter().enumerate() { + let ast::Alias { + range: _, + name, + asname: _, + } = alias; + + let ty = module_ty + .member(&self.typing_context(), &Name::new(&name.id)) + .unwrap_or(Type::Unknown); + + self.definition_tys.insert( + Definition::ImportFrom(ImportFromDefinition { + import_id, + name: u32::try_from(i).unwrap(), + }), + ty, + ); + } + } + + fn infer_decorator(&mut self, decorator: &ast::Decorator) -> Type { + let ast::Decorator { + range: _, + expression, + } = decorator; + + self.infer_expression(expression) + } + + fn infer_arguments(&mut self, arguments: &ast::Arguments) -> Vec { + let mut types = Vec::with_capacity( + arguments + .args + .len() + .saturating_add(arguments.keywords.len()), + ); + + types.extend(arguments.args.iter().map(|arg| self.infer_expression(arg))); + + types.extend(arguments.keywords.iter().map( + |ast::Keyword { + range: _, + arg: _, + value, + }| self.infer_expression(value), + )); + + types + } + + fn infer_expression(&mut self, expression: &ast::Expr) -> Type { + let ty = match expression { + ast::Expr::NoneLiteral(ast::ExprNoneLiteral { range: _ }) => Type::None, + ast::Expr::NumberLiteral(literal) => self.infer_number_literal_expression(literal), + ast::Expr::Name(name) => self.infer_name_expression(name), + ast::Expr::Attribute(attribute) => self.infer_attribute_expression(attribute), + ast::Expr::BinOp(binary) => self.infer_binary_expression(binary), + ast::Expr::Named(named) => self.infer_named_expression(named), + ast::Expr::If(if_expression) => self.infer_if_expression(if_expression), + + _ => todo!("expression type resolution for {:?}", expression), + }; + + self.types.expression_tys.push(ty); + + ty + } + + #[allow(clippy::unused_self)] + fn infer_number_literal_expression(&mut self, literal: &ast::ExprNumberLiteral) -> Type { + let ast::ExprNumberLiteral { range: _, value } = literal; + + match value { + ast::Number::Int(n) => { + // TODO support big int literals + n.as_i64().map(Type::IntLiteral).unwrap_or(Type::Unknown) + } + // TODO builtins.float or builtins.complex + _ => Type::Unknown, + } + } + + fn infer_named_expression(&mut self, named: &ast::ExprNamed) -> Type { + let ast::ExprNamed { + range: _, + target, + value, + } = named; + + let value_ty = self.infer_expression(value); + self.infer_expression(target); + + self.definition_tys.insert( + Definition::NamedExpr(named.scope_ast_id(self.db, self.file_id, self.file_scope_id)), + value_ty, + ); + + value_ty + } + + fn infer_if_expression(&mut self, if_expression: &ast::ExprIf) -> Type { + let ast::ExprIf { + range: _, + test, + body, + orelse, + } = if_expression; + + self.infer_expression(test); + + // TODO detect statically known truthy or falsy test + let body_ty = self.infer_expression(body); + let orelse_ty = self.infer_expression(orelse); + + let union = UnionTypeBuilder::new(&self.typing_context()) + .add(body_ty) + .add(orelse_ty) + .build(); + + self.union_ty(union) + } + + fn infer_name_expression(&mut self, name: &ast::ExprName) -> Type { + let ast::ExprName { range: _, id, ctx } = name; + + match ctx { + ExprContext::Load => { + if let Some(symbol_id) = self + .index + .symbol_table(self.file_scope_id) + .symbol_id_by_name(id) + { + self.local_definition_ty(symbol_id) + } else { + let ancestors = self.index.ancestor_scopes(self.file_scope_id).skip(1); + + for (ancestor_id, _) in ancestors { + // TODO: Skip over class scopes unless the they are a immediately-nested type param scope. + // TODO: Support built-ins + + let symbol_table = + symbol_table(self.db, ancestor_id.to_scope_id(self.db, self.file_id)); + + if let Some(_symbol_id) = symbol_table.symbol_id_by_name(id) { + todo!("Return type for symbol from outer scope"); + } + } + Type::Unknown + } + } + ExprContext::Del => Type::None, + ExprContext::Invalid => Type::Unknown, + ExprContext::Store => Type::None, + } + } + + fn infer_attribute_expression(&mut self, attribute: &ast::ExprAttribute) -> Type { + let ast::ExprAttribute { + value, + attr, + range: _, + ctx, + } = attribute; + + let value_ty = self.infer_expression(value); + let member_ty = value_ty + .member(&self.typing_context(), &Name::new(&attr.id)) + .unwrap_or(Type::Unknown); + + match ctx { + ExprContext::Load => member_ty, + ExprContext::Store | ExprContext::Del => Type::None, + ExprContext::Invalid => Type::Unknown, + } + } + + fn infer_binary_expression(&mut self, binary: &ast::ExprBinOp) -> Type { + let ast::ExprBinOp { + left, + op, + right, + range: _, + } = binary; + + let left_ty = self.infer_expression(left); + let right_ty = self.infer_expression(right); + + match left_ty { + Type::Any => Type::Any, + Type::Unknown => Type::Unknown, + Type::IntLiteral(n) => { + match right_ty { + Type::IntLiteral(m) => { + match op { + ast::Operator::Add => n + .checked_add(m) + .map(Type::IntLiteral) + // TODO builtins.int + .unwrap_or(Type::Unknown), + ast::Operator::Sub => n + .checked_sub(m) + .map(Type::IntLiteral) + // TODO builtins.int + .unwrap_or(Type::Unknown), + ast::Operator::Mult => n + .checked_mul(m) + .map(Type::IntLiteral) + // TODO builtins.int + .unwrap_or(Type::Unknown), + ast::Operator::Div => n + .checked_div(m) + .map(Type::IntLiteral) + // TODO builtins.int + .unwrap_or(Type::Unknown), + ast::Operator::Mod => n + .checked_rem(m) + .map(Type::IntLiteral) + // TODO division by zero error + .unwrap_or(Type::Unknown), + _ => todo!("complete binop op support for IntLiteral"), + } + } + _ => todo!("complete binop right_ty support for IntLiteral"), + } + } + _ => todo!("complete binop support"), + } + } + + fn infer_type_parameters(&mut self, _type_parameters: &TypeParams) { + todo!("Infer type parameters") + } + + pub(super) fn finish(mut self) -> TypeInference { + let symbol_tys: IndexVec<_, _> = self + .index + .symbol_table(self.file_scope_id) + .symbol_ids() + .map(|symbol| self.local_definition_ty(symbol)) + .collect(); + + self.types.symbol_tys = symbol_tys; + self.types.shrink_to_fit(); + self.types + } + + fn union_ty(&mut self, ty: UnionType) -> Type { + Type::Union(TypeId { + scope: self.scope, + scoped: self.types.union_types.push(ty), + }) + } + + fn function_ty(&mut self, ty: FunctionType) -> Type { + Type::Function(TypeId { + scope: self.scope, + scoped: self.types.function_types.push(ty), + }) + } + + fn class_ty(&mut self, ty: ClassType) -> Type { + Type::Class(TypeId { + scope: self.scope, + scoped: self.types.class_types.push(ty), + }) + } + + fn typing_context(&self) -> TypingContext { + TypingContext::scoped(self.db, self.scope, &self.types) + } + + fn local_definition_ty(&mut self, symbol: ScopedSymbolId) -> Type { + let symbol = self.symbol_table.symbol(symbol); + let mut definitions = symbol + .definitions() + .iter() + .filter_map(|definition| self.definition_tys.get(definition).copied()); + + let Some(first) = definitions.next() else { + return Type::Unbound; + }; + + if let Some(second) = definitions.next() { + let context = self.typing_context(); + let mut builder = UnionTypeBuilder::new(&context); + builder = builder.add(first).add(second); + + for variant in definitions { + builder = builder.add(variant); + } + + self.union_ty(builder.build()) + } else { + first + } + } +} + +#[cfg(test)] +mod tests { + use ruff_db::file_system::FileSystemPathBuf; + use ruff_db::vfs::system_path_to_file; + + use crate::db::tests::TestDb; + use crate::module::resolver::{set_module_resolution_settings, ModuleResolutionSettings}; + use crate::name::Name; + use crate::red_knot::types::{public_symbol_ty_by_name, Type, TypingContext}; + + fn setup_db() -> TestDb { + let mut db = TestDb::new(); + + set_module_resolution_settings( + &mut db, + ModuleResolutionSettings { + extra_paths: Vec::new(), + workspace_root: FileSystemPathBuf::from("/src"), + site_packages: None, + custom_typeshed: None, + }, + ); + + db + } + + fn assert_public_ty(db: &TestDb, file_name: &str, symbol_name: &str, expected: &str) { + let file = system_path_to_file(db, file_name).expect("Expected file to exist."); + + let ty = public_symbol_ty_by_name(db, file, symbol_name).unwrap_or(Type::Unknown); + assert_eq!(ty.display(&TypingContext::global(db)).to_string(), expected); + } + + #[test] + fn follow_import_to_class() -> anyhow::Result<()> { + let db = setup_db(); + + db.memory_file_system().write_files([ + ("src/a.py", "from b import C as D; E = D"), + ("src/b.py", "class C: pass"), + ])?; + + assert_public_ty(&db, "src/a.py", "E", "Literal[C]"); + + Ok(()) + } + + #[test] + fn resolve_base_class_by_name() -> anyhow::Result<()> { + let db = setup_db(); + + db.memory_file_system().write_file( + "src/mod.py", + r#" +class Base: + pass + +class Sub(Base): + pass"#, + )?; + + let mod_file = system_path_to_file(&db, "src/mod.py").expect("Expected file to exist."); + let ty = public_symbol_ty_by_name(&db, mod_file, "Sub").expect("Symbol type to exist"); + + let Type::Class(class_id) = ty else { + panic!("Sub is not a Class") + }; + + let context = TypingContext::global(&db); + + let base_names: Vec<_> = class_id + .lookup(&context) + .bases() + .iter() + .map(|base_ty| format!("{}", base_ty.display(&context))) + .collect(); + + assert_eq!(base_names, vec!["Literal[Base]"]); + + Ok(()) + } + + #[test] + fn resolve_method() -> anyhow::Result<()> { + let db = setup_db(); + + db.memory_file_system().write_file( + "src/mod.py", + " +class C: + def f(self): pass + ", + )?; + + let mod_file = system_path_to_file(&db, "src/mod.py").unwrap(); + let ty = public_symbol_ty_by_name(&db, mod_file, "C").unwrap(); + + let Type::Class(class_id) = ty else { + panic!("C is not a Class"); + }; + + let context = TypingContext::global(&db); + let member_ty = class_id.class_member(&context, &Name::new("f")); + + let Some(Type::Function(func_id)) = member_ty else { + panic!("C.f is not a Function"); + }; + + let function_ty = func_id.lookup(&context); + assert_eq!(function_ty.name(), "f"); + + Ok(()) + } + + #[test] + fn resolve_module_member() -> anyhow::Result<()> { + let db = setup_db(); + + db.memory_file_system().write_files([ + ("src/a.py", "import b; D = b.C"), + ("src/b.py", "class C: pass"), + ])?; + + assert_public_ty(&db, "src/a.py", "D", "Literal[C]"); + + Ok(()) + } + + #[test] + fn resolve_literal() -> anyhow::Result<()> { + let db = setup_db(); + + db.memory_file_system().write_file("src/a.py", "x = 1")?; + + assert_public_ty(&db, "src/a.py", "x", "Literal[1]"); + + Ok(()) + } + + #[test] + fn resolve_union() -> anyhow::Result<()> { + let db = setup_db(); + + db.memory_file_system().write_file( + "src/a.py", + " +if flag: + x = 1 +else: + x = 2 + ", + )?; + + assert_public_ty(&db, "src/a.py", "x", "Literal[1, 2]"); + + Ok(()) + } + + #[test] + fn literal_int_arithmetic() -> anyhow::Result<()> { + let db = setup_db(); + + db.memory_file_system().write_file( + "src/a.py", + " +a = 2 + 1 +b = a - 4 +c = a * b +d = c / 3 +e = 5 % 3 + ", + )?; + + assert_public_ty(&db, "src/a.py", "a", "Literal[3]"); + assert_public_ty(&db, "src/a.py", "b", "Literal[-1]"); + assert_public_ty(&db, "src/a.py", "c", "Literal[-3]"); + assert_public_ty(&db, "src/a.py", "d", "Literal[-1]"); + assert_public_ty(&db, "src/a.py", "e", "Literal[2]"); + + Ok(()) + } + + #[test] + fn walrus() -> anyhow::Result<()> { + let db = setup_db(); + + db.memory_file_system() + .write_file("src/a.py", "x = (y := 1) + 1")?; + + assert_public_ty(&db, "src/a.py", "x", "Literal[2]"); + assert_public_ty(&db, "src/a.py", "y", "Literal[1]"); + + Ok(()) + } + + #[test] + fn ifexpr() -> anyhow::Result<()> { + let db = setup_db(); + + db.memory_file_system() + .write_file("src/a.py", "x = 1 if flag else 2")?; + + assert_public_ty(&db, "src/a.py", "x", "Literal[1, 2]"); + + Ok(()) + } + + #[test] + fn ifexpr_walrus() -> anyhow::Result<()> { + let db = setup_db(); + + db.memory_file_system().write_file( + "src/a.py", + " +y = z = 0 +x = (y := 1) if flag else (z := 2) +a = y +b = z + ", + )?; + + assert_public_ty(&db, "src/a.py", "x", "Literal[1, 2]"); + assert_public_ty(&db, "src/a.py", "a", "Literal[0, 1]"); + assert_public_ty(&db, "src/a.py", "b", "Literal[0, 2]"); + + Ok(()) + } + + #[test] + fn ifexpr_nested() -> anyhow::Result<()> { + let db = setup_db(); + + db.memory_file_system() + .write_file("src/a.py", "x = 1 if flag else 2 if flag2 else 3")?; + + assert_public_ty(&db, "src/a.py", "x", "Literal[1, 2, 3]"); + + Ok(()) + } + + #[test] + fn none() -> anyhow::Result<()> { + let db = setup_db(); + + db.memory_file_system() + .write_file("src/a.py", "x = 1 if flag else None")?; + + assert_public_ty(&db, "src/a.py", "x", "Literal[1] | None"); + Ok(()) + } +}