Skip to content

Rework some predicates_of/{Generic,Instantiated}Predicates code #106395

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 21 additions & 27 deletions compiler/rustc_borrowck/src/diagnostics/conflict_errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -673,40 +673,34 @@ impl<'cx, 'tcx> MirBorrowckCtxt<'cx, 'tcx> {
let tcx = self.infcx.tcx;

// Find out if the predicates show that the type is a Fn or FnMut
let find_fn_kind_from_did = |predicates: ty::EarlyBinder<
&[(ty::Predicate<'tcx>, Span)],
>,
substs| {
predicates.0.iter().find_map(|(pred, _)| {
let pred = if let Some(substs) = substs {
predicates.rebind(*pred).subst(tcx, substs).kind().skip_binder()
} else {
pred.kind().skip_binder()
};
if let ty::PredicateKind::Clause(ty::Clause::Trait(pred)) = pred && pred.self_ty() == ty {
if Some(pred.def_id()) == tcx.lang_items().fn_trait() {
return Some(hir::Mutability::Not);
} else if Some(pred.def_id()) == tcx.lang_items().fn_mut_trait() {
return Some(hir::Mutability::Mut);
}
let find_fn_kind_from_did = |(pred, _): (ty::Predicate<'tcx>, _)| {
if let ty::PredicateKind::Clause(ty::Clause::Trait(pred)) = pred.kind().skip_binder()
&& pred.self_ty() == ty
{
if Some(pred.def_id()) == tcx.lang_items().fn_trait() {
return Some(hir::Mutability::Not);
} else if Some(pred.def_id()) == tcx.lang_items().fn_mut_trait() {
return Some(hir::Mutability::Mut);
}
None
})
}
None
};

// If the type is opaque/param/closure, and it is Fn or FnMut, let's suggest (mutably)
// borrowing the type, since `&mut F: FnMut` iff `F: FnMut` and similarly for `Fn`.
// These types seem reasonably opaque enough that they could be substituted with their
// borrowed variants in a function body when we see a move error.
let borrow_level = match ty.kind() {
ty::Param(_) => find_fn_kind_from_did(
tcx.bound_explicit_predicates_of(self.mir_def_id().to_def_id())
.map_bound(|p| p.predicates),
None,
),
ty::Alias(ty::Opaque, ty::AliasTy { def_id, substs, .. }) => {
find_fn_kind_from_did(tcx.bound_explicit_item_bounds(*def_id), Some(*substs))
}
let borrow_level = match *ty.kind() {
ty::Param(_) => tcx
.explicit_predicates_of(self.mir_def_id().to_def_id())
.predicates
.iter()
.copied()
.find_map(find_fn_kind_from_did),
ty::Alias(ty::Opaque, ty::AliasTy { def_id, substs, .. }) => tcx
.bound_explicit_item_bounds(def_id)
.subst_iter_copied(tcx, substs)
.find_map(find_fn_kind_from_did),
ty::Closure(_, substs) => match substs.as_closure().kind() {
ty::ClosureKind::Fn => Some(hir::Mutability::Not),
ty::ClosureKind::FnMut => Some(hir::Mutability::Mut),
Expand Down
6 changes: 1 addition & 5 deletions compiler/rustc_borrowck/src/type_check/canonical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,7 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
instantiated_predicates: ty::InstantiatedPredicates<'tcx>,
locations: Locations,
) {
for (predicate, span) in instantiated_predicates
.predicates
.into_iter()
.zip(instantiated_predicates.spans.into_iter())
{
for (predicate, span) in instantiated_predicates {
debug!(?predicate);
let category = ConstraintCategory::Predicate(span);
let predicate = self.normalize_with_category(predicate, locations, category);
Expand Down
25 changes: 13 additions & 12 deletions compiler/rustc_hir_analysis/src/check/compare_impl_item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,11 @@ fn compare_method_predicate_entailment<'tcx>(
//
// We then register the obligations from the impl_m and check to see
// if all constraints hold.
hybrid_preds
.predicates
.extend(trait_m_predicates.instantiate_own(tcx, trait_to_placeholder_substs).predicates);
hybrid_preds.predicates.extend(
trait_m_predicates
.instantiate_own(tcx, trait_to_placeholder_substs)
.map(|(predicate, _)| predicate),
);

// Construct trait parameter environment and then shift it into the placeholder viewpoint.
// The key step here is to update the caller_bounds's predicates to be
Expand All @@ -230,7 +232,7 @@ fn compare_method_predicate_entailment<'tcx>(
debug!("compare_impl_method: caller_bounds={:?}", param_env.caller_bounds());

let impl_m_own_bounds = impl_m_predicates.instantiate_own(tcx, impl_to_placeholder_substs);
for (predicate, span) in iter::zip(impl_m_own_bounds.predicates, impl_m_own_bounds.spans) {
for (predicate, span) in impl_m_own_bounds {
let normalize_cause = traits::ObligationCause::misc(span, impl_m_hir_id);
let predicate = ocx.normalize(&normalize_cause, param_env, predicate);

Expand Down Expand Up @@ -1828,8 +1830,7 @@ fn compare_type_predicate_entailment<'tcx>(
check_region_bounds_on_impl_item(tcx, impl_ty, trait_ty, false)?;

let impl_ty_own_bounds = impl_ty_predicates.instantiate_own(tcx, impl_substs);

if impl_ty_own_bounds.is_empty() {
if impl_ty_own_bounds.len() == 0 {
// Nothing to check.
return Ok(());
}
Expand All @@ -1844,9 +1845,11 @@ fn compare_type_predicate_entailment<'tcx>(
// associated type in the trait are assumed.
let impl_predicates = tcx.predicates_of(impl_ty_predicates.parent.unwrap());
let mut hybrid_preds = impl_predicates.instantiate_identity(tcx);
hybrid_preds
.predicates
.extend(trait_ty_predicates.instantiate_own(tcx, trait_to_impl_substs).predicates);
hybrid_preds.predicates.extend(
trait_ty_predicates
.instantiate_own(tcx, trait_to_impl_substs)
.map(|(predicate, _)| predicate),
);

debug!("compare_type_predicate_entailment: bounds={:?}", hybrid_preds);

Expand All @@ -1862,9 +1865,7 @@ fn compare_type_predicate_entailment<'tcx>(

debug!("compare_type_predicate_entailment: caller_bounds={:?}", param_env.caller_bounds());

assert_eq!(impl_ty_own_bounds.predicates.len(), impl_ty_own_bounds.spans.len());
for (span, predicate) in std::iter::zip(impl_ty_own_bounds.spans, impl_ty_own_bounds.predicates)
{
for (predicate, span) in impl_ty_own_bounds {
let cause = ObligationCause::misc(span, impl_ty_hir_id);
let predicate = ocx.normalize(&cause, param_env, predicate);

Expand Down
29 changes: 13 additions & 16 deletions compiler/rustc_hir_analysis/src/check/wfcheck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ use rustc_trait_selection::traits::{
};

use std::cell::LazyCell;
use std::iter;
use std::ops::{ControlFlow, Deref};

pub(super) struct WfCheckingCtxt<'a, 'tcx> {
Expand Down Expand Up @@ -1310,7 +1309,7 @@ fn check_where_clauses<'tcx>(wfcx: &WfCheckingCtxt<'_, 'tcx>, span: Span, def_id
let infcx = wfcx.infcx;
let tcx = wfcx.tcx();

let predicates = tcx.bound_predicates_of(def_id.to_def_id());
let predicates = tcx.predicates_of(def_id.to_def_id());
let generics = tcx.generics_of(def_id);

let is_our_default = |def: &ty::GenericParamDef| match def.kind {
Expand Down Expand Up @@ -1411,7 +1410,6 @@ fn check_where_clauses<'tcx>(wfcx: &WfCheckingCtxt<'_, 'tcx>, span: Span, def_id

// Now we build the substituted predicates.
let default_obligations = predicates
.0
.predicates
.iter()
.flat_map(|&(pred, sp)| {
Expand Down Expand Up @@ -1442,13 +1440,13 @@ fn check_where_clauses<'tcx>(wfcx: &WfCheckingCtxt<'_, 'tcx>, span: Span, def_id
}
let mut param_count = CountParams::default();
let has_region = pred.visit_with(&mut param_count).is_break();
let substituted_pred = predicates.rebind(pred).subst(tcx, substs);
let substituted_pred = ty::EarlyBinder(pred).subst(tcx, substs);
// Don't check non-defaulted params, dependent defaults (including lifetimes)
// or preds with multiple params.
if substituted_pred.has_non_region_param() || param_count.params.len() > 1 || has_region
{
None
} else if predicates.0.predicates.iter().any(|&(p, _)| p == substituted_pred) {
} else if predicates.predicates.iter().any(|&(p, _)| p == substituted_pred) {
// Avoid duplication of predicates that contain no parameters, for example.
None
} else {
Expand All @@ -1474,22 +1472,21 @@ fn check_where_clauses<'tcx>(wfcx: &WfCheckingCtxt<'_, 'tcx>, span: Span, def_id
traits::Obligation::new(tcx, cause, wfcx.param_env, pred)
});

let predicates = predicates.0.instantiate_identity(tcx);
let predicates = predicates.instantiate_identity(tcx);

let predicates = wfcx.normalize(span, None, predicates);

debug!(?predicates.predicates);
assert_eq!(predicates.predicates.len(), predicates.spans.len());
let wf_obligations =
iter::zip(&predicates.predicates, &predicates.spans).flat_map(|(&p, &sp)| {
traits::wf::predicate_obligations(
infcx,
wfcx.param_env.without_const(),
wfcx.body_id,
p,
sp,
)
});
let wf_obligations = predicates.into_iter().flat_map(|(p, sp)| {
traits::wf::predicate_obligations(
infcx,
wfcx.param_env.without_const(),
wfcx.body_id,
p,
sp,
)
});

let obligations: Vec<_> = wf_obligations.chain(default_obligations).collect();
wfcx.register_obligations(obligations);
Expand Down
8 changes: 3 additions & 5 deletions compiler/rustc_hir_typeck/src/callee.rs
Original file line number Diff line number Diff line change
Expand Up @@ -375,14 +375,12 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
if self.tcx.has_attr(def_id, sym::rustc_evaluate_where_clauses) {
let predicates = self.tcx.predicates_of(def_id);
let predicates = predicates.instantiate(self.tcx, subst);
for (predicate, predicate_span) in
predicates.predicates.iter().zip(&predicates.spans)
{
for (predicate, predicate_span) in predicates {
let obligation = Obligation::new(
self.tcx,
ObligationCause::dummy_with_span(callee_expr.span),
self.param_env,
*predicate,
predicate,
);
let result = self.evaluate_obligation(&obligation);
self.tcx
Expand All @@ -391,7 +389,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
callee_expr.span,
&format!("evaluate({:?}) = {:?}", predicate, result),
)
.span_label(*predicate_span, "predicate")
.span_label(predicate_span, "predicate")
.emit();
}
}
Expand Down
3 changes: 1 addition & 2 deletions compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2140,8 +2140,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
// FIXME(compiler-errors): This could be problematic if something has two
// fn-like predicates with different args, but callable types really never
// do that, so it's OK.
for (predicate, span) in
std::iter::zip(instantiated.predicates, instantiated.spans)
for (predicate, span) in instantiated
{
if let ty::PredicateKind::Clause(ty::Clause::Trait(pred)) = predicate.kind().skip_binder()
&& pred.self_ty().peel_refs() == callee_ty
Expand Down
10 changes: 5 additions & 5 deletions compiler/rustc_hir_typeck/src/method/confirm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ use rustc_middle::ty::{InternalSubsts, UserSubsts, UserType};
use rustc_span::{Span, DUMMY_SP};
use rustc_trait_selection::traits;

use std::iter;
use std::ops::Deref;

struct ConfirmContext<'a, 'tcx> {
Expand Down Expand Up @@ -101,7 +100,7 @@ impl<'a, 'tcx> ConfirmContext<'a, 'tcx> {
let filler_substs = rcvr_substs
.extend_to(self.tcx, pick.item.def_id, |def, _| self.tcx.mk_param_from_def(def));
let illegal_sized_bound = self.predicates_require_illegal_sized_bound(
&self.tcx.predicates_of(pick.item.def_id).instantiate(self.tcx, filler_substs),
self.tcx.predicates_of(pick.item.def_id).instantiate(self.tcx, filler_substs),
);

// Unify the (adjusted) self type with what the method expects.
Expand Down Expand Up @@ -565,7 +564,7 @@ impl<'a, 'tcx> ConfirmContext<'a, 'tcx> {

fn predicates_require_illegal_sized_bound(
&self,
predicates: &ty::InstantiatedPredicates<'tcx>,
predicates: ty::InstantiatedPredicates<'tcx>,
) -> Option<Span> {
let sized_def_id = self.tcx.lang_items().sized_trait()?;

Expand All @@ -575,10 +574,11 @@ impl<'a, 'tcx> ConfirmContext<'a, 'tcx> {
ty::PredicateKind::Clause(ty::Clause::Trait(trait_pred))
if trait_pred.def_id() == sized_def_id =>
{
let span = iter::zip(&predicates.predicates, &predicates.spans)
let span = predicates
.iter()
.find_map(
|(p, span)| {
if *p == obligation.predicate { Some(*span) } else { None }
if p == obligation.predicate { Some(span) } else { None }
},
)
.unwrap_or(rustc_span::DUMMY_SP);
Expand Down
5 changes: 2 additions & 3 deletions compiler/rustc_infer/src/infer/error_reporting/note.rs
Original file line number Diff line number Diff line change
Expand Up @@ -330,9 +330,8 @@ impl<'tcx> TypeErrCtxt<'_, 'tcx> {

let Ok(trait_predicates) = self
.tcx
.bound_explicit_predicates_of(trait_item_def_id)
.map_bound(|p| p.predicates)
.subst_iter_copied(self.tcx, trait_item_substs)
.explicit_predicates_of(trait_item_def_id)
.instantiate_own(self.tcx, trait_item_substs)
.map(|(pred, _)| {
if pred.is_suggestable(self.tcx, false) {
Ok(pred.to_string())
Expand Down
12 changes: 3 additions & 9 deletions compiler/rustc_middle/src/ty/generics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -341,15 +341,9 @@ impl<'tcx> GenericPredicates<'tcx> {
&self,
tcx: TyCtxt<'tcx>,
substs: SubstsRef<'tcx>,
) -> InstantiatedPredicates<'tcx> {
InstantiatedPredicates {
predicates: self
.predicates
.iter()
.map(|(p, _)| EarlyBinder(*p).subst(tcx, substs))
.collect(),
spans: self.predicates.iter().map(|(_, sp)| *sp).collect(),
}
) -> impl Iterator<Item = (Predicate<'tcx>, Span)> + DoubleEndedIterator + ExactSizeIterator
{
EarlyBinder(self.predicates).subst_iter_copied(tcx, substs)
}

#[instrument(level = "debug", skip(self, tcx))]
Expand Down
29 changes: 29 additions & 0 deletions compiler/rustc_middle/src/ty/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1252,6 +1252,35 @@ impl<'tcx> InstantiatedPredicates<'tcx> {
pub fn is_empty(&self) -> bool {
self.predicates.is_empty()
}

pub fn iter(&self) -> <&Self as IntoIterator>::IntoIter {
(&self).into_iter()
}
}

impl<'tcx> IntoIterator for InstantiatedPredicates<'tcx> {
type Item = (Predicate<'tcx>, Span);

type IntoIter = std::iter::Zip<std::vec::IntoIter<Predicate<'tcx>>, std::vec::IntoIter<Span>>;

fn into_iter(self) -> Self::IntoIter {
debug_assert_eq!(self.predicates.len(), self.spans.len());
std::iter::zip(self.predicates, self.spans)
}
}

impl<'a, 'tcx> IntoIterator for &'a InstantiatedPredicates<'tcx> {
type Item = (Predicate<'tcx>, Span);

type IntoIter = std::iter::Zip<
std::iter::Copied<std::slice::Iter<'a, Predicate<'tcx>>>,
std::iter::Copied<std::slice::Iter<'a, Span>>,
>;

fn into_iter(self) -> Self::IntoIter {
debug_assert_eq!(self.predicates.len(), self.spans.len());
std::iter::zip(self.predicates.iter().copied(), self.spans.iter().copied())
}
}

#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, HashStable, TyEncodable, TyDecodable, Lift)]
Expand Down
15 changes: 15 additions & 0 deletions compiler/rustc_middle/src/ty/subst.rs
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,13 @@ where
}
}

impl<'tcx, I: IntoIterator> ExactSizeIterator for SubstIter<'_, 'tcx, I>
where
I::IntoIter: ExactSizeIterator,
I::Item: TypeFoldable<'tcx>,
{
}

impl<'tcx, 's, I: IntoIterator> EarlyBinder<I>
where
I::Item: Deref,
Expand Down Expand Up @@ -686,6 +693,14 @@ where
}
}

impl<'tcx, I: IntoIterator> ExactSizeIterator for SubstIterCopied<'_, 'tcx, I>
where
I::IntoIter: ExactSizeIterator,
I::Item: Deref,
<I::Item as Deref>::Target: Copy + TypeFoldable<'tcx>,
{
}

pub struct EarlyBinderIter<T> {
t: T,
}
Expand Down
Loading