Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement syntax for impl Trait to specify its captures explicitly (feature(precise_capturing)) #123468

Merged
merged 11 commits into from
Apr 16, 2024
12 changes: 10 additions & 2 deletions compiler/rustc_ast/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ impl fmt::Debug for Label {

/// A "Lifetime" is an annotation of the scope in which variable
/// can be used, e.g. `'a` in `&'a i32`.
#[derive(Clone, Encodable, Decodable, Copy, PartialEq, Eq)]
#[derive(Clone, Encodable, Decodable, Copy, PartialEq, Eq, Hash)]
pub struct Lifetime {
pub id: NodeId,
pub ident: Ident,
Expand Down Expand Up @@ -2132,7 +2132,7 @@ pub enum TyKind {
/// The `NodeId` exists to prevent lowering from having to
/// generate `NodeId`s on the fly, which would complicate
/// the generation of opaque `type Foo = impl Trait` items significantly.
ImplTrait(NodeId, GenericBounds),
ImplTrait(NodeId, GenericBounds, Option<P<(ThinVec<PreciseCapturingArg>, Span)>>),
/// No-op; kept solely so that we can pretty-print faithfully.
Paren(P<Ty>),
/// Unused for now.
Expand Down Expand Up @@ -2188,6 +2188,14 @@ pub enum TraitObjectSyntax {
None,
}

#[derive(Clone, Encodable, Decodable, Debug)]
pub enum PreciseCapturingArg {
/// Lifetime parameter
Lifetime(Lifetime),
/// Type or const parameter
Arg(Path, NodeId),
}

/// Inline assembly operand explicit register or register class.
///
/// E.g., `"eax"` as in `asm!("mov eax, 2", out("eax") result)`.
Expand Down
23 changes: 22 additions & 1 deletion compiler/rustc_ast/src/mut_visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,10 @@ pub trait MutVisitor: Sized {
noop_visit_param_bound(tpb, self);
}

fn visit_precise_capturing_arg(&mut self, arg: &mut PreciseCapturingArg) {
noop_visit_precise_capturing_arg(arg, self);
}

fn visit_mt(&mut self, mt: &mut MutTy) {
noop_visit_mt(mt, self);
}
Expand Down Expand Up @@ -518,9 +522,14 @@ pub fn noop_visit_ty<T: MutVisitor>(ty: &mut P<Ty>, vis: &mut T) {
TyKind::TraitObject(bounds, _syntax) => {
visit_vec(bounds, |bound| vis.visit_param_bound(bound))
}
TyKind::ImplTrait(id, bounds) => {
TyKind::ImplTrait(id, bounds, precise_capturing) => {
vis.visit_id(id);
visit_vec(bounds, |bound| vis.visit_param_bound(bound));
if let Some((precise_capturing, _span)) = precise_capturing.as_deref_mut() {
for arg in precise_capturing {
vis.visit_precise_capturing_arg(arg);
}
}
}
TyKind::MacCall(mac) => vis.visit_mac_call(mac),
TyKind::AnonStruct(id, fields) | TyKind::AnonUnion(id, fields) => {
Expand Down Expand Up @@ -914,6 +923,18 @@ pub fn noop_visit_param_bound<T: MutVisitor>(pb: &mut GenericBound, vis: &mut T)
}
}

pub fn noop_visit_precise_capturing_arg<T: MutVisitor>(arg: &mut PreciseCapturingArg, vis: &mut T) {
match arg {
PreciseCapturingArg::Lifetime(lt) => {
vis.visit_lifetime(lt);
}
PreciseCapturingArg::Arg(path, id) => {
vis.visit_path(path);
vis.visit_id(id);
}
}
}

pub fn noop_flat_map_generic_param<T: MutVisitor>(
mut param: GenericParam,
vis: &mut T,
Expand Down
24 changes: 23 additions & 1 deletion compiler/rustc_ast/src/visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@ pub trait Visitor<'ast>: Sized {
fn visit_param_bound(&mut self, bounds: &'ast GenericBound, _ctxt: BoundKind) -> Self::Result {
walk_param_bound(self, bounds)
}
fn visit_precise_capturing_arg(&mut self, arg: &'ast PreciseCapturingArg) {
walk_precise_capturing_arg(self, arg);
}
fn visit_poly_trait_ref(&mut self, t: &'ast PolyTraitRef) -> Self::Result {
walk_poly_trait_ref(self, t)
}
Expand Down Expand Up @@ -457,8 +460,13 @@ pub fn walk_ty<'a, V: Visitor<'a>>(visitor: &mut V, typ: &'a Ty) -> V::Result {
TyKind::TraitObject(bounds, ..) => {
walk_list!(visitor, visit_param_bound, bounds, BoundKind::TraitObject);
}
TyKind::ImplTrait(_, bounds) => {
TyKind::ImplTrait(_, bounds, precise_capturing) => {
walk_list!(visitor, visit_param_bound, bounds, BoundKind::Impl);
if let Some((precise_capturing, _span)) = precise_capturing.as_deref() {
for arg in precise_capturing {
try_visit!(visitor.visit_precise_capturing_arg(arg));
}
}
}
TyKind::Typeof(expression) => try_visit!(visitor.visit_anon_const(expression)),
TyKind::Infer | TyKind::ImplicitSelf | TyKind::Dummy | TyKind::Err(_) => {}
Expand Down Expand Up @@ -637,6 +645,20 @@ pub fn walk_param_bound<'a, V: Visitor<'a>>(visitor: &mut V, bound: &'a GenericB
}
}

pub fn walk_precise_capturing_arg<'a, V: Visitor<'a>>(
visitor: &mut V,
arg: &'a PreciseCapturingArg,
) {
match arg {
PreciseCapturingArg::Lifetime(lt) => {
visitor.visit_lifetime(lt, LifetimeCtxt::GenericArg);
}
PreciseCapturingArg::Arg(path, id) => {
visitor.visit_path(path, *id);
}
}
}

pub fn walk_generic_param<'a, V: Visitor<'a>>(
visitor: &mut V,
param: &'a GenericParam,
Expand Down
2 changes: 2 additions & 0 deletions compiler/rustc_ast_lowering/messages.ftl
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ ast_lowering_never_pattern_with_guard =
a guard on a never pattern will never be run
.suggestion = remove this guard

ast_lowering_no_precise_captures_on_apit = `use<...>` precise capturing syntax not allowed on argument-position `impl Trait`
compiler-errors marked this conversation as resolved.
Show resolved Hide resolved

ast_lowering_previously_used_here = previously used here

ast_lowering_register1 = register `{$reg1_name}`
Expand Down
7 changes: 7 additions & 0 deletions compiler/rustc_ast_lowering/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -414,3 +414,10 @@ pub(crate) struct AsyncBoundOnlyForFnTraits {
#[primary_span]
pub span: Span,
}

#[derive(Diagnostic)]
#[diag(ast_lowering_no_precise_captures_on_apit)]
pub(crate) struct NoPreciseCapturesOnApit {
#[primary_span]
pub span: Span,
}
17 changes: 17 additions & 0 deletions compiler/rustc_ast_lowering/src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -385,4 +385,21 @@ impl<'a, 'hir> Visitor<'hir> for NodeCollector<'a, 'hir> {
fn visit_pattern_type_pattern(&mut self, p: &'hir hir::Pat<'hir>) {
self.visit_pat(p)
}

fn visit_precise_capturing_arg(
&mut self,
arg: &'hir PreciseCapturingArg<'hir>,
) -> Self::Result {
match arg {
PreciseCapturingArg::Lifetime(_) => {
// This is represented as a `Node::Lifetime`, intravisit will get to it below.
}
PreciseCapturingArg::Param(param) => self.insert(
param.ident.span,
param.hir_id,
Node::PreciseCapturingNonLifetimeArg(param),
),
}
intravisit::walk_precise_capturing_arg(self, arg);
}
}
136 changes: 97 additions & 39 deletions compiler/rustc_ast_lowering/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ use rustc_ast::{self as ast, *};
use rustc_ast_pretty::pprust;
use rustc_data_structures::captures::Captures;
use rustc_data_structures::fingerprint::Fingerprint;
use rustc_data_structures::fx::FxIndexSet;
use rustc_data_structures::sorted_map::SortedMap;
use rustc_data_structures::stable_hasher::{HashStable, StableHasher};
use rustc_data_structures::sync::Lrc;
Expand Down Expand Up @@ -1398,7 +1399,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
});
hir::TyKind::TraitObject(bounds, lifetime_bound, *kind)
}
TyKind::ImplTrait(def_node_id, bounds) => {
TyKind::ImplTrait(def_node_id, bounds, precise_capturing) => {
let span = t.span;
match itctx {
ImplTraitContext::OpaqueTy { origin, fn_kind } => self.lower_opaque_impl_trait(
Expand All @@ -1408,8 +1409,12 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
bounds,
fn_kind,
itctx,
precise_capturing.as_deref().map(|(args, _)| args.as_slice()),
),
ImplTraitContext::Universal => {
if let Some(&(_, span)) = precise_capturing.as_deref() {
self.tcx.dcx().emit_err(errors::NoPreciseCapturesOnApit { span });
};
let span = t.span;

// HACK: pprust breaks strings with newlines when the type
Expand Down Expand Up @@ -1520,6 +1525,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
bounds: &GenericBounds,
fn_kind: Option<FnDeclKind>,
itctx: ImplTraitContext,
precise_capturing_args: Option<&[PreciseCapturingArg]>,
) -> hir::TyKind<'hir> {
// Make sure we know that some funky desugaring has been going on here.
// This is a first: there is code in other places like for loop
Expand All @@ -1528,42 +1534,59 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
// frequently opened issues show.
let opaque_ty_span = self.mark_span_with_reason(DesugaringKind::OpaqueTy, span, None);

let captured_lifetimes_to_duplicate = match origin {
hir::OpaqueTyOrigin::TyAlias { .. } => {
// type alias impl trait and associated type position impl trait were
// decided to capture all in-scope lifetimes, which we collect for
// all opaques during resolution.
self.resolver
.take_extra_lifetime_params(opaque_ty_node_id)
.into_iter()
.map(|(ident, id, _)| Lifetime { id, ident })
let captured_lifetimes_to_duplicate =
if let Some(precise_capturing) = precise_capturing_args {
// We'll actually validate these later on; all we need is the list of
// lifetimes to duplicate during this portion of lowering.
precise_capturing
.iter()
.filter_map(|arg| match arg {
PreciseCapturingArg::Lifetime(lt) => Some(*lt),
PreciseCapturingArg::Arg(..) => None,
})
// Add in all the lifetimes mentioned in the bounds. We will error
// them out later, but capturing them here is important to make sure
// they actually get resolved in resolve_bound_vars.
oli-obk marked this conversation as resolved.
Show resolved Hide resolved
.chain(lifetime_collector::lifetimes_in_bounds(self.resolver, bounds))
.collect()
}
hir::OpaqueTyOrigin::FnReturn(..) => {
if matches!(
fn_kind.expect("expected RPITs to be lowered with a FnKind"),
FnDeclKind::Impl | FnDeclKind::Trait
) || self.tcx.features().lifetime_capture_rules_2024
|| span.at_least_rust_2024()
{
// return-position impl trait in trait was decided to capture all
// in-scope lifetimes, which we collect for all opaques during resolution.
self.resolver
.take_extra_lifetime_params(opaque_ty_node_id)
.into_iter()
.map(|(ident, id, _)| Lifetime { id, ident })
.collect()
} else {
// in fn return position, like the `fn test<'a>() -> impl Debug + 'a`
// example, we only need to duplicate lifetimes that appear in the
// bounds, since those are the only ones that are captured by the opaque.
lifetime_collector::lifetimes_in_bounds(self.resolver, bounds)
} else {
match origin {
hir::OpaqueTyOrigin::TyAlias { .. } => {
// type alias impl trait and associated type position impl trait were
// decided to capture all in-scope lifetimes, which we collect for
// all opaques during resolution.
self.resolver
.take_extra_lifetime_params(opaque_ty_node_id)
.into_iter()
.map(|(ident, id, _)| Lifetime { id, ident })
.collect()
}
hir::OpaqueTyOrigin::FnReturn(..) => {
if matches!(
fn_kind.expect("expected RPITs to be lowered with a FnKind"),
FnDeclKind::Impl | FnDeclKind::Trait
) || self.tcx.features().lifetime_capture_rules_2024
|| span.at_least_rust_2024()
{
// return-position impl trait in trait was decided to capture all
// in-scope lifetimes, which we collect for all opaques during resolution.
self.resolver
.take_extra_lifetime_params(opaque_ty_node_id)
.into_iter()
.map(|(ident, id, _)| Lifetime { id, ident })
.collect()
} else {
// in fn return position, like the `fn test<'a>() -> impl Debug + 'a`
// example, we only need to duplicate lifetimes that appear in the
// bounds, since those are the only ones that are captured by the opaque.
lifetime_collector::lifetimes_in_bounds(self.resolver, bounds)
}
}
hir::OpaqueTyOrigin::AsyncFn(..) => {
unreachable!("should be using `lower_async_fn_ret_ty`")
}
}
}
hir::OpaqueTyOrigin::AsyncFn(..) => {
unreachable!("should be using `lower_async_fn_ret_ty`")
}
};
};
debug!(?captured_lifetimes_to_duplicate);

self.lower_opaque_inner(
Expand All @@ -1573,6 +1596,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
captured_lifetimes_to_duplicate,
span,
opaque_ty_span,
precise_capturing_args,
|this| this.lower_param_bounds(bounds, itctx),
)
}
Expand All @@ -1582,9 +1606,10 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
opaque_ty_node_id: NodeId,
origin: hir::OpaqueTyOrigin,
in_trait: bool,
captured_lifetimes_to_duplicate: Vec<Lifetime>,
captured_lifetimes_to_duplicate: FxIndexSet<Lifetime>,
oli-obk marked this conversation as resolved.
Show resolved Hide resolved
span: Span,
opaque_ty_span: Span,
precise_capturing_args: Option<&[PreciseCapturingArg]>,
lower_item_bounds: impl FnOnce(&mut Self) -> &'hir [hir::GenericBound<'hir>],
) -> hir::TyKind<'hir> {
let opaque_ty_def_id = self.create_def(
Expand Down Expand Up @@ -1671,8 +1696,15 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
// Install the remapping from old to new (if any). This makes sure that
// any lifetimes that would have resolved to the def-id of captured
// lifetimes are remapped to the new *synthetic* lifetimes of the opaque.
let bounds = this
.with_remapping(captured_to_synthesized_mapping, |this| lower_item_bounds(this));
let (bounds, precise_capturing_args) =
this.with_remapping(captured_to_synthesized_mapping, |this| {
(
lower_item_bounds(this),
precise_capturing_args.map(|precise_capturing| {
this.lower_precise_capturing_args(precise_capturing)
}),
)
});

let generic_params =
this.arena.alloc_from_iter(synthesized_lifetime_definitions.iter().map(
Expand Down Expand Up @@ -1717,6 +1749,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
origin,
lifetime_mapping,
in_trait,
precise_capturing_args,
};

// Generate an `type Foo = impl Trait;` declaration.
Expand Down Expand Up @@ -1749,6 +1782,30 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
)
}

fn lower_precise_capturing_args(
&mut self,
precise_capturing_args: &[PreciseCapturingArg],
) -> &'hir [hir::PreciseCapturingArg<'hir>] {
self.arena.alloc_from_iter(precise_capturing_args.iter().map(|arg| match arg {
PreciseCapturingArg::Lifetime(lt) => {
hir::PreciseCapturingArg::Lifetime(self.lower_lifetime(lt))
}
PreciseCapturingArg::Arg(path, id) => {
let [segment] = path.segments.as_slice() else {
panic!();
};
let res = self.resolver.get_partial_res(*id).map_or(Res::Err, |partial_res| {
partial_res.full_res().expect("no partial res expected for precise capture arg")
});
hir::PreciseCapturingArg::Param(hir::PreciseCapturingNonLifetimeArg {
hir_id: self.lower_node_id(*id),
ident: self.lower_ident(segment.ident),
res: self.lower_res(res),
})
}
}))
}

fn lower_fn_params_to_names(&mut self, decl: &FnDecl) -> &'hir [Ident] {
self.arena.alloc_from_iter(decl.inputs.iter().map(|param| match param.pat.kind {
PatKind::Ident(_, ident, _) => self.lower_ident(ident),
Expand Down Expand Up @@ -1889,7 +1946,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
let opaque_ty_span =
self.mark_span_with_reason(DesugaringKind::Async, span, allowed_features);

let captured_lifetimes: Vec<_> = self
let captured_lifetimes = self
.resolver
.take_extra_lifetime_params(opaque_ty_node_id)
.into_iter()
Expand All @@ -1903,6 +1960,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
captured_lifetimes,
span,
opaque_ty_span,
None,
|this| {
let bound = this.lower_coroutine_fn_output_type_to_bound(
output,
Expand Down
Loading
Loading