@@ -37,7 +37,10 @@ use rustc_middle::traits::ObligationCause;
3737use rustc_middle:: ty:: error:: { ExpectedFound , TypeError } ;
3838use rustc_middle:: ty:: relate:: { self , Relate , RelateResult , TypeRelation } ;
3939use rustc_middle:: ty:: subst:: SubstsRef ;
40- use rustc_middle:: ty:: { self , InferConst , Ty , TyCtxt , TypeVisitable } ;
40+ use rustc_middle:: ty:: {
41+ self , FallibleTypeFolder , InferConst , Ty , TyCtxt , TypeFoldable , TypeSuperFoldable ,
42+ TypeVisitable ,
43+ } ;
4144use rustc_middle:: ty:: { IntType , UintType } ;
4245use rustc_span:: { Span , DUMMY_SP } ;
4346
@@ -140,8 +143,6 @@ impl<'tcx> InferCtxt<'tcx> {
140143 let a = self . shallow_resolve ( a) ;
141144 let b = self . shallow_resolve ( b) ;
142145
143- let a_is_expected = relation. a_is_expected ( ) ;
144-
145146 match ( a. kind ( ) , b. kind ( ) ) {
146147 (
147148 ty:: ConstKind :: Infer ( InferConst :: Var ( a_vid) ) ,
@@ -158,11 +159,11 @@ impl<'tcx> InferCtxt<'tcx> {
158159 }
159160
160161 ( ty:: ConstKind :: Infer ( InferConst :: Var ( vid) ) , _) => {
161- return self . unify_const_variable ( relation . param_env ( ) , vid, b, a_is_expected ) ;
162+ return self . unify_const_variable ( vid, b) ;
162163 }
163164
164165 ( _, ty:: ConstKind :: Infer ( InferConst :: Var ( vid) ) ) => {
165- return self . unify_const_variable ( relation . param_env ( ) , vid, a, !a_is_expected ) ;
166+ return self . unify_const_variable ( vid, a) ;
166167 }
167168 ( ty:: ConstKind :: Unevaluated ( ..) , _) if self . tcx . lazy_normalization ( ) => {
168169 // FIXME(#59490): Need to remove the leak check to accommodate
@@ -223,10 +224,8 @@ impl<'tcx> InferCtxt<'tcx> {
223224 #[ instrument( level = "debug" , skip( self ) ) ]
224225 fn unify_const_variable (
225226 & self ,
226- param_env : ty:: ParamEnv < ' tcx > ,
227227 target_vid : ty:: ConstVid < ' tcx > ,
228228 ct : ty:: Const < ' tcx > ,
229- vid_is_expected : bool ,
230229 ) -> RelateResult < ' tcx , ty:: Const < ' tcx > > {
231230 let ( for_universe, span) = {
232231 let mut inner = self . inner . borrow_mut ( ) ;
@@ -239,8 +238,12 @@ impl<'tcx> InferCtxt<'tcx> {
239238 ConstVariableValue :: Unknown { universe } => ( universe, var_value. origin . span ) ,
240239 }
241240 } ;
242- let value = ConstInferUnifier { infcx : self , span, param_env, for_universe, target_vid }
243- . relate ( ct, ct) ?;
241+ let value = ct. try_fold_with ( & mut ConstInferUnifier {
242+ infcx : self ,
243+ span,
244+ for_universe,
245+ target_vid,
246+ } ) ?;
244247
245248 self . inner . borrow_mut ( ) . const_unification_table ( ) . union_value (
246249 target_vid,
@@ -800,8 +803,6 @@ struct ConstInferUnifier<'cx, 'tcx> {
800803
801804 span : Span ,
802805
803- param_env : ty:: ParamEnv < ' tcx > ,
804-
805806 for_universe : ty:: UniverseIndex ,
806807
807808 /// The vid of the const variable that is in the process of being
@@ -810,69 +811,23 @@ struct ConstInferUnifier<'cx, 'tcx> {
810811 target_vid : ty:: ConstVid < ' tcx > ,
811812}
812813
813- // We use `TypeRelation` here to propagate `RelateResult` upwards.
814- //
815- // Both inputs are expected to be the same.
816- impl < ' tcx > TypeRelation < ' tcx > for ConstInferUnifier < ' _ , ' tcx > {
817- fn tcx ( & self ) -> TyCtxt < ' tcx > {
818- self . infcx . tcx
819- }
820-
821- fn intercrate ( & self ) -> bool {
822- assert ! ( !self . infcx. intercrate) ;
823- false
824- }
825-
826- fn param_env ( & self ) -> ty:: ParamEnv < ' tcx > {
827- self . param_env
828- }
829-
830- fn tag ( & self ) -> & ' static str {
831- "ConstInferUnifier"
832- }
833-
834- fn a_is_expected ( & self ) -> bool {
835- true
836- }
837-
838- fn mark_ambiguous ( & mut self ) {
839- bug ! ( )
840- }
841-
842- fn relate_with_variance < T : Relate < ' tcx > > (
843- & mut self ,
844- _variance : ty:: Variance ,
845- _info : ty:: VarianceDiagInfo < ' tcx > ,
846- a : T ,
847- b : T ,
848- ) -> RelateResult < ' tcx , T > {
849- // We don't care about variance here.
850- self . relate ( a, b)
851- }
814+ impl < ' tcx > FallibleTypeFolder < ' tcx > for ConstInferUnifier < ' _ , ' tcx > {
815+ type Error = TypeError < ' tcx > ;
852816
853- fn binders < T > (
854- & mut self ,
855- a : ty:: Binder < ' tcx , T > ,
856- b : ty:: Binder < ' tcx , T > ,
857- ) -> RelateResult < ' tcx , ty:: Binder < ' tcx , T > >
858- where
859- T : Relate < ' tcx > ,
860- {
861- Ok ( a. rebind ( self . relate ( a. skip_binder ( ) , b. skip_binder ( ) ) ?) )
817+ fn tcx < ' a > ( & ' a self ) -> TyCtxt < ' tcx > {
818+ self . infcx . tcx
862819 }
863820
864821 #[ instrument( level = "debug" , skip( self ) , ret) ]
865- fn tys ( & mut self , t : Ty < ' tcx > , _t : Ty < ' tcx > ) -> RelateResult < ' tcx , Ty < ' tcx > > {
866- debug_assert_eq ! ( t, _t) ;
867-
822+ fn try_fold_ty ( & mut self , t : Ty < ' tcx > ) -> Result < Ty < ' tcx > , TypeError < ' tcx > > {
868823 match t. kind ( ) {
869824 & ty:: Infer ( ty:: TyVar ( vid) ) => {
870825 let vid = self . infcx . inner . borrow_mut ( ) . type_variables ( ) . root_var ( vid) ;
871826 let probe = self . infcx . inner . borrow_mut ( ) . type_variables ( ) . probe ( vid) ;
872827 match probe {
873828 TypeVariableValue :: Known { value : u } => {
874829 debug ! ( "ConstOccursChecker: known value {:?}" , u) ;
875- self . tys ( u , u )
830+ u . try_fold_with ( self )
876831 }
877832 TypeVariableValue :: Unknown { universe } => {
878833 if self . for_universe . can_name ( universe) {
@@ -892,16 +847,15 @@ impl<'tcx> TypeRelation<'tcx> for ConstInferUnifier<'_, 'tcx> {
892847 }
893848 }
894849 ty:: Infer ( ty:: IntVar ( _) | ty:: FloatVar ( _) ) => Ok ( t) ,
895- _ => relate :: super_relate_tys ( self , t , t ) ,
850+ _ => t . try_super_fold_with ( self ) ,
896851 }
897852 }
898853
899- fn regions (
854+ #[ instrument( level = "debug" , skip( self ) , ret) ]
855+ fn try_fold_region (
900856 & mut self ,
901857 r : ty:: Region < ' tcx > ,
902- _r : ty:: Region < ' tcx > ,
903- ) -> RelateResult < ' tcx , ty:: Region < ' tcx > > {
904- debug_assert_eq ! ( r, _r) ;
858+ ) -> Result < ty:: Region < ' tcx > , TypeError < ' tcx > > {
905859 debug ! ( "ConstInferUnifier: r={:?}" , r) ;
906860
907861 match * r {
@@ -930,14 +884,8 @@ impl<'tcx> TypeRelation<'tcx> for ConstInferUnifier<'_, 'tcx> {
930884 }
931885 }
932886
933- #[ instrument( level = "debug" , skip( self ) ) ]
934- fn consts (
935- & mut self ,
936- c : ty:: Const < ' tcx > ,
937- _c : ty:: Const < ' tcx > ,
938- ) -> RelateResult < ' tcx , ty:: Const < ' tcx > > {
939- debug_assert_eq ! ( c, _c) ;
940-
887+ #[ instrument( level = "debug" , skip( self ) , ret) ]
888+ fn try_fold_const ( & mut self , c : ty:: Const < ' tcx > ) -> Result < ty:: Const < ' tcx > , TypeError < ' tcx > > {
941889 match c. kind ( ) {
942890 ty:: ConstKind :: Infer ( InferConst :: Var ( vid) ) => {
943891 // Check if the current unification would end up
@@ -958,7 +906,7 @@ impl<'tcx> TypeRelation<'tcx> for ConstInferUnifier<'_, 'tcx> {
958906 let var_value =
959907 self . infcx . inner . borrow_mut ( ) . const_unification_table ( ) . probe_value ( vid) ;
960908 match var_value. val {
961- ConstVariableValue :: Known { value : u } => self . consts ( u , u ) ,
909+ ConstVariableValue :: Known { value : u } => u . try_fold_with ( self ) ,
962910 ConstVariableValue :: Unknown { universe } => {
963911 if self . for_universe . can_name ( universe) {
964912 Ok ( c)
@@ -977,17 +925,7 @@ impl<'tcx> TypeRelation<'tcx> for ConstInferUnifier<'_, 'tcx> {
977925 }
978926 }
979927 }
980- ty:: ConstKind :: Unevaluated ( ty:: UnevaluatedConst { def, substs } ) => {
981- let substs = self . relate_with_variance (
982- ty:: Variance :: Invariant ,
983- ty:: VarianceDiagInfo :: default ( ) ,
984- substs,
985- substs,
986- ) ?;
987-
988- Ok ( self . tcx ( ) . mk_const ( ty:: UnevaluatedConst { def, substs } , c. ty ( ) ) )
989- }
990- _ => relate:: super_relate_consts ( self , c, c) ,
928+ _ => c. try_super_fold_with ( self ) ,
991929 }
992930 }
993931}
0 commit comments