diff --git a/compiler/rustc_middle/src/traits/mod.rs b/compiler/rustc_middle/src/traits/mod.rs index 664eef7ca5691..4b67fc84d2ed8 100644 --- a/compiler/rustc_middle/src/traits/mod.rs +++ b/compiler/rustc_middle/src/traits/mod.rs @@ -97,9 +97,7 @@ pub struct ObligationCause<'tcx> { /// information. pub body_id: hir::HirId, - /// `None` for `MISC_OBLIGATION_CAUSE_CODE` (a common case, occurs ~60% of - /// the time). `Some` otherwise. - code: Option>>, + code: InternedObligationCauseCode<'tcx>, } // This custom hash function speeds up hashing for `Obligation` deduplication @@ -123,11 +121,7 @@ impl<'tcx> ObligationCause<'tcx> { body_id: hir::HirId, code: ObligationCauseCode<'tcx>, ) -> ObligationCause<'tcx> { - ObligationCause { - span, - body_id, - code: if code == MISC_OBLIGATION_CAUSE_CODE { None } else { Some(Lrc::new(code)) }, - } + ObligationCause { span, body_id, code: code.into() } } pub fn misc(span: Span, body_id: hir::HirId) -> ObligationCause<'tcx> { @@ -136,11 +130,11 @@ impl<'tcx> ObligationCause<'tcx> { #[inline(always)] pub fn dummy() -> ObligationCause<'tcx> { - ObligationCause { span: DUMMY_SP, body_id: hir::CRATE_HIR_ID, code: None } + ObligationCause::dummy_with_span(DUMMY_SP) } pub fn dummy_with_span(span: Span) -> ObligationCause<'tcx> { - ObligationCause { span, body_id: hir::CRATE_HIR_ID, code: None } + ObligationCause { span, body_id: hir::CRATE_HIR_ID, code: Default::default() } } pub fn span(&self, tcx: TyCtxt<'tcx>) -> Span { @@ -160,14 +154,14 @@ impl<'tcx> ObligationCause<'tcx> { #[inline] pub fn code(&self) -> &ObligationCauseCode<'tcx> { - self.code.as_deref().unwrap_or(&MISC_OBLIGATION_CAUSE_CODE) + &self.code } pub fn map_code( &mut self, - f: impl FnOnce(InternedObligationCauseCode<'tcx>) -> Lrc>, + f: impl FnOnce(InternedObligationCauseCode<'tcx>) -> ObligationCauseCode<'tcx>, ) { - self.code = Some(f(InternedObligationCauseCode { code: self.code.take() })); + self.code = f(std::mem::take(&mut self.code)).into(); } pub fn derived_cause( @@ -188,10 +182,8 @@ impl<'tcx> ObligationCause<'tcx> { // NOTE(flaper87): As of now, it keeps track of the whole error // chain. Ideally, we should have a way to configure this either // by using -Z verbose or just a CLI argument. - self.code = Some( - variant(DerivedObligationCause { parent_trait_pred, parent_code: self.code.take() }) - .into(), - ); + self.code = + variant(DerivedObligationCause { parent_trait_pred, parent_code: self.code }).into(); self } } @@ -203,11 +195,19 @@ pub struct UnifyReceiverContext<'tcx> { pub substs: SubstsRef<'tcx>, } -#[derive(Clone, Debug, PartialEq, Eq, Hash, Lift)] +#[derive(Clone, Debug, PartialEq, Eq, Hash, Lift, Default)] pub struct InternedObligationCauseCode<'tcx> { + /// `None` for `MISC_OBLIGATION_CAUSE_CODE` (a common case, occurs ~60% of + /// the time). `Some` otherwise. code: Option>>, } +impl<'tcx> From> for InternedObligationCauseCode<'tcx> { + fn from(code: ObligationCauseCode<'tcx>) -> Self { + Self { code: if code == MISC_OBLIGATION_CAUSE_CODE { None } else { Some(Lrc::new(code)) } } + } +} + impl<'tcx> std::ops::Deref for InternedObligationCauseCode<'tcx> { type Target = ObligationCauseCode<'tcx>; @@ -454,7 +454,7 @@ impl<'tcx> ObligationCauseCode<'tcx> { BuiltinDerivedObligation(derived) | DerivedObligation(derived) | ImplDerivedObligation(box ImplDerivedObligationCause { derived, .. }) => { - Some((derived.parent_code(), Some(derived.parent_trait_pred))) + Some((&derived.parent_code, Some(derived.parent_trait_pred))) } _ => None, } @@ -508,15 +508,7 @@ pub struct DerivedObligationCause<'tcx> { pub parent_trait_pred: ty::PolyTraitPredicate<'tcx>, /// The parent trait had this cause. - parent_code: Option>>, -} - -impl<'tcx> DerivedObligationCause<'tcx> { - /// Get a reference to the derived obligation cause's parent code. - #[must_use] - pub fn parent_code(&self) -> &ObligationCauseCode<'tcx> { - self.parent_code.as_deref().unwrap_or(&MISC_OBLIGATION_CAUSE_CODE) - } + pub parent_code: InternedObligationCauseCode<'tcx>, } #[derive(Clone, Debug, TypeFoldable, Lift)] diff --git a/compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs b/compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs index 81e62f6da06e9..6082d7529c32e 100644 --- a/compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs +++ b/compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs @@ -1868,7 +1868,7 @@ impl<'a, 'tcx> InferCtxtPrivExt<'a, 'tcx> for InferCtxt<'a, 'tcx> { match code { ObligationCauseCode::BuiltinDerivedObligation(data) => { let parent_trait_ref = self.resolve_vars_if_possible(data.parent_trait_pred); - match self.get_parent_trait_ref(data.parent_code()) { + match self.get_parent_trait_ref(&data.parent_code) { Some(t) => Some(t), None => { let ty = parent_trait_ref.skip_binder().self_ty(); diff --git a/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs b/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs index ee3e9544b4d60..833e232e63665 100644 --- a/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs +++ b/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs @@ -1683,7 +1683,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> { _ => {} } - next_code = Some(cause.derived.parent_code()); + next_code = Some(&cause.derived.parent_code); } ObligationCauseCode::DerivedObligation(derived_obligation) | ObligationCauseCode::BuiltinDerivedObligation(derived_obligation) => { @@ -1715,7 +1715,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> { _ => {} } - next_code = Some(derived_obligation.parent_code()); + next_code = Some(&derived_obligation.parent_code); } _ => break, } @@ -2365,7 +2365,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> { let is_upvar_tys_infer_tuple = if !matches!(ty.kind(), ty::Tuple(..)) { false } else { - if let ObligationCauseCode::BuiltinDerivedObligation(data) = data.parent_code() + if let ObligationCauseCode::BuiltinDerivedObligation(data) = &*data.parent_code { let parent_trait_ref = self.resolve_vars_if_possible(data.parent_trait_pred); @@ -2392,14 +2392,14 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> { obligated_types.push(ty); let parent_predicate = parent_trait_ref.to_predicate(tcx); - if !self.is_recursive_obligation(obligated_types, data.parent_code()) { + if !self.is_recursive_obligation(obligated_types, &data.parent_code) { // #74711: avoid a stack overflow ensure_sufficient_stack(|| { self.note_obligation_cause_code( err, &parent_predicate, param_env, - data.parent_code(), + &data.parent_code, obligated_types, seen_requirements, ) @@ -2410,7 +2410,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> { err, &parent_predicate, param_env, - &cause_code.peel_derives(), + cause_code.peel_derives(), obligated_types, seen_requirements, ) @@ -2461,7 +2461,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> { // We don't want to point at the ADT saying "required because it appears within // the type `X`", like we would otherwise do in test `supertrait-auto-trait.rs`. while let ObligationCauseCode::BuiltinDerivedObligation(derived) = - data.parent_code() + &*data.parent_code { let child_trait_ref = self.resolve_vars_if_possible(derived.parent_trait_pred); @@ -2474,7 +2474,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> { parent_trait_pred = child_trait_ref; } } - while let ObligationCauseCode::ImplDerivedObligation(child) = data.parent_code() { + while let ObligationCauseCode::ImplDerivedObligation(child) = &*data.parent_code { // Skip redundant recursive obligation notes. See `ui/issue-20413.rs`. let child_trait_pred = self.resolve_vars_if_possible(child.derived.parent_trait_pred); @@ -2505,7 +2505,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> { err, &parent_predicate, param_env, - data.parent_code(), + &data.parent_code, obligated_types, seen_requirements, ) @@ -2520,7 +2520,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> { err, &parent_predicate, param_env, - data.parent_code(), + &data.parent_code, obligated_types, seen_requirements, ) diff --git a/compiler/rustc_typeck/src/check/fn_ctxt/checks.rs b/compiler/rustc_typeck/src/check/fn_ctxt/checks.rs index 300b87aa46575..7c180bd164322 100644 --- a/compiler/rustc_typeck/src/check/fn_ctxt/checks.rs +++ b/compiler/rustc_typeck/src/check/fn_ctxt/checks.rs @@ -1606,9 +1606,9 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { let mut result_code = code; loop { let parent = match code { - ObligationCauseCode::ImplDerivedObligation(c) => c.derived.parent_code(), + ObligationCauseCode::ImplDerivedObligation(c) => &c.derived.parent_code, ObligationCauseCode::BuiltinDerivedObligation(c) - | ObligationCauseCode::DerivedObligation(c) => c.parent_code(), + | ObligationCauseCode::DerivedObligation(c) => &c.parent_code, _ => break result_code, }; (result_code, code) = (code, parent); @@ -1670,7 +1670,6 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { call_hir_id: expr.hir_id, parent_code, } - .into() }); } else if error.obligation.cause.span == call_sp { // Make function calls point at the callee, not the whole thing.