Skip to content

Commit

Permalink
feat: add Param type to Function
Browse files Browse the repository at this point in the history
  • Loading branch information
baszalmstra committed Apr 3, 2024
1 parent b68a952 commit 66cc367
Show file tree
Hide file tree
Showing 7 changed files with 193 additions and 53 deletions.
26 changes: 13 additions & 13 deletions crates/mun_compiler/src/diagnostics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,82 +29,82 @@ mod tests {

#[test]
fn test_syntax_error() {
insta::assert_display_snapshot!(compilation_errors("\n\nfn main(\n struct Foo\n"));
insta::assert_snapshot!(compilation_errors("\n\nfn main(\n struct Foo\n"));
}

#[test]
fn test_unresolved_value_error() {
insta::assert_display_snapshot!(compilation_errors(
insta::assert_snapshot!(compilation_errors(
"\n\nfn main() {\nlet b = a;\n\nlet d = c;\n}"
));
}

#[test]
fn test_unresolved_type_error() {
insta::assert_display_snapshot!(compilation_errors(
insta::assert_snapshot!(compilation_errors(
"\n\nfn main() {\nlet a = Foo{};\n\nlet b = Bar{};\n}"
));
}

#[test]
fn test_leaked_private_type_error_function() {
insta::assert_display_snapshot!(compilation_errors(
insta::assert_snapshot!(compilation_errors(
"\n\nstruct Foo;\n pub fn Bar() -> Foo { Foo } \n fn main() {}"
));
}

#[test]
fn test_expected_function_error() {
insta::assert_display_snapshot!(compilation_errors(
insta::assert_snapshot!(compilation_errors(
"\n\nfn main() {\nlet a = Foo();\n\nlet b = Bar();\n}"
));
}

#[test]
fn test_mismatched_type_error() {
insta::assert_display_snapshot!(compilation_errors(
insta::assert_snapshot!(compilation_errors(
"\n\nfn main() {\nlet a: f64 = false;\n\nlet b: bool = 22;\n}"
));
}

#[test]
fn test_duplicate_definition_error() {
insta::assert_display_snapshot!(compilation_errors(
insta::assert_snapshot!(compilation_errors(
"\n\nfn foo(){}\n\nfn foo(){}\n\nstruct Bar;\n\nstruct Bar;\n\nfn BAZ(){}\n\nstruct BAZ;"
));
}

#[test]
fn test_possibly_uninitialized_variable_error() {
insta::assert_display_snapshot!(compilation_errors(
insta::assert_snapshot!(compilation_errors(
"\n\nfn main() {\nlet a;\nif 5>6 {\na = 5\n}\nlet b = a;\n}"
));
}

#[test]
fn test_access_unknown_field_error() {
insta::assert_display_snapshot!(compilation_errors(
insta::assert_snapshot!(compilation_errors(
"\n\nstruct Foo {\ni: bool\n}\n\nfn main() {\nlet a = Foo { i: false };\nlet b = a.t;\n}"
));
}

#[test]
fn test_free_type_alias_error() {
insta::assert_display_snapshot!(compilation_errors("\n\ntype Foo;"));
insta::assert_snapshot!(compilation_errors("\n\ntype Foo;"));
}

#[test]
fn test_type_alias_target_undeclared_error() {
insta::assert_display_snapshot!(compilation_errors("\n\ntype Foo = UnknownType;"));
insta::assert_snapshot!(compilation_errors("\n\ntype Foo = UnknownType;"));
}

#[test]
fn test_cyclic_type_alias_error() {
insta::assert_display_snapshot!(compilation_errors("\n\ntype Foo = Foo;"));
insta::assert_snapshot!(compilation_errors("\n\ntype Foo = Foo;"));
}

#[test]
fn test_expected_function() {
insta::assert_display_snapshot!(compilation_errors("\n\nfn foo() { let a = 3; a(); }"));
insta::assert_snapshot!(compilation_errors("\n\nfn foo() { let a = 3; a(); }"));
}
}
56 changes: 53 additions & 3 deletions crates/mun_hir/src/code_model/function.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{iter::once, sync::Arc};

use mun_syntax::ast::TypeAscriptionOwner;
use mun_syntax::{ast, ast::TypeAscriptionOwner};

use super::Module;
use crate::{
Expand All @@ -11,8 +11,8 @@ use crate::{
resolve::HasResolver,
type_ref::{LocalTypeRefId, TypeRefMap, TypeRefSourceMap},
visibility::RawVisibility,
Body, DefDatabase, DiagnosticSink, FileId, HasVisibility, HirDatabase, InferenceResult, Name,
Ty, Visibility,
Body, DefDatabase, DiagnosticSink, FileId, HasSource, HasVisibility, HirDatabase, InFile,
InferenceResult, Name, Ty, Visibility,
};

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)]
Expand Down Expand Up @@ -138,6 +138,20 @@ impl Function {
db.type_for_def(self.into(), Namespace::Values)
}

/// Returns the parameters of the function.
pub fn params(self, db: &dyn HirDatabase) -> Vec<Param> {
db.callable_sig(self.into())
.params()
.iter()
.enumerate()
.map(|(idx, ty)| Param {
func: self,
ty: ty.clone(),
idx,
})
.collect()
}

pub fn ret_type(self, db: &dyn HirDatabase) -> Ty {
let resolver = self.id.resolver(db.upcast());
let data = self.data(db.upcast());
Expand Down Expand Up @@ -166,6 +180,42 @@ impl Function {
}
}

#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct Param {
func: Function,
/// The index in parameter list, including self parameter.
idx: usize,
ty: Ty,
}

impl Param {
/// Returns the function to which this parameter belongs
pub fn parent_fn(&self) -> Function {
self.func
}

/// Returns the index of this parameter in the parameter list (including
/// self)
pub fn index(&self) -> usize {
self.idx
}

/// Returns the type of this parameter.
pub fn ty(&self) -> &Ty {
&self.ty
}

/// Returns the source of the parameter.
pub fn source(&self, db: &dyn HirDatabase) -> Option<InFile<ast::Param>> {
let InFile { file_id, value } = self.func.source(db.upcast());
let params = value.param_list()?;
params
.params()
.nth(self.idx)
.map(|value| InFile { file_id, value })
}
}

impl HasVisibility for Function {
fn visibility(&self, db: &dyn HirDatabase) -> Visibility {
self.data(db.upcast())
Expand Down
27 changes: 22 additions & 5 deletions crates/mun_hir/src/item_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ use std::{
sync::Arc,
};

use mun_syntax::{ast, AstNode};
use mun_syntax::ast;

use crate::{
arena::{Arena, Idx},
path::ImportAlias,
source_id::FileAstId,
source_id::{AstIdNode, FileAstId},
type_ref::{LocalTypeRefId, TypeRefMap},
visibility::RawVisibility,
DefDatabase, FileId, InFile, Name, Path,
Expand Down Expand Up @@ -112,6 +112,7 @@ impl ItemVisibilities {
struct ItemTreeData {
imports: Arena<Import>,
functions: Arena<Function>,
params: Arena<Param>,
structs: Arena<Struct>,
fields: Arena<Field>,
type_aliases: Arena<TypeAlias>,
Expand All @@ -122,7 +123,7 @@ struct ItemTreeData {

/// Trait implemented by all item nodes in the item tree.
pub trait ItemTreeNode: Clone {
type Source: AstNode + Into<ast::ModuleItem>;
type Source: AstIdNode + Into<ast::ModuleItem>;

/// Returns the AST id for this instance
fn ast_id(&self) -> FileAstId<Self::Source>;
Expand Down Expand Up @@ -244,7 +245,7 @@ macro_rules! impl_index {
};
}

impl_index!(fields: Field);
impl_index!(fields: Field, params: Param);

static VIS_PUB: RawVisibility = RawVisibility::Public;
static VIS_PRIV: RawVisibility = RawVisibility::This;
Expand Down Expand Up @@ -302,11 +303,22 @@ pub struct Function {
pub visibility: RawVisibilityId,
pub is_extern: bool,
pub types: TypeRefMap,
pub params: Box<[LocalTypeRefId]>,
pub params: IdRange<Param>,
pub ret_type: LocalTypeRefId,
pub ast_id: FileAstId<ast::FunctionDef>,
}

#[derive(Debug, Clone, Eq, PartialEq)]
pub struct Param {
pub type_ref: LocalTypeRefId,
pub ast_id: ParamAstId,
}

#[derive(Debug, Clone, Eq, PartialEq)]
pub enum ParamAstId {
Param(FileAstId<ast::Param>),
}

#[derive(Debug, Clone, Eq, PartialEq)]
pub struct Struct {
pub name: Name,
Expand Down Expand Up @@ -390,6 +402,11 @@ impl<T> IdRange<T> {
_p: PhantomData,
}
}

/// Returns true if the index range is empty
pub fn is_empty(&self) -> bool {
self.range.is_empty()
}
}

impl<T> Iterator for IdRange<T> {
Expand Down
23 changes: 18 additions & 5 deletions crates/mun_hir/src/item_tree/lower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ use smallvec::SmallVec;

use super::{
diagnostics, AssociatedItem, Field, Fields, Function, IdRange, Impl, ItemTree, ItemTreeData,
ItemTreeNode, ItemVisibilities, LocalItemTreeId, ModItem, RawVisibilityId, Struct, TypeAlias,
ItemTreeNode, ItemVisibilities, LocalItemTreeId, ModItem, Param, ParamAstId, RawVisibilityId,
Struct, TypeAlias,
};
use crate::{
arena::{Idx, RawId},
Expand Down Expand Up @@ -156,13 +157,19 @@ impl Context {
let mut types = TypeRefMap::builder();

// Lower all the params
let mut params = Vec::new();
let start_param_idx = self.next_param_idx();
if let Some(param_list) = func.param_list() {
for param in param_list.params() {
let ast_id = self.source_ast_id_map.ast_id(&param);
let type_ref = types.alloc_from_node_opt(param.ascribed_type().as_ref());
params.push(type_ref);
self.data.params.alloc(Param {
type_ref,
ast_id: ParamAstId::Param(ast_id),
});
}
}
let end_param_idx = self.next_param_idx();
let params = IdRange::new(start_param_idx..end_param_idx);

// Lowers the return type
let ret_type = match func.ret_type().and_then(|rt| rt.type_ref()) {
Expand All @@ -177,9 +184,9 @@ impl Context {
let res = Function {
name,
visibility,
types,
is_extern,
params: params.into_boxed_slice(),
types,
params,
ret_type,
ast_id,
};
Expand Down Expand Up @@ -313,6 +320,12 @@ impl Context {
let idx: u32 = self.data.fields.len().try_into().expect("too many fields");
Idx::from_raw(RawId::from(idx))
}

/// Returns the `Idx` of the next `Param`
fn next_param_idx(&self) -> Idx<Param> {
let idx: u32 = self.data.params.len().try_into().expect("too many params");
Idx::from_raw(RawId::from(idx))
}
}

/// Lowers a record field (e.g. `a:i32`)
Expand Down
10 changes: 7 additions & 3 deletions crates/mun_hir/src/item_tree/pretty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::{fmt, fmt::Write};

use crate::{
item_tree::{
Fields, Function, Impl, Import, ItemTree, LocalItemTreeId, ModItem, RawVisibilityId,
Fields, Function, Impl, Import, ItemTree, LocalItemTreeId, ModItem, Param, RawVisibilityId,
Struct, TypeAlias,
},
path::ImportAlias,
Expand Down Expand Up @@ -181,8 +181,12 @@ impl Printer<'_> {
write!(self, "(")?;
if !params.is_empty() {
self.indented(|this| {
for param in params.iter().copied() {
this.print_type_ref(param, types)?;
for param in params.clone() {
let Param {
type_ref,
ast_id: _,
} = &this.tree[param];
this.print_type_ref(*type_ref, types)?;
writeln!(this, ",")?;
}
Ok(())
Expand Down
Loading

0 comments on commit 66cc367

Please sign in to comment.