Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do project specializable RPITIT projection #108321

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 28 additions & 7 deletions compiler/rustc_hir_analysis/src/check/compare_impl_item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ pub(super) fn compare_impl_method<'tcx>(
compare_generic_param_kinds(tcx, impl_m, trait_m, false)?;
compare_number_of_method_arguments(tcx, impl_m, trait_m)?;
compare_synthetic_generics(tcx, impl_m, trait_m)?;
compare_asyncness(tcx, impl_m, trait_m)?;
compare_asyncness(tcx, impl_m, trait_m, false)?;
compare_method_predicate_entailment(
tcx,
impl_m,
Expand Down Expand Up @@ -191,6 +191,11 @@ fn compare_method_predicate_entailment<'tcx>(
.map(|(predicate, _)| predicate),
);

// Additionally, we are allowed to assume that we can project RPITITs to their
// associated hidden types within method signatures. This is to allow us to support
// specialization with `impl Trait` in traits.
hybrid_preds.predicates.extend(tcx.additional_method_assumptions(impl_m_def_id));

// Construct trait parameter environment and then shift it into the placeholder viewpoint.
// The key step here is to update the caller_bounds's predicates to be
// the new hybrid bounds we computed.
Expand Down Expand Up @@ -526,6 +531,7 @@ fn compare_asyncness<'tcx>(
tcx: TyCtxt<'tcx>,
impl_m: ty::AssocItem,
trait_m: ty::AssocItem,
delay: bool,
) -> Result<(), ErrorGuaranteed> {
if tcx.asyncness(trait_m.def_id) == hir::IsAsync::Async {
match tcx.fn_sig(impl_m.def_id).skip_binder().skip_binder().output().kind() {
Expand All @@ -536,11 +542,14 @@ fn compare_asyncness<'tcx>(
// We don't know if it's ok, but at least it's already an error.
}
_ => {
return Err(tcx.sess.emit_err(crate::errors::AsyncTraitImplShouldBeAsync {
span: tcx.def_span(impl_m.def_id),
method_name: trait_m.name,
trait_item_span: tcx.hir().span_if_local(trait_m.def_id),
}));
return Err(tcx
.sess
.create_err(crate::errors::AsyncTraitImplShouldBeAsync {
span: tcx.def_span(impl_m.def_id),
method_name: trait_m.name,
trait_item_span: tcx.hir().span_if_local(trait_m.def_id),
})
.emit_unless(delay));
}
};
}
Expand Down Expand Up @@ -590,10 +599,15 @@ pub(super) fn collect_return_position_impl_trait_in_trait_tys<'tcx>(
let trait_m = tcx.opt_associated_item(impl_m.trait_item_def_id.unwrap()).unwrap();
let impl_trait_ref =
tcx.impl_trait_ref(impl_m.impl_container(tcx).unwrap()).unwrap().subst_identity();
let param_env = tcx.param_env(def_id);

// We use the RPITIT values computed in this method to construct the param-env,
// so to avoid cycles, we do computations in this function without assuming anything
// about RPITIT projection.
let param_env = tcx.param_env_no_assumptions(def_id);

// First, check a few of the same things as `compare_impl_method`,
// just so we don't ICE during substitution later.
compare_asyncness(tcx, impl_m, trait_m, true)?;
compare_number_of_generics(tcx, impl_m, trait_m, true)?;
compare_generic_param_kinds(tcx, impl_m, trait_m, true)?;
check_region_bounds_on_impl_item(tcx, impl_m, trait_m, true)?;
Expand Down Expand Up @@ -648,6 +662,13 @@ pub(super) fn collect_return_position_impl_trait_in_trait_tys<'tcx>(
tcx.fn_sig(trait_m.def_id).subst(tcx, trait_to_placeholder_substs),
)
.fold_with(&mut collector);

debug_assert_ne!(
collector.types.len(),
0,
"expect >1 RPITITs in call to `collect_return_position_impl_trait_in_trait_tys`"
);

let trait_sig = ocx.normalize(&norm_cause, param_env, unnormalized_trait_sig);
trait_sig.error_reported()?;
let trait_return_ty = trait_sig.output();
Expand Down
30 changes: 1 addition & 29 deletions compiler/rustc_metadata/src/rmeta/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1101,34 +1101,6 @@ fn should_encode_const(def_kind: DefKind) -> bool {
}
}

fn should_encode_trait_impl_trait_tys(tcx: TyCtxt<'_>, def_id: DefId) -> bool {
if tcx.def_kind(def_id) != DefKind::AssocFn {
return false;
}

let Some(item) = tcx.opt_associated_item(def_id) else { return false; };
if item.container != ty::AssocItemContainer::ImplContainer {
return false;
}

let Some(trait_item_def_id) = item.trait_item_def_id else { return false; };

// FIXME(RPITIT): This does a somewhat manual walk through the signature
// of the trait fn to look for any RPITITs, but that's kinda doing a lot
// of work. We can probably remove this when we refactor RPITITs to be
// associated types.
tcx.fn_sig(trait_item_def_id).subst_identity().skip_binder().output().walk().any(|arg| {
if let ty::GenericArgKind::Type(ty) = arg.unpack()
&& let ty::Alias(ty::Projection, data) = ty.kind()
&& tcx.def_kind(data.def_id) == DefKind::ImplTraitPlaceholder
{
true
} else {
false
}
})
}

// Return `false` to avoid encoding impl trait in trait, while we don't use the query.
fn should_encode_fn_impl_trait_in_trait<'tcx>(_tcx: TyCtxt<'tcx>, _def_id: DefId) -> bool {
false
Expand Down Expand Up @@ -1211,7 +1183,7 @@ impl<'a, 'tcx> EncodeContext<'a, 'tcx> {
if let DefKind::Enum | DefKind::Struct | DefKind::Union = def_kind {
self.encode_info_for_adt(def_id);
}
if should_encode_trait_impl_trait_tys(tcx, def_id)
if tcx.impl_method_has_trait_impl_trait_tys(def_id)
&& let Ok(table) = self.tcx.collect_return_position_impl_trait_in_trait_tys(def_id)
{
record!(self.tables.trait_impl_trait_tys[def_id] <- table);
Expand Down
8 changes: 8 additions & 0 deletions compiler/rustc_middle/src/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1313,6 +1313,14 @@ rustc_queries! {
desc { |tcx| "computing normalized predicates of `{}`", tcx.def_path_str(def_id) }
}

query param_env_no_assumptions(def_id: DefId) -> ty::ParamEnv<'tcx> {
desc { |tcx| "computing normalized predicates of `{}`", tcx.def_path_str(def_id) }
}

query additional_method_assumptions(def_id: DefId) -> &'tcx ty::List<ty::Predicate<'tcx>> {
desc { |tcx| "computing additional predicate assumptions for the body of `{}`", tcx.def_path_str(def_id) }
}

/// Like `param_env`, but returns the `ParamEnv` in `Reveal::All` mode.
/// Prefer this over `tcx.param_env(def_id).with_reveal_all_normalized(tcx)`,
/// as this method is more efficient.
Expand Down
28 changes: 28 additions & 0 deletions compiler/rustc_middle/src/ty/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2544,6 +2544,34 @@ impl<'tcx> TyCtxt<'tcx> {
}
def_id
}

pub fn impl_method_has_trait_impl_trait_tys(self, def_id: DefId) -> bool {
if self.def_kind(def_id) != DefKind::AssocFn {
return false;
}

let Some(item) = self.opt_associated_item(def_id) else { return false; };
if item.container != ty::AssocItemContainer::ImplContainer {
return false;
}

let Some(trait_item_def_id) = item.trait_item_def_id else { return false; };

// FIXME(RPITIT): This does a somewhat manual walk through the signature
// of the trait fn to look for any RPITITs, but that's kinda doing a lot
// of work. We can probably remove this when we refactor RPITITs to be
// associated types.
self.fn_sig(trait_item_def_id).subst_identity().skip_binder().output().walk().any(|arg| {
if let ty::GenericArgKind::Type(ty) = arg.unpack()
&& let ty::Alias(ty::Projection, data) = ty.kind()
&& self.def_kind(data.def_id) == DefKind::ImplTraitPlaceholder
{
true
} else {
false
}
})
}
}

/// Yields the parent function's `LocalDefId` if `def_id` is an `impl Trait` definition.
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_middle/src/ty/subst.rs
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ impl<'tcx> InternalSubsts<'tcx> {
target_substs: SubstsRef<'tcx>,
) -> SubstsRef<'tcx> {
let defs = tcx.generics_of(source_ancestor);
tcx.mk_substs(target_substs.iter().chain(self.iter().skip(defs.params.len())))
tcx.mk_substs(target_substs.iter().chain(self.iter().skip(defs.count())))
}

pub fn truncate_to(&self, tcx: TyCtxt<'tcx>, generics: &ty::Generics) -> SubstsRef<'tcx> {
Expand Down
21 changes: 19 additions & 2 deletions compiler/rustc_trait_selection/src/traits/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1246,13 +1246,13 @@ fn project<'cx, 'tcx>(

let mut candidates = ProjectionCandidateSet::None;

assemble_candidate_for_impl_trait_in_trait(selcx, obligation, &mut candidates);

// Make sure that the following procedures are kept in order. ParamEnv
// needs to be first because it has highest priority, and Select checks
// the return value of push_candidate which assumes it's ran at last.
assemble_candidates_from_param_env(selcx, obligation, &mut candidates);

assemble_candidate_for_impl_trait_in_trait(selcx, obligation, &mut candidates);

assemble_candidates_from_trait_def(selcx, obligation, &mut candidates);

assemble_candidates_from_object_ty(selcx, obligation, &mut candidates);
Expand Down Expand Up @@ -1307,6 +1307,23 @@ fn assemble_candidate_for_impl_trait_in_trait<'cx, 'tcx>(
let _ = selcx.infcx.commit_if_ok(|_| {
match selcx.select(&obligation.with(tcx, trait_predicate)) {
Ok(Some(super::ImplSource::UserDefined(data))) => {
let Ok(leaf_def) = specialization_graph::assoc_def(tcx, data.impl_def_id, trait_fn_def_id) else {
return Err(());
};
// Only reveal a specializable default if we're past type-checking
// and the obligation is monomorphic, otherwise passes such as
// transmute checking and polymorphic MIR optimizations could
// get a result which isn't correct for all monomorphizations.
if !leaf_def.is_final()
&& obligation.param_env.reveal() != Reveal::All
&& selcx
.infcx
.resolve_vars_if_possible(obligation.predicate.trait_ref(tcx))
.still_further_specializable()
{
return Err(());
}

candidate_set.push_candidate(ProjectionCandidate::ImplTraitInTrait(data));
Ok(())
}
Expand Down
89 changes: 73 additions & 16 deletions compiler/rustc_ty_utils/src/ty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ fn adt_sized_constraint(tcx: TyCtxt<'_>, def_id: DefId) -> &[Ty<'_>] {
}

/// See `ParamEnv` struct definition for details.
fn param_env(tcx: TyCtxt<'_>, def_id: DefId) -> ty::ParamEnv<'_> {
fn param_env(tcx: TyCtxt<'_>, def_id: DefId, add_assumptions: bool) -> ty::ParamEnv<'_> {
// Compute the bounds on Self and the type parameters.
let ty::InstantiatedPredicates { mut predicates, .. } =
tcx.predicates_of(def_id).instantiate_identity(tcx);
Expand All @@ -138,17 +138,8 @@ fn param_env(tcx: TyCtxt<'_>, def_id: DefId) -> ty::ParamEnv<'_> {
predicates.extend(environment);
}

if tcx.def_kind(def_id) == DefKind::AssocFn
&& tcx.associated_item(def_id).container == ty::AssocItemContainer::TraitContainer
{
let sig = tcx.fn_sig(def_id).subst_identity();
sig.visit_with(&mut ImplTraitInTraitFinder {
tcx,
fn_def_id: def_id,
bound_vars: sig.bound_vars(),
predicates: &mut predicates,
seen: FxHashSet::default(),
});
if add_assumptions && tcx.def_kind(def_id) == DefKind::AssocFn {
predicates.extend(tcx.additional_method_assumptions(def_id))
}

let local_did = def_id.as_local();
Expand Down Expand Up @@ -237,19 +228,83 @@ fn param_env(tcx: TyCtxt<'_>, def_id: DefId) -> ty::ParamEnv<'_> {
traits::normalize_param_env_or_error(tcx, unnormalized_env, cause)
}

fn additional_method_assumptions<'tcx>(
tcx: TyCtxt<'tcx>,
def_id: DefId,
) -> &'tcx ty::List<Predicate<'tcx>> {
let assoc_item = tcx.associated_item(def_id);
let mut predicates = vec![];

match assoc_item.container {
ty::AssocItemContainer::TraitContainer => {
let sig = tcx.fn_sig(def_id).subst_identity();
sig.visit_with(&mut ImplTraitInTraitFinder {
tcx,
fn_def_id: def_id,
bound_vars: sig.bound_vars(),
predicates: &mut predicates,
seen: FxHashSet::default(),
hidden_ty: |alias_ty| tcx.mk_alias(ty::Opaque, alias_ty),
});
}
ty::AssocItemContainer::ImplContainer => {
if tcx.impl_method_has_trait_impl_trait_tys(def_id)
&& let Ok(table)
= tcx.collect_return_position_impl_trait_in_trait_tys(def_id)
{
let impl_def_id = assoc_item.container_id(tcx);
let trait_to_impl_substs =
tcx.impl_trait_ref(impl_def_id).unwrap().subst_identity().substs;
// Create mapping from impl to placeholder.
let impl_to_placeholder_substs = ty::InternalSubsts::identity_for_item(tcx, def_id);
// Create mapping from trait to placeholder.
let trait_to_placeholder_substs =
impl_to_placeholder_substs.rebase_onto(tcx, impl_def_id, trait_to_impl_substs);

let trait_fn_def_id = assoc_item.trait_item_def_id.unwrap();
let trait_fn_sig =
tcx.fn_sig(trait_fn_def_id).subst(tcx, trait_to_placeholder_substs);
trait_fn_sig.visit_with(&mut ImplTraitInTraitFinder {
tcx,
fn_def_id: trait_fn_def_id,
bound_vars: trait_fn_sig.bound_vars(),
predicates: &mut predicates,
seen: FxHashSet::default(),
hidden_ty: |alias_ty| {
EarlyBinder(*table.get(&alias_ty.def_id).unwrap()).subst(
tcx,
alias_ty.substs.rebase_onto(
tcx,
trait_fn_def_id,
impl_to_placeholder_substs,
),
)
},
});
}
}
}

tcx.intern_predicates(&predicates)
}

/// Walk through a function type, gathering all RPITITs and installing a
/// `NormalizesTo(Projection(RPITIT) -> Opaque(RPITIT))` predicate into the
/// predicates list. This allows us to observe that an RPITIT projects to
/// its corresponding opaque within the body of a default-body trait method.
struct ImplTraitInTraitFinder<'a, 'tcx> {
struct ImplTraitInTraitFinder<'a, 'tcx, F: Fn(ty::AliasTy<'tcx>) -> Ty<'tcx>> {
tcx: TyCtxt<'tcx>,
predicates: &'a mut Vec<Predicate<'tcx>>,
fn_def_id: DefId,
bound_vars: &'tcx ty::List<ty::BoundVariableKind>,
seen: FxHashSet<DefId>,
hidden_ty: F,
}

impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for ImplTraitInTraitFinder<'_, 'tcx> {
impl<'tcx, F> TypeVisitor<TyCtxt<'tcx>> for ImplTraitInTraitFinder<'_, 'tcx, F>
where
F: Fn(ty::AliasTy<'tcx>) -> Ty<'tcx>,
{
fn visit_ty(&mut self, ty: Ty<'tcx>) -> std::ops::ControlFlow<Self::BreakTy> {
if let ty::Alias(ty::Projection, alias_ty) = *ty.kind()
&& self.tcx.def_kind(alias_ty.def_id) == DefKind::ImplTraitPlaceholder
Expand All @@ -260,7 +315,7 @@ impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for ImplTraitInTraitFinder<'_, 'tcx> {
ty::Binder::bind_with_vars(
ty::ProjectionPredicate {
projection_ty: alias_ty,
term: self.tcx.mk_alias(ty::Opaque, alias_ty).into(),
term: (self.hidden_ty)(alias_ty).into(),
},
self.bound_vars,
)
Expand Down Expand Up @@ -514,7 +569,9 @@ pub fn provide(providers: &mut ty::query::Providers) {
*providers = ty::query::Providers {
asyncness,
adt_sized_constraint,
param_env,
param_env: |tcx, def_id| param_env(tcx, def_id, true),
param_env_no_assumptions: |tcx, def_id| param_env(tcx, def_id, false),
additional_method_assumptions,
param_env_reveal_all_normalized,
instance_def_size_estimate,
issue33140_self_ty,
Expand Down
Loading