From 66cc3674761510760ce0e97066f4b4b3d6e28ed8 Mon Sep 17 00:00:00 2001 From: Bas Zalmstra Date: Wed, 3 Apr 2024 16:37:35 +0200 Subject: [PATCH] feat: add Param type to Function --- crates/mun_compiler/src/diagnostics.rs | 26 +++--- crates/mun_hir/src/code_model/function.rs | 56 +++++++++++- crates/mun_hir/src/item_tree.rs | 27 ++++-- crates/mun_hir/src/item_tree/lower.rs | 23 +++-- crates/mun_hir/src/item_tree/pretty.rs | 10 ++- crates/mun_hir/src/source_id.rs | 102 +++++++++++++++++----- crates/mun_language_server/Cargo.toml | 2 +- 7 files changed, 193 insertions(+), 53 deletions(-) diff --git a/crates/mun_compiler/src/diagnostics.rs b/crates/mun_compiler/src/diagnostics.rs index 4867b819..3eb5213c 100644 --- a/crates/mun_compiler/src/diagnostics.rs +++ b/crates/mun_compiler/src/diagnostics.rs @@ -29,82 +29,82 @@ mod tests { #[test] fn test_syntax_error() { - insta::assert_display_snapshot!(compilation_errors("\n\nfn main(\n struct Foo\n")); + insta::assert_snapshot!(compilation_errors("\n\nfn main(\n struct Foo\n")); } #[test] fn test_unresolved_value_error() { - insta::assert_display_snapshot!(compilation_errors( + insta::assert_snapshot!(compilation_errors( "\n\nfn main() {\nlet b = a;\n\nlet d = c;\n}" )); } #[test] fn test_unresolved_type_error() { - insta::assert_display_snapshot!(compilation_errors( + insta::assert_snapshot!(compilation_errors( "\n\nfn main() {\nlet a = Foo{};\n\nlet b = Bar{};\n}" )); } #[test] fn test_leaked_private_type_error_function() { - insta::assert_display_snapshot!(compilation_errors( + insta::assert_snapshot!(compilation_errors( "\n\nstruct Foo;\n pub fn Bar() -> Foo { Foo } \n fn main() {}" )); } #[test] fn test_expected_function_error() { - insta::assert_display_snapshot!(compilation_errors( + insta::assert_snapshot!(compilation_errors( "\n\nfn main() {\nlet a = Foo();\n\nlet b = Bar();\n}" )); } #[test] fn test_mismatched_type_error() { - insta::assert_display_snapshot!(compilation_errors( + insta::assert_snapshot!(compilation_errors( "\n\nfn main() {\nlet a: f64 = false;\n\nlet b: bool = 22;\n}" )); } #[test] fn test_duplicate_definition_error() { - insta::assert_display_snapshot!(compilation_errors( + insta::assert_snapshot!(compilation_errors( "\n\nfn foo(){}\n\nfn foo(){}\n\nstruct Bar;\n\nstruct Bar;\n\nfn BAZ(){}\n\nstruct BAZ;" )); } #[test] fn test_possibly_uninitialized_variable_error() { - insta::assert_display_snapshot!(compilation_errors( + insta::assert_snapshot!(compilation_errors( "\n\nfn main() {\nlet a;\nif 5>6 {\na = 5\n}\nlet b = a;\n}" )); } #[test] fn test_access_unknown_field_error() { - insta::assert_display_snapshot!(compilation_errors( + insta::assert_snapshot!(compilation_errors( "\n\nstruct Foo {\ni: bool\n}\n\nfn main() {\nlet a = Foo { i: false };\nlet b = a.t;\n}" )); } #[test] fn test_free_type_alias_error() { - insta::assert_display_snapshot!(compilation_errors("\n\ntype Foo;")); + insta::assert_snapshot!(compilation_errors("\n\ntype Foo;")); } #[test] fn test_type_alias_target_undeclared_error() { - insta::assert_display_snapshot!(compilation_errors("\n\ntype Foo = UnknownType;")); + insta::assert_snapshot!(compilation_errors("\n\ntype Foo = UnknownType;")); } #[test] fn test_cyclic_type_alias_error() { - insta::assert_display_snapshot!(compilation_errors("\n\ntype Foo = Foo;")); + insta::assert_snapshot!(compilation_errors("\n\ntype Foo = Foo;")); } #[test] fn test_expected_function() { - insta::assert_display_snapshot!(compilation_errors("\n\nfn foo() { let a = 3; a(); }")); + insta::assert_snapshot!(compilation_errors("\n\nfn foo() { let a = 3; a(); }")); } } diff --git a/crates/mun_hir/src/code_model/function.rs b/crates/mun_hir/src/code_model/function.rs index 6a1a0b8b..44e83091 100644 --- a/crates/mun_hir/src/code_model/function.rs +++ b/crates/mun_hir/src/code_model/function.rs @@ -1,6 +1,6 @@ use std::{iter::once, sync::Arc}; -use mun_syntax::ast::TypeAscriptionOwner; +use mun_syntax::{ast, ast::TypeAscriptionOwner}; use super::Module; use crate::{ @@ -11,8 +11,8 @@ use crate::{ resolve::HasResolver, type_ref::{LocalTypeRefId, TypeRefMap, TypeRefSourceMap}, visibility::RawVisibility, - Body, DefDatabase, DiagnosticSink, FileId, HasVisibility, HirDatabase, InferenceResult, Name, - Ty, Visibility, + Body, DefDatabase, DiagnosticSink, FileId, HasSource, HasVisibility, HirDatabase, InFile, + InferenceResult, Name, Ty, Visibility, }; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)] @@ -138,6 +138,20 @@ impl Function { db.type_for_def(self.into(), Namespace::Values) } + /// Returns the parameters of the function. + pub fn params(self, db: &dyn HirDatabase) -> Vec { + db.callable_sig(self.into()) + .params() + .iter() + .enumerate() + .map(|(idx, ty)| Param { + func: self, + ty: ty.clone(), + idx, + }) + .collect() + } + pub fn ret_type(self, db: &dyn HirDatabase) -> Ty { let resolver = self.id.resolver(db.upcast()); let data = self.data(db.upcast()); @@ -166,6 +180,42 @@ impl Function { } } +#[derive(Clone, PartialEq, Eq, Hash, Debug)] +pub struct Param { + func: Function, + /// The index in parameter list, including self parameter. + idx: usize, + ty: Ty, +} + +impl Param { + /// Returns the function to which this parameter belongs + pub fn parent_fn(&self) -> Function { + self.func + } + + /// Returns the index of this parameter in the parameter list (including + /// self) + pub fn index(&self) -> usize { + self.idx + } + + /// Returns the type of this parameter. + pub fn ty(&self) -> &Ty { + &self.ty + } + + /// Returns the source of the parameter. + pub fn source(&self, db: &dyn HirDatabase) -> Option> { + let InFile { file_id, value } = self.func.source(db.upcast()); + let params = value.param_list()?; + params + .params() + .nth(self.idx) + .map(|value| InFile { file_id, value }) + } +} + impl HasVisibility for Function { fn visibility(&self, db: &dyn HirDatabase) -> Visibility { self.data(db.upcast()) diff --git a/crates/mun_hir/src/item_tree.rs b/crates/mun_hir/src/item_tree.rs index 378ea690..58c5be3f 100644 --- a/crates/mun_hir/src/item_tree.rs +++ b/crates/mun_hir/src/item_tree.rs @@ -13,12 +13,12 @@ use std::{ sync::Arc, }; -use mun_syntax::{ast, AstNode}; +use mun_syntax::ast; use crate::{ arena::{Arena, Idx}, path::ImportAlias, - source_id::FileAstId, + source_id::{AstIdNode, FileAstId}, type_ref::{LocalTypeRefId, TypeRefMap}, visibility::RawVisibility, DefDatabase, FileId, InFile, Name, Path, @@ -112,6 +112,7 @@ impl ItemVisibilities { struct ItemTreeData { imports: Arena, functions: Arena, + params: Arena, structs: Arena, fields: Arena, type_aliases: Arena, @@ -122,7 +123,7 @@ struct ItemTreeData { /// Trait implemented by all item nodes in the item tree. pub trait ItemTreeNode: Clone { - type Source: AstNode + Into; + type Source: AstIdNode + Into; /// Returns the AST id for this instance fn ast_id(&self) -> FileAstId; @@ -244,7 +245,7 @@ macro_rules! impl_index { }; } -impl_index!(fields: Field); +impl_index!(fields: Field, params: Param); static VIS_PUB: RawVisibility = RawVisibility::Public; static VIS_PRIV: RawVisibility = RawVisibility::This; @@ -302,11 +303,22 @@ pub struct Function { pub visibility: RawVisibilityId, pub is_extern: bool, pub types: TypeRefMap, - pub params: Box<[LocalTypeRefId]>, + pub params: IdRange, pub ret_type: LocalTypeRefId, pub ast_id: FileAstId, } +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct Param { + pub type_ref: LocalTypeRefId, + pub ast_id: ParamAstId, +} + +#[derive(Debug, Clone, Eq, PartialEq)] +pub enum ParamAstId { + Param(FileAstId), +} + #[derive(Debug, Clone, Eq, PartialEq)] pub struct Struct { pub name: Name, @@ -390,6 +402,11 @@ impl IdRange { _p: PhantomData, } } + + /// Returns true if the index range is empty + pub fn is_empty(&self) -> bool { + self.range.is_empty() + } } impl Iterator for IdRange { diff --git a/crates/mun_hir/src/item_tree/lower.rs b/crates/mun_hir/src/item_tree/lower.rs index 736a42d8..5ea75f6f 100644 --- a/crates/mun_hir/src/item_tree/lower.rs +++ b/crates/mun_hir/src/item_tree/lower.rs @@ -9,7 +9,8 @@ use smallvec::SmallVec; use super::{ diagnostics, AssociatedItem, Field, Fields, Function, IdRange, Impl, ItemTree, ItemTreeData, - ItemTreeNode, ItemVisibilities, LocalItemTreeId, ModItem, RawVisibilityId, Struct, TypeAlias, + ItemTreeNode, ItemVisibilities, LocalItemTreeId, ModItem, Param, ParamAstId, RawVisibilityId, + Struct, TypeAlias, }; use crate::{ arena::{Idx, RawId}, @@ -156,13 +157,19 @@ impl Context { let mut types = TypeRefMap::builder(); // Lower all the params - let mut params = Vec::new(); + let start_param_idx = self.next_param_idx(); if let Some(param_list) = func.param_list() { for param in param_list.params() { + let ast_id = self.source_ast_id_map.ast_id(¶m); let type_ref = types.alloc_from_node_opt(param.ascribed_type().as_ref()); - params.push(type_ref); + self.data.params.alloc(Param { + type_ref, + ast_id: ParamAstId::Param(ast_id), + }); } } + let end_param_idx = self.next_param_idx(); + let params = IdRange::new(start_param_idx..end_param_idx); // Lowers the return type let ret_type = match func.ret_type().and_then(|rt| rt.type_ref()) { @@ -177,9 +184,9 @@ impl Context { let res = Function { name, visibility, - types, is_extern, - params: params.into_boxed_slice(), + types, + params, ret_type, ast_id, }; @@ -313,6 +320,12 @@ impl Context { let idx: u32 = self.data.fields.len().try_into().expect("too many fields"); Idx::from_raw(RawId::from(idx)) } + + /// Returns the `Idx` of the next `Param` + fn next_param_idx(&self) -> Idx { + let idx: u32 = self.data.params.len().try_into().expect("too many params"); + Idx::from_raw(RawId::from(idx)) + } } /// Lowers a record field (e.g. `a:i32`) diff --git a/crates/mun_hir/src/item_tree/pretty.rs b/crates/mun_hir/src/item_tree/pretty.rs index 8b7988d8..08b44221 100644 --- a/crates/mun_hir/src/item_tree/pretty.rs +++ b/crates/mun_hir/src/item_tree/pretty.rs @@ -2,7 +2,7 @@ use std::{fmt, fmt::Write}; use crate::{ item_tree::{ - Fields, Function, Impl, Import, ItemTree, LocalItemTreeId, ModItem, RawVisibilityId, + Fields, Function, Impl, Import, ItemTree, LocalItemTreeId, ModItem, Param, RawVisibilityId, Struct, TypeAlias, }, path::ImportAlias, @@ -181,8 +181,12 @@ impl Printer<'_> { write!(self, "(")?; if !params.is_empty() { self.indented(|this| { - for param in params.iter().copied() { - this.print_type_ref(param, types)?; + for param in params.clone() { + let Param { + type_ref, + ast_id: _, + } = &this.tree[param]; + this.print_type_ref(*type_ref, types)?; writeln!(this, ",")?; } Ok(()) diff --git a/crates/mun_hir/src/source_id.rs b/crates/mun_hir/src/source_id.rs index 6fab8ee1..8586b650 100644 --- a/crates/mun_hir/src/source_id.rs +++ b/crates/mun_hir/src/source_id.rs @@ -4,7 +4,7 @@ use std::{ sync::Arc, }; -use mun_syntax::{ast, AstNode, AstPtr, SyntaxNode, SyntaxNodePtr}; +use mun_syntax::{ast, AstNode, AstPtr, SyntaxNode, SyntaxNodePtr, WalkEvent}; use crate::{ arena::{Arena, Idx}, @@ -13,12 +13,14 @@ use crate::{ FileId, }; +type ErasedFileAstId = Idx; + /// `AstId` points to an AST node in any file. /// /// It is stable across reparses, and can be used as salsa key/value. pub(crate) type AstId = InFile>; -impl AstId { +impl AstId { pub fn to_node(self, db: &dyn AstDatabase) -> N { let root = db.parse(self.file_id); db.ast_id_map(self.file_id) @@ -28,29 +30,51 @@ impl AstId { } #[derive(Clone, Debug)] -pub struct FileAstId { +pub struct FileAstId { raw: ErasedFileAstId, _ty: PhantomData N>, } -impl Copy for FileAstId {} +impl Copy for FileAstId {} -impl PartialEq for FileAstId { +impl PartialEq for FileAstId { fn eq(&self, other: &Self) -> bool { self.raw == other.raw } } -impl Eq for FileAstId {} -impl Hash for FileAstId { +impl Eq for FileAstId {} +impl Hash for FileAstId { fn hash(&self, hasher: &mut H) { self.raw.hash(hasher); } } -impl FileAstId { - pub(crate) fn with_file_id(self, file_id: FileId) -> AstId { - AstId::new(file_id, self) - } +/// A trait that is implemented for all nodes that can be represented as a +/// `FileAstId`. +pub trait AstIdNode: AstNode {} + +macro_rules! register_ast_id_node { + (impl AstIdNode for $($ident:ident),+ ) => { + $( + impl AstIdNode for ast::$ident {} + )+ + fn should_alloc_id(kind: mun_syntax::SyntaxKind) -> bool { + $( + ast::$ident::can_cast(kind) + )||+ + } + }; +} + +register_ast_id_node! { + impl AstIdNode for + ModuleItem, + Use, + FunctionDef, + StructDef, + Impl, + TypeAliasDef, + Param } /// Maps items' `SyntaxNode`s to `ErasedFileAstId`s and back. @@ -59,15 +83,13 @@ pub struct AstIdMap { arena: Arena, } -type ErasedFileAstId = Idx; - impl AstIdMap { pub(crate) fn ast_id_map_query(db: &dyn AstDatabase, file_id: FileId) -> Arc { let map = AstIdMap::from_source(db.parse(file_id).tree().syntax()); Arc::new(map) } - pub(crate) fn ast_id(&self, item: &N) -> FileAstId { + pub(crate) fn ast_id(&self, item: &N) -> FileAstId { let ptr = SyntaxNodePtr::new(item.syntax()); let raw = match self.arena.iter().find(|(_id, i)| **i == ptr) { Some((it, _)) => it, @@ -88,22 +110,31 @@ impl AstIdMap { /// `node` must be the root of a syntax tree. fn from_source(node: &SyntaxNode) -> AstIdMap { assert!(node.parent().is_none()); - let mut res = AstIdMap::default(); + + // Make sure the root node is allocated + if !should_alloc_id(node.kind()) { + res.alloc(node); + } + // By walking the tree in breadth-first order we make sure that parents // get lower ids then children. That is, adding a new child does not // change parent's id. This means that, say, adding a new function to a // trait does not change ids of top-level items, which helps caching. - bfs(node, |it| { - if let Some(module_item) = ast::ModuleItem::cast(it) { - res.alloc(module_item.syntax()); + bdfs(node, |it| { + if should_alloc_id(it.kind()) { + res.alloc(&it); + TreeOrder::BreadthFirst + } else { + TreeOrder::DepthFirst } }); + res } /// Returns the `AstPtr` of the given id. - pub(crate) fn get(&self, id: FileAstId) -> AstPtr { + pub(crate) fn get(&self, id: FileAstId) -> AstPtr { self.arena[id.raw].clone().try_cast::().unwrap() } @@ -113,14 +144,39 @@ impl AstIdMap { } } -/// Walks the subtree in bfs order, calling `f` for each node. -fn bfs(node: &SyntaxNode, mut f: impl FnMut(SyntaxNode)) { +#[derive(Copy, Clone, PartialEq, Eq)] +enum TreeOrder { + BreadthFirst, + DepthFirst, +} + +/// Walks the subtree in bdfs order, calling `f` for each node. +/// +/// ### What is bdfs order? +/// +/// It is a mix of breadth-first and depth first orders. Nodes for which `f` +/// returns [`TreeOrder::BreadthFirst`] are visited breadth-first, all the other +/// nodes are explored [`TreeOrder::DepthFirst`]. +/// +/// In other words, the size of the bfs queue is bound by the number of "true" +/// nodes. +fn bdfs(node: &SyntaxNode, mut f: impl FnMut(SyntaxNode) -> TreeOrder) { let mut curr_layer = vec![node.clone()]; let mut next_layer = vec![]; while !curr_layer.is_empty() { curr_layer.drain(..).for_each(|node| { - next_layer.extend(node.children()); - f(node); + let mut preorder = node.preorder(); + while let Some(event) = preorder.next() { + match event { + WalkEvent::Enter(node) => { + if f(node.clone()) == TreeOrder::BreadthFirst { + next_layer.extend(node.children()); + preorder.skip_subtree(); + } + } + WalkEvent::Leave(_) => {} + } + } }); std::mem::swap(&mut curr_layer, &mut next_layer); } diff --git a/crates/mun_language_server/Cargo.toml b/crates/mun_language_server/Cargo.toml index d206d247..4afc1f37 100644 --- a/crates/mun_language_server/Cargo.toml +++ b/crates/mun_language_server/Cargo.toml @@ -23,7 +23,7 @@ mun_paths = { version = "0.6.0-dev", path="../mun_paths" } anyhow = { version = "1.0", default-features = false, features=["std"] } crossbeam-channel = { version = "0.5.9", default-features = false } log = { version = "0.4", default-features = false } -lsp-types = { version = "0.95.0", default-features = false } +lsp-types = { version = "=0.95", default-features = false } lsp-server = { version = "0.7.5", default-features = false } parking_lot = { version = "0.12.1", default-features = false } ra_ap_text_edit = { version = "0.0.190", default-features = false }