diff --git a/compiler/rustc_infer/src/infer/combine.rs b/compiler/rustc_infer/src/infer/combine.rs index 72676b718fabe..a567b6acdbeeb 100644 --- a/compiler/rustc_infer/src/infer/combine.rs +++ b/compiler/rustc_infer/src/infer/combine.rs @@ -37,7 +37,10 @@ use rustc_middle::traits::ObligationCause; use rustc_middle::ty::error::{ExpectedFound, TypeError}; use rustc_middle::ty::relate::{self, Relate, RelateResult, TypeRelation}; use rustc_middle::ty::subst::SubstsRef; -use rustc_middle::ty::{self, InferConst, Ty, TyCtxt, TypeVisitable}; +use rustc_middle::ty::{ + self, FallibleTypeFolder, InferConst, Ty, TyCtxt, TypeFoldable, TypeSuperFoldable, + TypeVisitable, +}; use rustc_middle::ty::{IntType, UintType}; use rustc_span::{Span, DUMMY_SP}; @@ -140,8 +143,6 @@ impl<'tcx> InferCtxt<'tcx> { let a = self.shallow_resolve(a); let b = self.shallow_resolve(b); - let a_is_expected = relation.a_is_expected(); - match (a.kind(), b.kind()) { ( ty::ConstKind::Infer(InferConst::Var(a_vid)), @@ -158,11 +159,11 @@ impl<'tcx> InferCtxt<'tcx> { } (ty::ConstKind::Infer(InferConst::Var(vid)), _) => { - return self.unify_const_variable(relation.param_env(), vid, b, a_is_expected); + return self.unify_const_variable(vid, b); } (_, ty::ConstKind::Infer(InferConst::Var(vid))) => { - return self.unify_const_variable(relation.param_env(), vid, a, !a_is_expected); + return self.unify_const_variable(vid, a); } (ty::ConstKind::Unevaluated(..), _) if self.tcx.lazy_normalization() => { // FIXME(#59490): Need to remove the leak check to accommodate @@ -223,10 +224,8 @@ impl<'tcx> InferCtxt<'tcx> { #[instrument(level = "debug", skip(self))] fn unify_const_variable( &self, - param_env: ty::ParamEnv<'tcx>, target_vid: ty::ConstVid<'tcx>, ct: ty::Const<'tcx>, - vid_is_expected: bool, ) -> RelateResult<'tcx, ty::Const<'tcx>> { let (for_universe, span) = { let mut inner = self.inner.borrow_mut(); @@ -239,8 +238,12 @@ impl<'tcx> InferCtxt<'tcx> { ConstVariableValue::Unknown { universe } => (universe, var_value.origin.span), } }; - let value = ConstInferUnifier { infcx: self, span, param_env, for_universe, target_vid } - .relate(ct, ct)?; + let value = ct.try_fold_with(&mut ConstInferUnifier { + infcx: self, + span, + for_universe, + target_vid, + })?; self.inner.borrow_mut().const_unification_table().union_value( target_vid, @@ -800,8 +803,6 @@ struct ConstInferUnifier<'cx, 'tcx> { span: Span, - param_env: ty::ParamEnv<'tcx>, - for_universe: ty::UniverseIndex, /// The vid of the const variable that is in the process of being @@ -810,61 +811,15 @@ struct ConstInferUnifier<'cx, 'tcx> { target_vid: ty::ConstVid<'tcx>, } -// We use `TypeRelation` here to propagate `RelateResult` upwards. -// -// Both inputs are expected to be the same. -impl<'tcx> TypeRelation<'tcx> for ConstInferUnifier<'_, 'tcx> { - fn tcx(&self) -> TyCtxt<'tcx> { - self.infcx.tcx - } - - fn intercrate(&self) -> bool { - assert!(!self.infcx.intercrate); - false - } - - fn param_env(&self) -> ty::ParamEnv<'tcx> { - self.param_env - } - - fn tag(&self) -> &'static str { - "ConstInferUnifier" - } - - fn a_is_expected(&self) -> bool { - true - } - - fn mark_ambiguous(&mut self) { - bug!() - } - - fn relate_with_variance>( - &mut self, - _variance: ty::Variance, - _info: ty::VarianceDiagInfo<'tcx>, - a: T, - b: T, - ) -> RelateResult<'tcx, T> { - // We don't care about variance here. - self.relate(a, b) - } +impl<'tcx> FallibleTypeFolder<'tcx> for ConstInferUnifier<'_, 'tcx> { + type Error = TypeError<'tcx>; - fn binders( - &mut self, - a: ty::Binder<'tcx, T>, - b: ty::Binder<'tcx, T>, - ) -> RelateResult<'tcx, ty::Binder<'tcx, T>> - where - T: Relate<'tcx>, - { - Ok(a.rebind(self.relate(a.skip_binder(), b.skip_binder())?)) + fn tcx<'a>(&'a self) -> TyCtxt<'tcx> { + self.infcx.tcx } #[instrument(level = "debug", skip(self), ret)] - fn tys(&mut self, t: Ty<'tcx>, _t: Ty<'tcx>) -> RelateResult<'tcx, Ty<'tcx>> { - debug_assert_eq!(t, _t); - + fn try_fold_ty(&mut self, t: Ty<'tcx>) -> Result, TypeError<'tcx>> { match t.kind() { &ty::Infer(ty::TyVar(vid)) => { let vid = self.infcx.inner.borrow_mut().type_variables().root_var(vid); @@ -872,7 +827,7 @@ impl<'tcx> TypeRelation<'tcx> for ConstInferUnifier<'_, 'tcx> { match probe { TypeVariableValue::Known { value: u } => { debug!("ConstOccursChecker: known value {:?}", u); - self.tys(u, u) + u.try_fold_with(self) } TypeVariableValue::Unknown { universe } => { if self.for_universe.can_name(universe) { @@ -892,16 +847,15 @@ impl<'tcx> TypeRelation<'tcx> for ConstInferUnifier<'_, 'tcx> { } } ty::Infer(ty::IntVar(_) | ty::FloatVar(_)) => Ok(t), - _ => relate::super_relate_tys(self, t, t), + _ => t.try_super_fold_with(self), } } - fn regions( + #[instrument(level = "debug", skip(self), ret)] + fn try_fold_region( &mut self, r: ty::Region<'tcx>, - _r: ty::Region<'tcx>, - ) -> RelateResult<'tcx, ty::Region<'tcx>> { - debug_assert_eq!(r, _r); + ) -> Result, TypeError<'tcx>> { debug!("ConstInferUnifier: r={:?}", r); match *r { @@ -930,14 +884,8 @@ impl<'tcx> TypeRelation<'tcx> for ConstInferUnifier<'_, 'tcx> { } } - #[instrument(level = "debug", skip(self))] - fn consts( - &mut self, - c: ty::Const<'tcx>, - _c: ty::Const<'tcx>, - ) -> RelateResult<'tcx, ty::Const<'tcx>> { - debug_assert_eq!(c, _c); - + #[instrument(level = "debug", skip(self), ret)] + fn try_fold_const(&mut self, c: ty::Const<'tcx>) -> Result, TypeError<'tcx>> { match c.kind() { ty::ConstKind::Infer(InferConst::Var(vid)) => { // Check if the current unification would end up @@ -958,7 +906,7 @@ impl<'tcx> TypeRelation<'tcx> for ConstInferUnifier<'_, 'tcx> { let var_value = self.infcx.inner.borrow_mut().const_unification_table().probe_value(vid); match var_value.val { - ConstVariableValue::Known { value: u } => self.consts(u, u), + ConstVariableValue::Known { value: u } => u.try_fold_with(self), ConstVariableValue::Unknown { universe } => { if self.for_universe.can_name(universe) { Ok(c) @@ -977,17 +925,7 @@ impl<'tcx> TypeRelation<'tcx> for ConstInferUnifier<'_, 'tcx> { } } } - ty::ConstKind::Unevaluated(ty::UnevaluatedConst { def, substs }) => { - let substs = self.relate_with_variance( - ty::Variance::Invariant, - ty::VarianceDiagInfo::default(), - substs, - substs, - )?; - - Ok(self.tcx().mk_const(ty::UnevaluatedConst { def, substs }, c.ty())) - } - _ => relate::super_relate_consts(self, c, c), + _ => c.try_super_fold_with(self), } } }