Skip to content

Commit

Permalink
Auto merge of #72962 - lcnr:ObligationCause-lrc, r=ecstatic-morse
Browse files Browse the repository at this point in the history
store `ObligationCause` on the heap

Stores `ObligationCause` on the heap using an `Rc`.

This PR trades off some transient memory allocations to reduce the size of–and thus the number of instructions required to memcpy–a few widely used data structures in trait solving.
  • Loading branch information
bors committed Jun 16, 2020
2 parents f315c35 + ea668d9 commit c8a9c34
Show file tree
Hide file tree
Showing 10 changed files with 82 additions and 38 deletions.
2 changes: 1 addition & 1 deletion src/librustc_infer/traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ pub type TraitObligation<'tcx> = Obligation<'tcx, ty::PolyTraitPredicate<'tcx>>;

// `PredicateObligation` is used a lot. Make sure it doesn't unintentionally get bigger.
#[cfg(target_arch = "x86_64")]
static_assert_size!(PredicateObligation<'_>, 88);
static_assert_size!(PredicateObligation<'_>, 48);

pub type Obligations<'tcx, O> = Vec<Obligation<'tcx, O>>;
pub type PredicateObligations<'tcx> = Vec<PredicateObligation<'tcx>>;
Expand Down
10 changes: 6 additions & 4 deletions src/librustc_infer/traits/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,12 @@ fn predicate_obligation<'tcx>(
predicate: ty::Predicate<'tcx>,
span: Option<Span>,
) -> PredicateObligation<'tcx> {
let mut cause = ObligationCause::dummy();
if let Some(span) = span {
cause.span = span;
}
let cause = if let Some(span) = span {
ObligationCause::dummy_with_span(span)
} else {
ObligationCause::dummy()
};

Obligation { cause, param_env: ty::ParamEnv::empty(), recursion_depth: 0, predicate }
}

Expand Down
51 changes: 46 additions & 5 deletions src/librustc_middle/traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ use rustc_span::{Span, DUMMY_SP};
use smallvec::SmallVec;

use std::borrow::Cow;
use std::fmt::Debug;
use std::fmt;
use std::ops::Deref;
use std::rc::Rc;

pub use self::select::{EvaluationCache, EvaluationResult, OverflowError, SelectionCache};
Expand Down Expand Up @@ -80,8 +81,39 @@ pub enum Reveal {
}

/// The reason why we incurred this obligation; used for error reporting.
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
///
/// As the happy path does not care about this struct, storing this on the heap
/// ends up increasing performance.
///
/// We do not want to intern this as there are a lot of obligation causes which
/// only live for a short period of time.
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct ObligationCause<'tcx> {
/// `None` for `ObligationCause::dummy`, `Some` otherwise.
data: Option<Rc<ObligationCauseData<'tcx>>>,
}

const DUMMY_OBLIGATION_CAUSE_DATA: ObligationCauseData<'static> =
ObligationCauseData { span: DUMMY_SP, body_id: hir::CRATE_HIR_ID, code: MiscObligation };

// Correctly format `ObligationCause::dummy`.
impl<'tcx> fmt::Debug for ObligationCause<'tcx> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
ObligationCauseData::fmt(self, f)
}
}

impl Deref for ObligationCause<'tcx> {
type Target = ObligationCauseData<'tcx>;

#[inline(always)]
fn deref(&self) -> &Self::Target {
self.data.as_deref().unwrap_or(&DUMMY_OBLIGATION_CAUSE_DATA)
}
}

#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct ObligationCauseData<'tcx> {
pub span: Span,

/// The ID of the fn body that triggered this obligation. This is
Expand All @@ -102,15 +134,24 @@ impl<'tcx> ObligationCause<'tcx> {
body_id: hir::HirId,
code: ObligationCauseCode<'tcx>,
) -> ObligationCause<'tcx> {
ObligationCause { span, body_id, code }
ObligationCause { data: Some(Rc::new(ObligationCauseData { span, body_id, code })) }
}

pub fn misc(span: Span, body_id: hir::HirId) -> ObligationCause<'tcx> {
ObligationCause { span, body_id, code: MiscObligation }
ObligationCause::new(span, body_id, MiscObligation)
}

pub fn dummy_with_span(span: Span) -> ObligationCause<'tcx> {
ObligationCause::new(span, hir::CRATE_HIR_ID, MiscObligation)
}

#[inline(always)]
pub fn dummy() -> ObligationCause<'tcx> {
ObligationCause { span: DUMMY_SP, body_id: hir::CRATE_HIR_ID, code: MiscObligation }
ObligationCause { data: None }
}

pub fn make_mut(&mut self) -> &mut ObligationCauseData<'tcx> {
Rc::make_mut(self.data.get_or_insert_with(|| Rc::new(DUMMY_OBLIGATION_CAUSE_DATA)))
}

pub fn span(&self, tcx: TyCtxt<'tcx>) -> Span {
Expand Down
6 changes: 1 addition & 5 deletions src/librustc_middle/traits/structural_impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,11 +232,7 @@ impl<'a, 'tcx> Lift<'tcx> for traits::DerivedObligationCause<'a> {
impl<'a, 'tcx> Lift<'tcx> for traits::ObligationCause<'a> {
type Lifted = traits::ObligationCause<'tcx>;
fn lift_to_tcx(&self, tcx: TyCtxt<'tcx>) -> Option<Self::Lifted> {
tcx.lift(&self.code).map(|code| traits::ObligationCause {
span: self.span,
body_id: self.body_id,
code,
})
tcx.lift(&self.code).map(|code| traits::ObligationCause::new(self.span, self.body_id, code))
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/librustc_mir/borrow_check/type_check/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1250,7 +1250,7 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
|infcx| {
let mut obligations = ObligationAccumulator::default();

let dummy_body_id = ObligationCause::dummy().body_id;
let dummy_body_id = hir::CRATE_HIR_ID;
let (output_ty, opaque_type_map) =
obligations.add(infcx.instantiate_opaque_types(
anon_owner_def_id,
Expand Down
2 changes: 1 addition & 1 deletion src/librustc_trait_selection/traits/fulfill.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ pub struct PendingPredicateObligation<'tcx> {

// `PendingPredicateObligation` is used a lot. Make sure it doesn't unintentionally get bigger.
#[cfg(target_arch = "x86_64")]
static_assert_size!(PendingPredicateObligation<'_>, 112);
static_assert_size!(PendingPredicateObligation<'_>, 72);

impl<'a, 'tcx> FulfillmentContext<'tcx> {
/// Creates a new fulfillment context.
Expand Down
2 changes: 1 addition & 1 deletion src/librustc_trait_selection/traits/misc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ pub fn can_type_implement_copy(
continue;
}
let span = tcx.def_span(field.did);
let cause = ObligationCause { span, ..ObligationCause::dummy() };
let cause = ObligationCause::dummy_with_span(span);
let ctx = traits::FulfillmentContext::new();
match traits::fully_normalize(&infcx, ctx, cause, param_env, &ty) {
Ok(ty) => {
Expand Down
7 changes: 4 additions & 3 deletions src/librustc_trait_selection/traits/wf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ fn extend_cause_with_original_assoc_item_obligation<'tcx>(
if let Some(impl_item_span) =
items.iter().find(|item| item.ident == trait_assoc_item.ident).map(fix_span)
{
cause.span = impl_item_span;
cause.make_mut().span = impl_item_span;
}
}
}
Expand All @@ -222,7 +222,7 @@ fn extend_cause_with_original_assoc_item_obligation<'tcx>(
items.iter().find(|i| i.ident == trait_assoc_item.ident).map(fix_span)
})
{
cause.span = impl_item_span;
cause.make_mut().span = impl_item_span;
}
}
}
Expand Down Expand Up @@ -273,7 +273,8 @@ impl<'a, 'tcx> WfPredicates<'a, 'tcx> {
parent_trait_ref,
parent_code: Rc::new(obligation.cause.code.clone()),
};
cause.code = traits::ObligationCauseCode::DerivedObligation(derived_cause);
cause.make_mut().code =
traits::ObligationCauseCode::DerivedObligation(derived_cause);
}
extend_cause_with_original_assoc_item_obligation(
tcx,
Expand Down
32 changes: 18 additions & 14 deletions src/librustc_typeck/check/compare_method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,16 @@ fn compare_predicate_entailment<'tcx>(
// `regionck_item` expects.
let impl_m_hir_id = tcx.hir().as_local_hir_id(impl_m.def_id.expect_local());

let cause = ObligationCause {
span: impl_m_span,
body_id: impl_m_hir_id,
code: ObligationCauseCode::CompareImplMethodObligation {
// We sometimes modify the span further down.
let mut cause = ObligationCause::new(
impl_m_span,
impl_m_hir_id,
ObligationCauseCode::CompareImplMethodObligation {
item_name: impl_m.ident.name,
impl_item_def_id: impl_m.def_id,
trait_item_def_id: trait_m.def_id,
},
};
);

// This code is best explained by example. Consider a trait:
//
Expand Down Expand Up @@ -280,7 +281,7 @@ fn compare_predicate_entailment<'tcx>(
&infcx, param_env, &terr, &cause, impl_m, impl_sig, trait_m, trait_sig,
);

let cause = ObligationCause { span: impl_err_span, ..cause };
cause.make_mut().span = impl_err_span;

let mut diag = struct_span_err!(
tcx.sess,
Expand Down Expand Up @@ -965,8 +966,11 @@ crate fn compare_const_impl<'tcx>(
// Compute placeholder form of impl and trait const tys.
let impl_ty = tcx.type_of(impl_c.def_id);
let trait_ty = tcx.type_of(trait_c.def_id).subst(tcx, trait_to_impl_substs);
let mut cause = ObligationCause::misc(impl_c_span, impl_c_hir_id);
cause.code = ObligationCauseCode::CompareImplConstObligation;
let mut cause = ObligationCause::new(
impl_c_span,
impl_c_hir_id,
ObligationCauseCode::CompareImplConstObligation,
);

// There is no "body" here, so just pass dummy id.
let impl_ty =
Expand All @@ -992,7 +996,7 @@ crate fn compare_const_impl<'tcx>(

// Locate the Span containing just the type of the offending impl
match tcx.hir().expect_impl_item(impl_c_hir_id).kind {
ImplItemKind::Const(ref ty, _) => cause.span = ty.span,
ImplItemKind::Const(ref ty, _) => cause.make_mut().span = ty.span,
_ => bug!("{:?} is not a impl const", impl_c),
}

Expand Down Expand Up @@ -1095,15 +1099,15 @@ fn compare_type_predicate_entailment(
// `ObligationCause` (and the `FnCtxt`). This is what
// `regionck_item` expects.
let impl_ty_hir_id = tcx.hir().as_local_hir_id(impl_ty.def_id.expect_local());
let cause = ObligationCause {
span: impl_ty_span,
body_id: impl_ty_hir_id,
code: ObligationCauseCode::CompareImplTypeObligation {
let cause = ObligationCause::new(
impl_ty_span,
impl_ty_hir_id,
ObligationCauseCode::CompareImplTypeObligation {
item_name: impl_ty.ident.name,
impl_item_def_id: impl_ty.def_id,
trait_item_def_id: trait_ty.def_id,
},
};
);

debug!("compare_type_predicate_entailment: trait_to_impl_substs={:?}", trait_to_impl_substs);

Expand Down
6 changes: 3 additions & 3 deletions src/librustc_typeck/check/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4218,7 +4218,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
if let (Some(ref_in), None) = (referenced_in.pop(), referenced_in.pop()) {
// We make sure that only *one* argument matches the obligation failure
// and we assign the obligation's span to its expression's.
error.obligation.cause.span = args[ref_in].span;
error.obligation.cause.make_mut().span = args[ref_in].span;
error.points_at_arg_span = true;
}
}
Expand Down Expand Up @@ -4261,7 +4261,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
let ty = AstConv::ast_ty_to_ty(self, hir_ty);
let ty = self.resolve_vars_if_possible(&ty);
if ty == predicate.skip_binder().self_ty() {
error.obligation.cause.span = hir_ty.span;
error.obligation.cause.make_mut().span = hir_ty.span;
}
}
}
Expand Down Expand Up @@ -5689,7 +5689,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
{
// This makes the error point at the bound, but we want to point at the argument
if let Some(span) = spans.get(i) {
obligation.cause.code = traits::BindingObligation(def_id, *span);
obligation.cause.make_mut().code = traits::BindingObligation(def_id, *span);
}
self.register_predicate(obligation);
}
Expand Down

0 comments on commit c8a9c34

Please sign in to comment.