diff --git a/crates/mun_compiler/src/snapshots/mun_compiler__diagnostics__tests__mismatched_type_error.snap b/crates/mun_compiler/src/snapshots/mun_compiler__diagnostics__tests__mismatched_type_error.snap index e54d6a9fa..2f816dcd8 100644 --- a/crates/mun_compiler/src/snapshots/mun_compiler__diagnostics__tests__mismatched_type_error.snap +++ b/crates/mun_compiler/src/snapshots/mun_compiler__diagnostics__tests__mismatched_type_error.snap @@ -12,6 +12,6 @@ error: mismatched type --> main.mun:6:14 | 6 | let b: bool = 22; - | ^^ expected `bool`, found `int` + | ^^ expected `bool`, found `{integer}` | diff --git a/crates/mun_hir/Cargo.toml b/crates/mun_hir/Cargo.toml index 191fffe31..1c21b799e 100644 --- a/crates/mun_hir/Cargo.toml +++ b/crates/mun_hir/Cargo.toml @@ -16,7 +16,7 @@ mun_target={path="../mun_target"} rustc-hash = "1.1" once_cell = "0.2" relative-path = "0.4.0" -ena = "0.13" +ena = "0.14" drop_bomb = "0.1.4" either = "1.5.3" diff --git a/crates/mun_hir/src/lib.rs b/crates/mun_hir/src/lib.rs index b5ffe0547..30e250f31 100644 --- a/crates/mun_hir/src/lib.rs +++ b/crates/mun_hir/src/lib.rs @@ -29,6 +29,7 @@ mod resolve; mod source_id; mod ty; mod type_ref; +mod utils; #[cfg(test)] mod mock; diff --git a/crates/mun_hir/src/ty.rs b/crates/mun_hir/src/ty.rs index 65aa24005..dfc545c0f 100644 --- a/crates/mun_hir/src/ty.rs +++ b/crates/mun_hir/src/ty.rs @@ -5,16 +5,18 @@ mod primitives; mod resolve; use crate::display::{HirDisplay, HirFormatter}; -use crate::ty::infer::TypeVarId; +use crate::ty::infer::InferTy; use crate::ty::lower::fn_sig_for_struct_constructor; +use crate::utils::make_mut_slice; use crate::{HirDatabase, Struct, StructMemoryKind}; pub(crate) use infer::infer_query; pub use infer::InferenceResult; pub(crate) use lower::{callable_item_sig, fn_sig_for_fn, type_for_def, CallableDef, TypableDef}; pub use primitives::{FloatTy, IntTy}; pub use resolve::ResolveBitness; -use std::fmt; +use std::ops::{Deref, DerefMut}; use std::sync::Arc; +use std::{fmt, mem}; #[cfg(test)] mod tests; @@ -27,7 +29,7 @@ pub enum Ty { Apply(ApplicationTy), /// A type variable used during type checking. Not to be confused with a type parameter. - Infer(TypeVarId), + Infer(InferTy), /// A placeholder for a type which could not be computed; this is propagated to avoid useless /// error messages. Doubles as a placeholder where type variables are inserted before type @@ -167,6 +169,20 @@ impl Substs { } } +impl Deref for Substs { + type Target = [Ty]; + + fn deref(&self) -> &[Ty] { + &self.0 + } +} + +impl DerefMut for Substs { + fn deref_mut(&mut self) -> &mut [Ty] { + make_mut_slice(&mut self.0) + } +} + /// A function signature as seen by type inference: Several parameter types and /// one return type. #[derive(Clone, PartialEq, Eq, Debug)] @@ -208,7 +224,11 @@ impl HirDisplay for Ty { Ty::Apply(a_ty) => a_ty.hir_fmt(f), Ty::Unknown => write!(f, "{{unknown}}"), Ty::Empty => write!(f, "nothing"), - Ty::Infer(tv) => write!(f, "'{}", tv.0), + Ty::Infer(tv) => match tv { + InferTy::TypeVar(tv) => write!(f, "'{}", tv.0), + InferTy::IntVar(_) => write!(f, "{{integer}}"), + InferTy::FloatVar(_) => write!(f, "{{float}}"), + }, } } } @@ -246,3 +266,28 @@ impl HirDisplay for &Ty { HirDisplay::hir_fmt(*self, f) } } + +impl Ty { + fn walk_mut(&mut self, f: &mut impl FnMut(&mut Ty)) { + match self { + Ty::Apply(ty) => { + for t in ty.parameters.iter_mut() { + t.walk_mut(f); + } + } + Ty::Empty | Ty::Infer(_) | Ty::Unknown => {} + } + f(self) + } + + fn fold(mut self, f: &mut impl FnMut(Ty) -> Ty) -> Self + where + Self: Sized, + { + self.walk_mut(&mut |ty_mut| { + let ty = mem::replace(ty_mut, Ty::Unknown); + *ty_mut = f(ty); + }); + self + } +} diff --git a/crates/mun_hir/src/ty/infer.rs b/crates/mun_hir/src/ty/infer.rs index 2b75f8ac8..285d4cb70 100644 --- a/crates/mun_hir/src/ty/infer.rs +++ b/crates/mun_hir/src/ty/infer.rs @@ -13,7 +13,7 @@ use crate::{ ty::op, ty::{Ty, TypableDef}, type_ref::TypeRefId, - BinaryOp, Function, HirDatabase, Name, Path, TypeCtor, + ApplicationTy, BinaryOp, Function, HirDatabase, Name, Path, TypeCtor, }; use rustc_hash::FxHashSet; use std::ops::Index; @@ -21,9 +21,11 @@ use std::sync::Arc; mod place_expr; mod type_variable; +mod unify; -use crate::expr::{LiteralFloatKind, LiteralIntKind}; +use crate::expr::{LiteralFloat, LiteralFloatKind, LiteralInt, LiteralIntKind}; use crate::ty::primitives::{FloatTy, IntTy}; +use std::mem; pub use type_variable::TypeVarId; #[macro_export] @@ -96,6 +98,32 @@ pub fn infer_query(db: &impl HirDatabase, def: DefWithBody) -> Arc type_variable::TypeVarId { + match self { + InferTy::TypeVar(ty) | InferTy::IntVar(ty) | InferTy::FloatVar(ty) => ty, + } + } + + fn fallback_value(self) -> Ty { + match self { + InferTy::TypeVar(..) => Ty::Unknown, + InferTy::IntVar(..) => Ty::simple(TypeCtor::Int(IntTy::int())), + InferTy::FloatVar(..) => Ty::simple(TypeCtor::Float(FloatTy::float())), + } + } +} + enum ActiveLoop { Loop(Ty, Expectation), While, @@ -176,25 +204,6 @@ impl<'a, D: HirDatabase> InferenceResultBuilder<'a, D> { } } -impl<'a, D: HirDatabase> InferenceResultBuilder<'a, D> { - /// Unify the specified types, returns true if successful; false otherwise. - fn unify(&mut self, ty1: &Ty, ty2: &Ty) -> bool { - if ty1 == ty2 { - return true; - } - - self.unify_inner_trivial(&ty1, &ty2) - } - - /// This function performs trivial unifications. Returns true if a unification took place; - fn unify_inner_trivial(&mut self, ty1: &Ty, ty2: &Ty) -> bool { - match (ty1, ty2) { - (Ty::Unknown, _) | (_, Ty::Unknown) => true, - _ => false, - } - } -} - impl<'a, D: HirDatabase> InferenceResultBuilder<'a, D> { /// Collect all the parameter patterns from the body. After calling this method the `return_ty` /// will have a valid value, also all parameters are added inferred. @@ -235,7 +244,7 @@ impl<'a, D: HirDatabase> InferenceResultBuilder<'a, D> { /// Infers the type of the `tgt_expr` fn infer_expr(&mut self, tgt_expr: ExprId, expected: &Expectation) -> Ty { let ty = self.infer_expr_inner(tgt_expr, expected, &CheckParams::default()); - if !expected.is_none() && ty != expected.ty { + if !self.unify(&ty, &expected.ty) { self.diagnostics.push(InferenceDiagnostic::MismatchedTypes { expected: expected.ty.clone(), found: ty.clone(), @@ -243,7 +252,7 @@ impl<'a, D: HirDatabase> InferenceResultBuilder<'a, D> { }); }; - ty + self.resolve_ty_as_far_as_possible(ty) } /// Infer type of expression with possibly implicit coerce to the expected type. Return the type @@ -256,7 +265,7 @@ impl<'a, D: HirDatabase> InferenceResultBuilder<'a, D> { /// Performs implicit coercion of the specified `Ty` to an expected type. Returns the type after /// possible coercion. Adds a diagnostic message if coercion failed. fn coerce_expr_ty(&mut self, expr: ExprId, ty: Ty, expected: &Expectation) -> Ty { - if !self.coerce(&ty, &expected.ty) { + let ty = if !self.coerce(&ty, &expected.ty) { self.diagnostics.push(InferenceDiagnostic::MismatchedTypes { expected: expected.ty.clone(), found: ty.clone(), @@ -267,7 +276,9 @@ impl<'a, D: HirDatabase> InferenceResultBuilder<'a, D> { ty } else { expected.ty.clone() - } + }; + + self.resolve_ty_as_far_as_possible(ty) } /// Infer the type of the given expression. Returns the type of the expression. @@ -327,30 +338,27 @@ impl<'a, D: HirDatabase> InferenceResultBuilder<'a, D> { Expr::Literal(lit) => match lit { Literal::String(_) => Ty::Unknown, Literal::Bool(_) => Ty::simple(TypeCtor::Bool), - Literal::Int(ty) => { - // TODO: Add inferencing support - let ty = if let LiteralIntKind::Suffixed(suffix) = ty.kind { - IntTy { - bitness: suffix.bitness, - signedness: suffix.signedness, - } - } else { - IntTy::int() - }; - - Ty::simple(TypeCtor::Int(ty)) - } - Literal::Float(ty) => { - // TODO: Add inferencing support - let ty = if let LiteralFloatKind::Suffixed(suffix) = ty.kind { - FloatTy { - bitness: suffix.bitness, - } - } else { - FloatTy::float() - }; - Ty::simple(TypeCtor::Float(ty)) - } + Literal::Int(LiteralInt { + kind: LiteralIntKind::Suffixed(suffix), + .. + }) => Ty::simple(TypeCtor::Int(IntTy { + bitness: suffix.bitness, + signedness: suffix.signedness, + })), + Literal::Float(LiteralFloat { + kind: LiteralFloatKind::Suffixed(suffix), + .. + }) => Ty::simple(TypeCtor::Float(FloatTy { + bitness: suffix.bitness, + })), + Literal::Int(LiteralInt { + kind: LiteralIntKind::Unsuffixed, + .. + }) => self.type_variables.new_integer_var(), + Literal::Float(LiteralFloat { + kind: LiteralFloatKind::Unsuffixed, + .. + }) => self.type_variables.new_float_var(), }, Expr::Return { expr } => { if let Some(expr) = expr { @@ -428,39 +436,53 @@ impl<'a, D: HirDatabase> InferenceResultBuilder<'a, D> { } } Expr::UnaryOp { expr, op } => { - let ty = + let inner_ty = self.infer_expr_inner(*expr, &Expectation::none(), &CheckParams::default()); - if let Some(simple) = ty.as_simple() { - match op { - UnaryOp::Not => { - match simple { - TypeCtor::Bool | TypeCtor::Int(_) => ty, - _ => { - self.diagnostics.push( - InferenceDiagnostic::CannotApplyUnaryOp { id: *expr, ty }, - ); - Ty::Unknown - } - } + match op { + UnaryOp::Not => match &inner_ty { + Ty::Apply(ApplicationTy { + ctor: TypeCtor::Bool, + .. + }) + | Ty::Apply(ApplicationTy { + ctor: TypeCtor::Int(_), + .. + }) + | Ty::Infer(InferTy::IntVar(..)) => inner_ty, + _ => { + self.diagnostics + .push(InferenceDiagnostic::CannotApplyUnaryOp { + id: *expr, + ty: inner_ty, + }); + Ty::Unknown } - UnaryOp::Neg => { - match simple { - TypeCtor::Float(_) | TypeCtor::Int(_) => ty, - _ => { - self.diagnostics.push( - InferenceDiagnostic::CannotApplyUnaryOp { id: *expr, ty }, - ); - Ty::Unknown - } - } + }, + UnaryOp::Neg => match &inner_ty { + Ty::Apply(ApplicationTy { + ctor: TypeCtor::Float(_), + .. + }) + | Ty::Apply(ApplicationTy { + ctor: TypeCtor::Int(_), + .. + }) + | Ty::Infer(InferTy::IntVar(..)) + | Ty::Infer(InferTy::FloatVar(..)) => inner_ty, + _ => { + self.diagnostics + .push(InferenceDiagnostic::CannotApplyUnaryOp { + id: *expr, + ty: inner_ty, + }); + Ty::Unknown } - } - } else { - Ty::Unknown + }, } } // Expr::Block { statements: _, tail: _ } => {} }; + let ty = self.resolve_ty_as_far_as_possible(ty); self.set_expr_type(tgt_expr, ty.clone()); ty } @@ -687,19 +709,25 @@ impl<'a, D: HirDatabase> InferenceResultBuilder<'a, D> { //let mut tv_stack = Vec::new(); let mut expr_types = std::mem::take(&mut self.type_of_expr); for (expr, ty) in expr_types.iter_mut() { - //let resolved = self.resolve_ty_completely(&mut tv_stack, mem::replace(ty, Ty::Unknown)); - if *ty == Ty::Unknown { + let was_unknown = ty == &mut Ty::Unknown; + let resolved = self + .type_variables + .resolve_ty_completely(mem::replace(ty, Ty::Unknown)); + if !was_unknown && resolved == Ty::Unknown { self.report_expr_inference_failure(expr); } - //*ty = resolved; + *ty = resolved; } let mut pat_types = std::mem::take(&mut self.type_of_pat); for (pat, ty) in pat_types.iter_mut() { - //let resolved = self.resolve_ty_completely(&mut tv_stack, mem::replace(ty, Ty::Unknown)); - if *ty == Ty::Unknown { + let was_unknown = ty == &mut Ty::Unknown; + let resolved = self + .type_variables + .resolve_ty_completely(mem::replace(ty, Ty::Unknown)); + if !was_unknown && resolved == Ty::Unknown { self.report_pat_inference_failure(pat); } - //*ty = resolved; + *ty = resolved; } InferenceResult { // method_resolutions: self.method_resolutions, @@ -773,6 +801,7 @@ impl<'a, D: HirDatabase> InferenceResultBuilder<'a, D> { decl_ty }; + let ty = self.resolve_ty_as_far_as_possible(ty); self.infer_pat(*pat, ty); } Statement::Expr(expr) => { @@ -827,7 +856,7 @@ impl<'a, D: HirDatabase> InferenceResultBuilder<'a, D> { }; // Verify that it matches what we expected - let ty = if !expected.is_none() && ty != expected.ty { + let ty = if !self.unify(&ty, &expected.ty) { self.diagnostics.push(InferenceDiagnostic::MismatchedTypes { expected: expected.ty.clone(), found: ty, @@ -885,12 +914,18 @@ impl<'a, D: HirDatabase> InferenceResultBuilder<'a, D> { // self.diagnostics.push(InferenceDiagnostic::PatInferenceFailed { // pat // }); + // Currently this should never happen because we can only infer `int` and `float` types + // which always have a fallback value. + panic!("pattern failed inferencing"); } pub fn report_expr_inference_failure(&mut self, _expr: ExprId) { // self.diagnostics.push(InferenceDiagnostic::ExprInferenceFailed { // expr // }); + // Currently this should never happen because we can only infer `int` and `float` types + // which always have a fallback value. + panic!("expression failed inferencing"); } } diff --git a/crates/mun_hir/src/ty/infer/coerce.rs b/crates/mun_hir/src/ty/infer/coerce.rs index c2d715516..f005a6b00 100644 --- a/crates/mun_hir/src/ty/infer/coerce.rs +++ b/crates/mun_hir/src/ty/infer/coerce.rs @@ -5,7 +5,9 @@ impl<'a, D: HirDatabase> InferenceResultBuilder<'a, D> { /// Unify two types, but may coerce the first one to the second using implicit coercion rules if /// needed. pub(super) fn coerce(&mut self, from_ty: &Ty, to_ty: &Ty) -> bool { - self.coerce_inner(from_ty.clone(), &to_ty) + let from_ty = self.replace_if_possible(from_ty).into_owned(); + let to_ty = self.replace_if_possible(to_ty); + self.coerce_inner(from_ty, &to_ty) } /// Merge two types from different branches, with possible implicit coerce. @@ -23,7 +25,7 @@ impl<'a, D: HirDatabase> InferenceResultBuilder<'a, D> { match (&from_ty, to_ty) { (ty_app!(TypeCtor::Never), ..) => return true, _ => { - if self.unify_inner_trivial(&from_ty, &to_ty) { + if self.type_variables.unify_inner_trivial(&from_ty, &to_ty) { return true; } } diff --git a/crates/mun_hir/src/ty/infer/type_variable.rs b/crates/mun_hir/src/ty/infer/type_variable.rs index 01912509f..cd994e9c6 100644 --- a/crates/mun_hir/src/ty/infer/type_variable.rs +++ b/crates/mun_hir/src/ty/infer/type_variable.rs @@ -1,9 +1,6 @@ -use crate::Ty; -use drop_bomb::DropBomb; -use ena::snapshot_vec::{SnapshotVec, SnapshotVecDelegate}; -use ena::unify::{InPlace, InPlaceUnificationTable, NoError, UnifyKey, UnifyValue}; -use std::borrow::Cow; -use std::fmt; +use crate::{ty::infer::InferTy, ty_app, Ty, TypeCtor}; +use ena::unify::{InPlaceUnificationTable, NoError, UnifyKey, UnifyValue}; +use std::{borrow::Cow, fmt}; /// The ID of a type variable. #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)] @@ -78,7 +75,6 @@ impl UnifyValue for TypeVarValue { #[derive(Default)] pub struct TypeVariableTable { - values: SnapshotVec, eq_relations: InPlaceUnificationTable, } @@ -94,23 +90,96 @@ struct Instantiate { struct Delegate; impl TypeVariableTable { - /// Creates a new generic infer type variable - pub fn new_type_var(&mut self) -> TypeVarId { - let eq_key = self.eq_relations.new_key(TypeVarValue::Unknown); - let index = self.values.push(TypeVariableData {}); - assert_eq!(eq_key.0, index as u32); - eq_key + /// Constructs a new generic type variable type + pub fn new_type_var(&mut self) -> Ty { + Ty::Infer(InferTy::TypeVar( + self.eq_relations.new_key(TypeVarValue::Unknown), + )) + } + + /// Constructs a new type variable that is used to represent *some* integer type + pub fn new_integer_var(&mut self) -> Ty { + Ty::Infer(InferTy::IntVar( + self.eq_relations.new_key(TypeVarValue::Unknown), + )) + } + + /// Constructs a new type variable that is used to represent *some* floating-point type + pub fn new_float_var(&mut self) -> Ty { + Ty::Infer(InferTy::FloatVar( + self.eq_relations.new_key(TypeVarValue::Unknown), + )) + } + + /// Unifies the two types. If one or more type variables are involved instantiate or equate the + /// variables with each other. + pub fn unify(&mut self, a: &Ty, b: &Ty) -> bool { + self.unify_inner(a, b) + } + + /// Unifies the two types. If one or more type variables are involved instantiate or equate the + /// variables with each other. + fn unify_inner(&mut self, a: &Ty, b: &Ty) -> bool { + if a == b { + return true; + } + + // First resolve both types as much as possible + let a = self.replace_if_possible(a); + let b = self.replace_if_possible(b); + + self.unify_inner_trivial(&a, &b) + } + + /// Handles unificiation of trivial cases. + pub(crate) fn unify_inner_trivial(&mut self, a: &Ty, b: &Ty) -> bool { + match (a, b) { + // Ignore unificiation if dealing with unknown types, there are no guarentees in that case. + (Ty::Unknown, _) | (_, Ty::Unknown) => true, + + // In case of two unknowns of the same type, equate them + (Ty::Infer(InferTy::TypeVar(tv_a)), Ty::Infer(InferTy::TypeVar(tv_b))) + | (Ty::Infer(InferTy::IntVar(tv_a)), Ty::Infer(InferTy::IntVar(tv_b))) + | (Ty::Infer(InferTy::FloatVar(tv_a)), Ty::Infer(InferTy::FloatVar(tv_b))) => { + self.equate(*tv_a, *tv_b); + true + } + + // Instantiate the variable if unifying with a concrete type + (Ty::Infer(InferTy::TypeVar(tv)), other) | (other, Ty::Infer(InferTy::TypeVar(tv))) => { + self.instantiate(*tv, other.clone()); + true + } + + // Instantiate the variable if unifying an unknown integer type with a concrete integer type + (Ty::Infer(InferTy::IntVar(tv)), other @ ty_app!(TypeCtor::Int(_))) + | (other @ ty_app!(TypeCtor::Int(_)), Ty::Infer(InferTy::IntVar(tv))) => { + self.instantiate(*tv, other.clone()); + true + } + + // Instantiate the variable if unifying an unknown float type with a concrete float type + (Ty::Infer(InferTy::FloatVar(tv)), other @ ty_app!(TypeCtor::Float(_))) + | (other @ ty_app!(TypeCtor::Float(_)), Ty::Infer(InferTy::FloatVar(tv))) => { + self.instantiate(*tv, other.clone()); + true + } + + // Was not able to unify the types + _ => false, + } } /// Records that `a == b` - pub fn equate(&mut self, a: TypeVarId, b: TypeVarId) { + fn equate(&mut self, a: TypeVarId, b: TypeVarId) { debug_assert!(self.eq_relations.probe_value(a).is_unknown()); debug_assert!(self.eq_relations.probe_value(b).is_unknown()); self.eq_relations.union(a, b); } - /// Instantiates `tv` with the type `ty`. - pub fn instantiate(&mut self, tv: TypeVarId, ty: Ty) { + /// Instantiates `tv` with the type `ty`. Instantiation is the process of associating a concrete + /// type with a type variable which in turn will resolve all equated type variables. + fn instantiate(&mut self, tv: TypeVarId, ty: Ty) { debug_assert!( self.eq_relations.probe_value(tv).is_unknown(), "instantiating type variable `{:?}` twice: new-value = {:?}, old-value={:?}", @@ -121,106 +190,112 @@ impl TypeVariableTable { self.eq_relations.union_value(tv, TypeVarValue::Known(ty)); } - /// If `ty` is a type-inference variable, and it has been instantiated, then return the - /// instantiated type; otherwise returns `ty`. + /// If `ty` is a type variable, and it has been instantiated, then return the instantiated type; + /// otherwise returns `ty`. pub fn replace_if_possible<'t>(&mut self, ty: &'t Ty) -> Cow<'t, Ty> { - let ty = Cow::Borrowed(ty); - match &*ty { - Ty::Infer(tv) => match self.eq_relations.probe_value(*tv).known() { - Some(known_ty) => Cow::Owned(known_ty.clone()), - _ => ty, - }, - _ => ty, + let mut ty = Cow::Borrowed(ty); + + // The type variable could resolve to an int/float variable. Therefore try to resolve up to + // three times; each type of variable shouldn't occur more than once + for _i in 0..3 { + match &*ty { + Ty::Infer(tv) => { + let inner = tv.to_inner(); + match self.eq_relations.inlined_probe_value(inner).known() { + Some(known_ty) => ty = Cow::Owned(known_ty.clone()), + _ => return ty, + } + } + _ => return ty, + } } + + ty } - /// Returns indices of all variables that are not yet instantiated. - pub fn unsolved_variables(&mut self) -> Vec { - (0..self.values.len()) - .filter_map(|i| { - let tv = TypeVarId::from_index(i as u32); - match self.eq_relations.probe_value(tv) { - TypeVarValue::Unknown { .. } => Some(tv), - TypeVarValue::Known { .. } => None, - } - }) - .collect() + /// Resolves the type as far as currently possible, replacing type variables by their known + /// types. All types returned by the `infer_*` functions should be resolved as far as possible, + /// i.e. contain no type variables with known type. + pub(crate) fn resolve_ty_as_far_as_possible(&mut self, ty: Ty) -> Ty { + self.resolve_ty_as_far_as_possible_inner(&mut Vec::new(), ty) } - /// Returns true if the table still contains unresolved type variables - pub fn has_unsolved_variables(&mut self) -> bool { - (0..self.values.len()).any(|i| { - let tv = TypeVarId::from_index(i as u32); - match self.eq_relations.probe_value(tv) { - TypeVarValue::Unknown { .. } => true, - TypeVarValue::Known { .. } => false, + pub(crate) fn resolve_ty_as_far_as_possible_inner( + &mut self, + tv_stack: &mut Vec, + ty: Ty, + ) -> Ty { + ty.fold(&mut |ty| match ty { + Ty::Infer(tv) => { + let inner = tv.to_inner(); + if tv_stack.contains(&inner) { + return tv.fallback_value(); + } + if let Some(known_ty) = self.eq_relations.inlined_probe_value(inner).known() { + tv_stack.push(inner); + let result = + self.resolve_ty_as_far_as_possible_inner(tv_stack, known_ty.clone()); + tv_stack.pop(); + result + } else { + ty + } } + _ => ty, }) } -} -pub struct Snapshot { - snapshot: ena::snapshot_vec::Snapshot, - eq_snapshot: ena::unify::Snapshot>, - bomb: DropBomb, -} - -impl TypeVariableTable { - /// Creates a snapshot of the type variable state. This snapshot must later be committed - /// (`commit`) or rolled back (`rollback_to()`). Nested snapshots are permitted but must be - /// processed in a stack-like fashion. - pub fn snapshot(&mut self) -> Snapshot { - Snapshot { - snapshot: self.values.start_snapshot(), - eq_snapshot: self.eq_relations.snapshot(), - bomb: DropBomb::new("Snapshot must be committed or rolled back"), - } + /// Resolves the type completely; type variables without known type are replaced by Ty::Unknown. + pub(crate) fn resolve_ty_completely(&mut self, ty: Ty) -> Ty { + self.resolve_ty_completely_inner(&mut Vec::new(), ty) } - /// Undoes all changes since the snapshot was created. Any snapshot created since that point - /// must already have been committed or rolled back. - pub fn rollback_to(&mut self, s: Snapshot) { - let Snapshot { - snapshot, - eq_snapshot, - mut bomb, - } = s; - self.values.rollback_to(snapshot); - self.eq_relations.rollback_to(eq_snapshot); - bomb.defuse(); - } - - /// Commits all changes since the snapshot was created, making them permanent (unless this - /// snapshot was created within another snapshot). Any snapshot created since that point - /// must already have been committed or rolled back. - pub fn commit(&mut self, s: Snapshot) { - let Snapshot { - snapshot, - eq_snapshot, - mut bomb, - } = s; - self.values.commit(snapshot); - self.eq_relations.commit(eq_snapshot); - bomb.defuse(); + pub(crate) fn resolve_ty_completely_inner( + &mut self, + tv_stack: &mut Vec, + ty: Ty, + ) -> Ty { + ty.fold(&mut |ty| match ty { + Ty::Infer(tv) => { + let inner = tv.to_inner(); + if tv_stack.contains(&inner) { + return tv.fallback_value(); + } + if let Some(known_ty) = self.eq_relations.inlined_probe_value(inner).known() { + // known_ty may contain other variables that are known by now + tv_stack.push(inner); + let result = self.resolve_ty_completely_inner(tv_stack, known_ty.clone()); + tv_stack.pop(); + result + } else { + tv.fallback_value() + } + } + _ => ty, + }) } -} -impl SnapshotVecDelegate for Delegate { - type Value = TypeVariableData; - type Undo = Instantiate; - - fn reverse(_values: &mut Vec, _action: Instantiate) { - // We don't actually have to *do* anything to reverse an - // instantiation; the value for a variable is stored in the - // `eq_relations` and hence its rollback code will handle - // it. In fact, we could *almost* just remove the - // `SnapshotVec` entirely, except that we would have to - // reproduce *some* of its logic, since we want to know which - // type variables have been instantiated since the snapshot - // was started, so we can implement `types_escaping_snapshot`. - // - // (If we extended the `UnificationTable` to let us see which - // values have been unified and so forth, that might also - // suffice.) - } + // /// Returns indices of all variables that are not yet instantiated. + // pub fn unsolved_variables(&mut self) -> Vec { + // (0..self.values.len()) + // .filter_map(|i| { + // let tv = TypeVarId::from_index(i as u32); + // match self.eq_relations.probe_value(tv) { + // TypeVarValue::Unknown { .. } => Some(tv), + // TypeVarValue::Known { .. } => None, + // } + // }) + // .collect() + // } + // + // /// Returns true if the table still contains unresolved type variables + // pub fn has_unsolved_variables(&mut self) -> bool { + // (0..self.values.len()).any(|i| { + // let tv = TypeVarId::from_index(i as u32); + // match self.eq_relations.probe_value(tv) { + // TypeVarValue::Unknown { .. } => true, + // TypeVarValue::Known { .. } => false, + // } + // }) + // } } diff --git a/crates/mun_hir/src/ty/infer/unify.rs b/crates/mun_hir/src/ty/infer/unify.rs new file mode 100644 index 000000000..c3dba01bd --- /dev/null +++ b/crates/mun_hir/src/ty/infer/unify.rs @@ -0,0 +1,24 @@ +use crate::ty::infer::InferenceResultBuilder; +use crate::{HirDatabase, Ty}; +use std::borrow::Cow; + +impl<'a, D: HirDatabase> InferenceResultBuilder<'a, D> { + /// If `ty` is a type variable, and it has been instantiated, then return the instantiated type; + /// otherwise returns `ty`. + pub(crate) fn replace_if_possible<'b>(&mut self, ty: &'b Ty) -> Cow<'b, Ty> { + self.type_variables.replace_if_possible(ty) + } + + /// Unifies the two types. If one or more type variables are involved instantiate or equate the + /// variables with each other. + pub(crate) fn unify(&mut self, a: &Ty, b: &Ty) -> bool { + self.type_variables.unify(a, b) + } + + /// Resolves the type as far as currently possible, replacing type variables by their known + /// types. All types returned by the `infer_*` functions should be resolved as far as possible, + /// i.e. contain no type variables with known type. + pub(crate) fn resolve_ty_as_far_as_possible(&mut self, ty: Ty) -> Ty { + self.type_variables.resolve_ty_as_far_as_possible(ty) + } +} diff --git a/crates/mun_hir/src/ty/op.rs b/crates/mun_hir/src/ty/op.rs index e05f41afc..4a5db4219 100644 --- a/crates/mun_hir/src/ty/op.rs +++ b/crates/mun_hir/src/ty/op.rs @@ -1,3 +1,4 @@ +use crate::ty::infer::InferTy; use crate::{ApplicationTy, ArithOp, BinaryOp, CmpOp, Ty, TypeCtor}; /// Given a binary operation and the type on the left of that operation, returns the expected type @@ -21,6 +22,7 @@ pub(super) fn binary_op_rhs_expectation(op: BinaryOp, lhs_ty: Ty) -> Ty { } _ => Ty::Unknown, }, + Ty::Infer(InferTy::IntVar(..)) | Ty::Infer(InferTy::FloatVar(..)) => lhs_ty, _ => Ty::Unknown, }, BinaryOp::Assignment { @@ -47,6 +49,7 @@ pub(super) fn binary_op_rhs_expectation(op: BinaryOp, lhs_ty: Ty) -> Ty { TypeCtor::Bool | TypeCtor::Int(_) => lhs_ty, _ => Ty::Unknown, }, + Ty::Infer(InferTy::IntVar(..)) => lhs_ty, _ => Ty::Unknown, }, BinaryOp::CmpOp(CmpOp::Ord { .. }) @@ -56,6 +59,7 @@ pub(super) fn binary_op_rhs_expectation(op: BinaryOp, lhs_ty: Ty) -> Ty { TypeCtor::Int(_) | TypeCtor::Float(_) => lhs_ty, _ => Ty::Unknown, }, + Ty::Infer(InferTy::IntVar(..)) | Ty::Infer(InferTy::FloatVar(..)) => lhs_ty, _ => Ty::Unknown, }, } @@ -70,6 +74,7 @@ pub(super) fn binary_op_return_ty(op: BinaryOp, rhs_ty: Ty) -> Ty { TypeCtor::Int(_) | TypeCtor::Float(_) => rhs_ty, _ => Ty::Unknown, }, + Ty::Infer(InferTy::IntVar(..)) | Ty::Infer(InferTy::FloatVar(..)) => rhs_ty, _ => Ty::Unknown, }, BinaryOp::CmpOp(_) | BinaryOp::LogicOp(_) => Ty::simple(TypeCtor::Bool), diff --git a/crates/mun_hir/src/ty/snapshots/tests__infer_literals.snap b/crates/mun_hir/src/ty/snapshots/tests__infer_literals.snap new file mode 100644 index 000000000..9abdb5d06 --- /dev/null +++ b/crates/mun_hir/src/ty/snapshots/tests__infer_literals.snap @@ -0,0 +1,14 @@ +--- +source: crates/mun_hir/src/ty/tests.rs +expression: "fn integer() -> i32 {\n 0\n }\n\n fn large_unsigned_integer() -> u128 {\n 0\n }\n\n fn with_let() -> u16 {\n let b = 4;\n let a = 4;\n a\n }" +--- +[20; 37) '{ ... }': i32 +[30; 31) '0': i32 +[79; 96) '{ ... }': u128 +[89; 90) '0': u128 +[123; 178) '{ ... }': u16 +[137; 138) 'b': int +[141; 142) '4': int +[156; 157) 'a': u16 +[160; 161) '4': u16 +[171; 172) 'a': u16 diff --git a/crates/mun_hir/src/ty/snapshots/tests__invalid_unary_ops.snap b/crates/mun_hir/src/ty/snapshots/tests__invalid_unary_ops.snap index 0bb46c57e..6c6b6f3ef 100644 --- a/crates/mun_hir/src/ty/snapshots/tests__invalid_unary_ops.snap +++ b/crates/mun_hir/src/ty/snapshots/tests__invalid_unary_ops.snap @@ -3,9 +3,7 @@ source: crates/mun_hir/src/ty/tests.rs expression: "fn bar(a: float, b: bool) {\n a = !a; // mismatched type\n b = -b; // mismatched type\n}" --- [37; 38): cannot apply unary operator -[36; 38): mismatched type [68; 69): cannot apply unary operator -[67; 69): mismatched type [7; 8) 'a': float [17; 18) 'b': bool [26; 91) '{ ...type }': nothing diff --git a/crates/mun_hir/src/ty/tests.rs b/crates/mun_hir/src/ty/tests.rs index a364d7f96..90e710f30 100644 --- a/crates/mun_hir/src/ty/tests.rs +++ b/crates/mun_hir/src/ty/tests.rs @@ -20,6 +20,27 @@ fn comparison_not_implemented_for_struct() { ) } +#[test] +fn infer_literals() { + infer_snapshot( + r" + fn integer() -> i32 { + 0 + } + + fn large_unsigned_integer() -> u128 { + 0 + } + + fn with_let() -> u16 { + let b = 4; + let a = 4; + a + } + ", + ) +} + #[test] fn infer_suffix_literals() { infer_snapshot( diff --git a/crates/mun_hir/src/utils.rs b/crates/mun_hir/src/utils.rs new file mode 100644 index 000000000..4f9be983c --- /dev/null +++ b/crates/mun_hir/src/utils.rs @@ -0,0 +1,10 @@ +use std::sync::Arc; + +/// Helper for mutating `Arc<[T]>` (i.e. `Arc::make_mut` for Arc slices). +/// The underlying values are cloned if there are other strong references. +pub(crate) fn make_mut_slice(a: &mut Arc<[T]>) -> &mut [T] { + if Arc::get_mut(a).is_none() { + *a = a.iter().cloned().collect(); + } + Arc::get_mut(a).unwrap() +}