From fbd39f9c0990740483de182696f028eb5cc18d9d Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Tue, 28 Nov 2023 18:18:19 +0000 Subject: [PATCH] Implement `async gen` blocks --- compiler/rustc_ast/src/ast.rs | 2 + compiler/rustc_ast_lowering/src/expr.rs | 138 ++++++++++++++++-- .../src/diagnostics/conflict_errors.rs | 5 + .../src/diagnostics/region_name.rs | 15 ++ .../src/debuginfo/type_names.rs | 3 + compiler/rustc_codegen_ssa/src/mir/locals.rs | 6 +- compiler/rustc_hir/src/hir.rs | 25 ++-- compiler/rustc_hir/src/lang_items.rs | 5 + compiler/rustc_hir_typeck/src/check.rs | 4 +- compiler/rustc_middle/src/mir/terminator.rs | 8 + compiler/rustc_middle/src/traits/select.rs | 8 +- compiler/rustc_middle/src/ty/context.rs | 7 +- compiler/rustc_middle/src/ty/util.rs | 2 + compiler/rustc_mir_transform/src/coroutine.rs | 30 +++- compiler/rustc_parse/src/parser/expr.rs | 23 +-- compiler/rustc_parse/src/parser/item.rs | 6 +- .../rustc_smir/src/rustc_smir/convert/mod.rs | 1 + compiler/rustc_span/src/symbol.rs | 5 + .../src/solve/assembly/mod.rs | 7 + .../src/solve/project_goals/mod.rs | 34 +++++ .../src/solve/trait_goals.rs | 24 +++ .../src/traits/error_reporting/suggestions.rs | 21 ++- .../error_reporting/type_err_ctxt_ext.rs | 3 + .../src/traits/project.rs | 67 ++++++++- .../src/traits/select/candidate_assembly.rs | 30 ++++ .../src/traits/select/confirmation.rs | 35 +++++ .../src/traits/select/mod.rs | 6 + .../rustc_trait_selection/src/traits/util.rs | 11 ++ compiler/rustc_ty_utils/src/abi.rs | 30 +++- compiler/rustc_ty_utils/src/instance.rs | 15 ++ library/core/src/async_iter/async_iter.rs | 23 +++ 31 files changed, 549 insertions(+), 50 deletions(-) diff --git a/compiler/rustc_ast/src/ast.rs b/compiler/rustc_ast/src/ast.rs index 10776f31c07bf..4108840274f37 100644 --- a/compiler/rustc_ast/src/ast.rs +++ b/compiler/rustc_ast/src/ast.rs @@ -1501,6 +1501,7 @@ pub enum ExprKind { pub enum GenBlockKind { Async, Gen, + AsyncGen, } impl fmt::Display for GenBlockKind { @@ -1514,6 +1515,7 @@ impl GenBlockKind { match self { GenBlockKind::Async => "async", GenBlockKind::Gen => "gen", + GenBlockKind::AsyncGen => "async gen", } } } diff --git a/compiler/rustc_ast_lowering/src/expr.rs b/compiler/rustc_ast_lowering/src/expr.rs index 43b345292c8b3..a54ccc87b65b6 100644 --- a/compiler/rustc_ast_lowering/src/expr.rs +++ b/compiler/rustc_ast_lowering/src/expr.rs @@ -325,6 +325,15 @@ impl<'hir> LoweringContext<'_, 'hir> { hir::CoroutineSource::Block, |this| this.with_new_scopes(e.span, |this| this.lower_block_expr(block)), ), + ExprKind::Gen(capture_clause, block, GenBlockKind::AsyncGen) => self + .make_async_gen_expr( + *capture_clause, + e.id, + None, + e.span, + hir::CoroutineSource::Block, + |this| this.with_new_scopes(e.span, |this| this.lower_block_expr(block)), + ), ExprKind::Yield(opt_expr) => self.lower_expr_yield(e.span, opt_expr.as_deref()), ExprKind::Err => hir::ExprKind::Err( self.tcx.sess.delay_span_bug(e.span, "lowered ExprKind::Err"), @@ -726,6 +735,84 @@ impl<'hir> LoweringContext<'_, 'hir> { })) } + /// Lower a `async gen` construct to a generator that implements `AsyncIterator`. + /// + /// This results in: + /// + /// ```text + /// static move? |_task_context| -> () { + /// + /// } + /// ``` + pub(super) fn make_async_gen_expr( + &mut self, + capture_clause: CaptureBy, + closure_node_id: NodeId, + _yield_ty: Option>, + span: Span, + async_coroutine_source: hir::CoroutineSource, + body: impl FnOnce(&mut Self) -> hir::Expr<'hir>, + ) -> hir::ExprKind<'hir> { + let output = hir::FnRetTy::DefaultReturn(self.lower_span(span)); + + // Resume argument type: `ResumeTy` + let unstable_span = + self.mark_span_with_reason(DesugaringKind::Async, span, self.allow_gen_future.clone()); + let resume_ty = hir::QPath::LangItem(hir::LangItem::ResumeTy, unstable_span); + let input_ty = hir::Ty { + hir_id: self.next_id(), + kind: hir::TyKind::Path(resume_ty), + span: unstable_span, + }; + + // The closure/coroutine `FnDecl` takes a single (resume) argument of type `input_ty`. + let fn_decl = self.arena.alloc(hir::FnDecl { + inputs: arena_vec![self; input_ty], + output, + c_variadic: false, + implicit_self: hir::ImplicitSelfKind::None, + lifetime_elision_allowed: false, + }); + + // Lower the argument pattern/ident. The ident is used again in the `.await` lowering. + let (pat, task_context_hid) = self.pat_ident_binding_mode( + span, + Ident::with_dummy_span(sym::_task_context), + hir::BindingAnnotation::MUT, + ); + let param = hir::Param { + hir_id: self.next_id(), + pat, + ty_span: self.lower_span(span), + span: self.lower_span(span), + }; + let params = arena_vec![self; param]; + + let body = self.lower_body(move |this| { + this.coroutine_kind = Some(hir::CoroutineKind::AsyncGen(async_coroutine_source)); + + let old_ctx = this.task_context; + this.task_context = Some(task_context_hid); + let res = body(this); + this.task_context = old_ctx; + (params, res) + }); + + // `static |_task_context| -> { body }`: + hir::ExprKind::Closure(self.arena.alloc(hir::Closure { + def_id: self.local_def_id(closure_node_id), + binder: hir::ClosureBinder::Default, + capture_clause, + bound_generic_params: &[], + fn_decl, + body, + fn_decl_span: self.lower_span(span), + fn_arg_span: None, + movability: Some(hir::Movability::Static), + constness: hir::Constness::NotConst, + })) + } + /// Forwards a possible `#[track_caller]` annotation from `outer_hir_id` to /// `inner_hir_id` in case the `async_fn_track_caller` feature is enabled. pub(super) fn maybe_forward_track_caller( @@ -775,15 +862,18 @@ impl<'hir> LoweringContext<'_, 'hir> { /// ``` fn lower_expr_await(&mut self, await_kw_span: Span, expr: &Expr) -> hir::ExprKind<'hir> { let full_span = expr.span.to(await_kw_span); - match self.coroutine_kind { - Some(hir::CoroutineKind::Async(_)) => {} + + let is_async_gen = match self.coroutine_kind { + Some(hir::CoroutineKind::Async(_)) => false, + Some(hir::CoroutineKind::AsyncGen(_)) => true, Some(hir::CoroutineKind::Coroutine) | Some(hir::CoroutineKind::Gen(_)) | None => { return hir::ExprKind::Err(self.tcx.sess.emit_err(AwaitOnlyInAsyncFnAndBlocks { await_kw_span, item_span: self.current_item, })); } - } + }; + let span = self.mark_span_with_reason(DesugaringKind::Await, await_kw_span, None); let gen_future_span = self.mark_span_with_reason( DesugaringKind::Await, @@ -872,12 +962,19 @@ impl<'hir> LoweringContext<'_, 'hir> { self.stmt_expr(span, match_expr) }; - // task_context = yield (); + // Depending on `async` of `async gen`: + // async - task_context = yield (); + // async gen - task_context = yield async_gen_pending(); let yield_stmt = { - let unit = self.expr_unit(span); + let yielded = if is_async_gen { + self.expr_call_lang_item_fn(span, hir::LangItem::AsyncGenPending, &[]) + } else { + self.expr_unit(span) + }; + let yield_expr = self.expr( span, - hir::ExprKind::Yield(unit, hir::YieldSource::Await { expr: Some(expr_hir_id) }), + hir::ExprKind::Yield(yielded, hir::YieldSource::Await { expr: Some(expr_hir_id) }), ); let yield_expr = self.arena.alloc(yield_expr); @@ -987,7 +1084,11 @@ impl<'hir> LoweringContext<'_, 'hir> { } Some(movability) } - Some(hir::CoroutineKind::Gen(_)) | Some(hir::CoroutineKind::Async(_)) => { + Some( + hir::CoroutineKind::Gen(_) + | hir::CoroutineKind::Async(_) + | hir::CoroutineKind::AsyncGen(_), + ) => { panic!("non-`async`/`gen` closure body turned `async`/`gen` during lowering"); } None => { @@ -1494,8 +1595,9 @@ impl<'hir> LoweringContext<'_, 'hir> { } fn lower_expr_yield(&mut self, span: Span, opt_expr: Option<&Expr>) -> hir::ExprKind<'hir> { - match self.coroutine_kind { - Some(hir::CoroutineKind::Gen(_)) => {} + let is_async_gen = match self.coroutine_kind { + Some(hir::CoroutineKind::Gen(_)) => false, + Some(hir::CoroutineKind::AsyncGen(_)) => true, Some(hir::CoroutineKind::Async(_)) => { return hir::ExprKind::Err( self.tcx.sess.emit_err(AsyncCoroutinesNotSupported { span }), @@ -1511,14 +1613,24 @@ impl<'hir> LoweringContext<'_, 'hir> { ) .emit(); } - self.coroutine_kind = Some(hir::CoroutineKind::Coroutine) + self.coroutine_kind = Some(hir::CoroutineKind::Coroutine); + false } - } + }; - let expr = + let mut yielded = opt_expr.as_ref().map(|x| self.lower_expr(x)).unwrap_or_else(|| self.expr_unit(span)); - hir::ExprKind::Yield(expr, hir::YieldSource::Yield) + if is_async_gen { + // yield async_gen_ready($expr); + yielded = self.expr_call_lang_item_fn( + span, + hir::LangItem::AsyncGenReady, + std::slice::from_ref(yielded), + ); + } + + hir::ExprKind::Yield(yielded, hir::YieldSource::Yield) } /// Desugar `ExprForLoop` from: `[opt_ident]: for in ` into: diff --git a/compiler/rustc_borrowck/src/diagnostics/conflict_errors.rs b/compiler/rustc_borrowck/src/diagnostics/conflict_errors.rs index 9225f19876dee..53970d762ef49 100644 --- a/compiler/rustc_borrowck/src/diagnostics/conflict_errors.rs +++ b/compiler/rustc_borrowck/src/diagnostics/conflict_errors.rs @@ -2519,6 +2519,11 @@ impl<'cx, 'tcx> MirBorrowckCtxt<'cx, 'tcx> { CoroutineSource::Closure => "gen closure", _ => bug!("gen block/closure expected, but gen function found."), }, + CoroutineKind::AsyncGen(kind) => match kind { + CoroutineSource::Block => "async gen block", + CoroutineSource::Closure => "async gen closure", + _ => bug!("gen block/closure expected, but gen function found."), + }, CoroutineKind::Async(async_kind) => match async_kind { CoroutineSource::Block => "async block", CoroutineSource::Closure => "async closure", diff --git a/compiler/rustc_borrowck/src/diagnostics/region_name.rs b/compiler/rustc_borrowck/src/diagnostics/region_name.rs index 730b65898bc3a..155698fbbf043 100644 --- a/compiler/rustc_borrowck/src/diagnostics/region_name.rs +++ b/compiler/rustc_borrowck/src/diagnostics/region_name.rs @@ -715,6 +715,21 @@ impl<'tcx> MirBorrowckCtxt<'_, 'tcx> { " of gen function" } }, + + Some(hir::CoroutineKind::AsyncGen(gen)) => match gen { + hir::CoroutineSource::Block => " of async gen block", + hir::CoroutineSource::Closure => " of async gen closure", + hir::CoroutineSource::Fn => { + let parent_item = + hir.get_by_def_id(hir.get_parent_item(mir_hir_id).def_id); + let output = &parent_item + .fn_decl() + .expect("coroutine lowered from async gen fn should be in fn") + .output; + span = output.span(); + " of async gen function" + } + }, Some(hir::CoroutineKind::Coroutine) => " of coroutine", None => " of closure", }; diff --git a/compiler/rustc_codegen_ssa/src/debuginfo/type_names.rs b/compiler/rustc_codegen_ssa/src/debuginfo/type_names.rs index 839cc4dabe3b7..3da38258a323b 100644 --- a/compiler/rustc_codegen_ssa/src/debuginfo/type_names.rs +++ b/compiler/rustc_codegen_ssa/src/debuginfo/type_names.rs @@ -566,6 +566,9 @@ fn coroutine_kind_label(coroutine_kind: Option) -> &'static str { Some(CoroutineKind::Async(CoroutineSource::Block)) => "async_block", Some(CoroutineKind::Async(CoroutineSource::Closure)) => "async_closure", Some(CoroutineKind::Async(CoroutineSource::Fn)) => "async_fn", + Some(CoroutineKind::AsyncGen(CoroutineSource::Block)) => "async_gen_block", + Some(CoroutineKind::AsyncGen(CoroutineSource::Closure)) => "async_gen_closure", + Some(CoroutineKind::AsyncGen(CoroutineSource::Fn)) => "async_gen_fn", Some(CoroutineKind::Coroutine) => "coroutine", None => "closure", } diff --git a/compiler/rustc_codegen_ssa/src/mir/locals.rs b/compiler/rustc_codegen_ssa/src/mir/locals.rs index 378c540132207..7db260c9f5bd8 100644 --- a/compiler/rustc_codegen_ssa/src/mir/locals.rs +++ b/compiler/rustc_codegen_ssa/src/mir/locals.rs @@ -43,7 +43,11 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { let local = mir::Local::from_usize(local); let expected_ty = self.monomorphize(self.mir.local_decls[local].ty); if expected_ty != op.layout.ty { - warn!("Unexpected initial operand type. See the issues/114858"); + warn!( + "Unexpected initial operand type: expected {expected_ty:?}, found {:?}.\ + See .", + op.layout.ty + ); } } } diff --git a/compiler/rustc_hir/src/hir.rs b/compiler/rustc_hir/src/hir.rs index d2b83d0eb00fa..6da1647fefaae 100644 --- a/compiler/rustc_hir/src/hir.rs +++ b/compiler/rustc_hir/src/hir.rs @@ -1484,12 +1484,16 @@ impl<'hir> Body<'hir> { /// The type of source expression that caused this coroutine to be created. #[derive(Clone, PartialEq, Eq, Debug, Copy, Hash, HashStable_Generic, Encodable, Decodable)] pub enum CoroutineKind { - /// An explicit `async` block or the body of an async function. + /// An explicit `async` block or the body of an `async` function. Async(CoroutineSource), /// An explicit `gen` block or the body of a `gen` function. Gen(CoroutineSource), + /// An explicit `async gen` block or the body of an `async gen` function, + /// which is able to both `yield` and `.await`. + AsyncGen(CoroutineSource), + /// A coroutine literal created via a `yield` inside a closure. Coroutine, } @@ -1514,6 +1518,14 @@ impl fmt::Display for CoroutineKind { } k.fmt(f) } + CoroutineKind::AsyncGen(k) => { + if f.alternate() { + f.write_str("`async gen` ")?; + } else { + f.write_str("async gen ")? + } + k.fmt(f) + } } } } @@ -2209,17 +2221,6 @@ impl fmt::Display for YieldSource { } } -impl From for YieldSource { - fn from(kind: CoroutineKind) -> Self { - match kind { - // Guess based on the kind of the current coroutine. - CoroutineKind::Coroutine => Self::Yield, - CoroutineKind::Async(_) => Self::Await { expr: None }, - CoroutineKind::Gen(_) => Self::Yield, - } - } -} - // N.B., if you change this, you'll probably want to change the corresponding // type structure in middle/ty.rs as well. #[derive(Debug, Clone, Copy, HashStable_Generic)] diff --git a/compiler/rustc_hir/src/lang_items.rs b/compiler/rustc_hir/src/lang_items.rs index 60f1449c177cb..10ccdf0efa883 100644 --- a/compiler/rustc_hir/src/lang_items.rs +++ b/compiler/rustc_hir/src/lang_items.rs @@ -212,6 +212,7 @@ language_item_table! { Iterator, sym::iterator, iterator_trait, Target::Trait, GenericRequirement::Exact(0); Future, sym::future_trait, future_trait, Target::Trait, GenericRequirement::Exact(0); + AsyncIterator, sym::async_iterator, async_iterator_trait, Target::Trait, GenericRequirement::Exact(0); CoroutineState, sym::coroutine_state, coroutine_state, Target::Enum, GenericRequirement::None; Coroutine, sym::coroutine, coroutine_trait, Target::Trait, GenericRequirement::Minimum(1); Unpin, sym::unpin, unpin_trait, Target::Trait, GenericRequirement::None; @@ -294,6 +295,10 @@ language_item_table! { PollReady, sym::Ready, poll_ready_variant, Target::Variant, GenericRequirement::None; PollPending, sym::Pending, poll_pending_variant, Target::Variant, GenericRequirement::None; + AsyncGenReady, sym::AsyncGenReady, async_gen_ready, Target::Method(MethodKind::Inherent), GenericRequirement::Exact(1); + AsyncGenPending, sym::AsyncGenPending, async_gen_pending, Target::Method(MethodKind::Inherent), GenericRequirement::Exact(1); + AsyncGenFinished, sym::AsyncGenFinished, async_gen_finished, Target::AssocConst, GenericRequirement::Exact(1); + // FIXME(swatinem): the following lang items are used for async lowering and // should become obsolete eventually. ResumeTy, sym::ResumeTy, resume_ty, Target::Struct, GenericRequirement::None; diff --git a/compiler/rustc_hir_typeck/src/check.rs b/compiler/rustc_hir_typeck/src/check.rs index 0cd1ae2dbfe5c..91c96d6f76fd8 100644 --- a/compiler/rustc_hir_typeck/src/check.rs +++ b/compiler/rustc_hir_typeck/src/check.rs @@ -59,7 +59,9 @@ pub(super) fn check_fn<'a, 'tcx>( && can_be_coroutine.is_some() { let yield_ty = match kind { - hir::CoroutineKind::Gen(..) | hir::CoroutineKind::Coroutine => { + hir::CoroutineKind::Gen(..) + | hir::CoroutineKind::AsyncGen(..) + | hir::CoroutineKind::Coroutine => { let yield_ty = fcx.next_ty_var(TypeVariableOrigin { kind: TypeVariableOriginKind::TypeInference, span, diff --git a/compiler/rustc_middle/src/mir/terminator.rs b/compiler/rustc_middle/src/mir/terminator.rs index 9a6ac6ff57a4e..aa4cb36c5cef2 100644 --- a/compiler/rustc_middle/src/mir/terminator.rs +++ b/compiler/rustc_middle/src/mir/terminator.rs @@ -150,11 +150,17 @@ impl AssertKind { RemainderByZero(_) => "attempt to calculate the remainder with a divisor of zero", ResumedAfterReturn(CoroutineKind::Coroutine) => "coroutine resumed after completion", ResumedAfterReturn(CoroutineKind::Async(_)) => "`async fn` resumed after completion", + ResumedAfterReturn(CoroutineKind::AsyncGen(_)) => { + "`async gen fn` resumed after completion" + } ResumedAfterReturn(CoroutineKind::Gen(_)) => { "`gen fn` should just keep returning `None` after completion" } ResumedAfterPanic(CoroutineKind::Coroutine) => "coroutine resumed after panicking", ResumedAfterPanic(CoroutineKind::Async(_)) => "`async fn` resumed after panicking", + ResumedAfterPanic(CoroutineKind::AsyncGen(_)) => { + "`async gen fn` resumed after panicking" + } ResumedAfterPanic(CoroutineKind::Gen(_)) => { "`gen fn` should just keep returning `None` after panicking" } @@ -245,6 +251,7 @@ impl AssertKind { DivisionByZero(_) => middle_assert_divide_by_zero, RemainderByZero(_) => middle_assert_remainder_by_zero, ResumedAfterReturn(CoroutineKind::Async(_)) => middle_assert_async_resume_after_return, + ResumedAfterReturn(CoroutineKind::AsyncGen(_)) => todo!(), ResumedAfterReturn(CoroutineKind::Gen(_)) => { bug!("gen blocks can be resumed after they return and will keep returning `None`") } @@ -252,6 +259,7 @@ impl AssertKind { middle_assert_coroutine_resume_after_return } ResumedAfterPanic(CoroutineKind::Async(_)) => middle_assert_async_resume_after_panic, + ResumedAfterPanic(CoroutineKind::AsyncGen(_)) => todo!(), ResumedAfterPanic(CoroutineKind::Gen(_)) => middle_assert_gen_resume_after_panic, ResumedAfterPanic(CoroutineKind::Coroutine) => { middle_assert_coroutine_resume_after_panic diff --git a/compiler/rustc_middle/src/traits/select.rs b/compiler/rustc_middle/src/traits/select.rs index f33421bbaa6c1..a39fffa7486d5 100644 --- a/compiler/rustc_middle/src/traits/select.rs +++ b/compiler/rustc_middle/src/traits/select.rs @@ -144,10 +144,14 @@ pub enum SelectionCandidate<'tcx> { /// generated for an async construct. FutureCandidate, - /// Implementation of an `Iterator` trait by one of the generator types - /// generated for a gen construct. + /// Implementation of an `Iterator` trait by one of the coroutine types + /// generated for a `gen` construct. IteratorCandidate, + /// Implementation of an `AsyncIterator` trait by one of the coroutine types + /// generated for a `async gen` construct. + AsyncIteratorCandidate, + /// Implementation of a `Fn`-family trait by one of the anonymous /// types generated for a fn pointer type (e.g., `fn(int) -> int`) FnPointerCandidate { diff --git a/compiler/rustc_middle/src/ty/context.rs b/compiler/rustc_middle/src/ty/context.rs index bb02c6e0c7856..6e6d08f1dc15e 100644 --- a/compiler/rustc_middle/src/ty/context.rs +++ b/compiler/rustc_middle/src/ty/context.rs @@ -822,11 +822,16 @@ impl<'tcx> TyCtxt<'tcx> { matches!(self.coroutine_kind(def_id), Some(hir::CoroutineKind::Coroutine)) } - /// Returns `true` if the node pointed to by `def_id` is a coroutine for a gen construct. + /// Returns `true` if the node pointed to by `def_id` is a coroutine for a `gen` construct. pub fn coroutine_is_gen(self, def_id: DefId) -> bool { matches!(self.coroutine_kind(def_id), Some(hir::CoroutineKind::Gen(_))) } + /// Returns `true` if the node pointed to by `def_id` is a coroutine for a `async gen` construct. + pub fn coroutine_is_async_gen(self, def_id: DefId) -> bool { + matches!(self.coroutine_kind(def_id), Some(hir::CoroutineKind::AsyncGen(_))) + } + pub fn stability(self) -> &'tcx stability::Index { self.stability_index(()) } diff --git a/compiler/rustc_middle/src/ty/util.rs b/compiler/rustc_middle/src/ty/util.rs index cbf1a9900d9aa..c95813780affd 100644 --- a/compiler/rustc_middle/src/ty/util.rs +++ b/compiler/rustc_middle/src/ty/util.rs @@ -732,6 +732,7 @@ impl<'tcx> TyCtxt<'tcx> { DefKind::Closure if let Some(coroutine_kind) = self.coroutine_kind(def_id) => { match coroutine_kind { rustc_hir::CoroutineKind::Async(..) => "async closure", + rustc_hir::CoroutineKind::AsyncGen(..) => "async gen closure", rustc_hir::CoroutineKind::Coroutine => "coroutine", rustc_hir::CoroutineKind::Gen(..) => "gen closure", } @@ -752,6 +753,7 @@ impl<'tcx> TyCtxt<'tcx> { DefKind::Closure if let Some(coroutine_kind) = self.coroutine_kind(def_id) => { match coroutine_kind { rustc_hir::CoroutineKind::Async(..) => "an", + rustc_hir::CoroutineKind::AsyncGen(..) => "an", rustc_hir::CoroutineKind::Coroutine => "a", rustc_hir::CoroutineKind::Gen(..) => "a", } diff --git a/compiler/rustc_mir_transform/src/coroutine.rs b/compiler/rustc_mir_transform/src/coroutine.rs index 0c865c06ecc2b..57cad9b52e4e5 100644 --- a/compiler/rustc_mir_transform/src/coroutine.rs +++ b/compiler/rustc_mir_transform/src/coroutine.rs @@ -356,6 +356,26 @@ impl<'tcx> TransformVisitor<'tcx> { ) } } + CoroutineKind::AsyncGen(_) => { + if is_return { + let ty::Adt(_poll_adt, args) = *self.old_yield_ty.kind() else { bug!() }; + let ty::Adt(_option_adt, args) = *args.type_at(0).kind() else { bug!() }; + let yield_ty = args.type_at(0); + Rvalue::Use(Operand::Constant(Box::new(ConstOperand { + span: source_info.span, + const_: Const::Unevaluated( + UnevaluatedConst::new( + self.tcx.require_lang_item(LangItem::AsyncGenFinished, None), + self.tcx.mk_args(&[yield_ty.into()]), + ), + self.old_yield_ty, + ), + user_ty: None, + }))) + } else { + Rvalue::Use(val) + } + } CoroutineKind::Coroutine => { let coroutine_state_def_id = self.tcx.require_lang_item(LangItem::CoroutineState, None); @@ -1383,7 +1403,8 @@ fn create_coroutine_resume_function<'tcx>( if can_return { let block = match coroutine_kind { - CoroutineKind::Async(_) | CoroutineKind::Coroutine => { + // FIXME(gen_blocks): Should `async gen` yield `None` when resumed once again? + CoroutineKind::Async(_) | CoroutineKind::AsyncGen(_) | CoroutineKind::Coroutine => { insert_panic_block(tcx, body, ResumedAfterReturn(coroutine_kind)) } CoroutineKind::Gen(_) => transform.insert_none_ret_block(body), @@ -1570,6 +1591,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform { }; let is_async_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::Async(_))); + let is_async_gen_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::AsyncGen(_))); let is_gen_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::Gen(_))); let new_ret_ty = match body.coroutine_kind().unwrap() { CoroutineKind::Async(_) => { @@ -1586,6 +1608,10 @@ impl<'tcx> MirPass<'tcx> for StateTransform { let option_args = tcx.mk_args(&[old_yield_ty.into()]); Ty::new_adt(tcx, option_adt_ref, option_args) } + CoroutineKind::AsyncGen(_) => { + // The yield ty is already `Poll>` + old_yield_ty + } CoroutineKind::Coroutine => { // Compute CoroutineState let state_did = tcx.require_lang_item(LangItem::CoroutineState, None); @@ -1600,7 +1626,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform { let old_ret_local = replace_local(RETURN_PLACE, new_ret_ty, body, tcx); // Replace all occurrences of `ResumeTy` with `&mut Context<'_>` within async bodies. - if is_async_kind { + if is_async_kind || is_async_gen_kind { transform_async_context(tcx, body); } diff --git a/compiler/rustc_parse/src/parser/expr.rs b/compiler/rustc_parse/src/parser/expr.rs index 79567eb389868..877680e194860 100644 --- a/compiler/rustc_parse/src/parser/expr.rs +++ b/compiler/rustc_parse/src/parser/expr.rs @@ -1441,8 +1441,9 @@ impl<'a> Parser<'a> { } else if this.token.uninterpolated_span().at_least_rust_2018() { // `Span:.at_least_rust_2018()` is somewhat expensive; don't get it repeatedly. if this.check_keyword(kw::Async) { - if this.is_gen_block(kw::Async) { - // Check for `async {` and `async move {`. + if this.is_gen_block(kw::Async, 0) || this.is_gen_block(kw::Gen, 1) { + // Check for `async {` and `async move {`, + // or `async gen {` and `async gen move {`. this.parse_gen_block() } else { this.parse_expr_closure() @@ -1450,7 +1451,7 @@ impl<'a> Parser<'a> { } else if this.eat_keyword(kw::Await) { this.recover_incorrect_await_syntax(lo, this.prev_token.span) } else if this.token.uninterpolated_span().at_least_rust_2024() { - if this.is_gen_block(kw::Gen) { + if this.is_gen_block(kw::Gen, 0) { this.parse_gen_block() } else { this.parse_expr_lit() @@ -3087,7 +3088,7 @@ impl<'a> Parser<'a> { fn parse_gen_block(&mut self) -> PResult<'a, P> { let lo = self.token.span; let kind = if self.eat_keyword(kw::Async) { - GenBlockKind::Async + if self.eat_keyword(kw::Gen) { GenBlockKind::AsyncGen } else { GenBlockKind::Async } } else { assert!(self.eat_keyword(kw::Gen)); self.sess.gated_spans.gate(sym::gen_blocks, lo.to(self.token.span)); @@ -3099,22 +3100,26 @@ impl<'a> Parser<'a> { Ok(self.mk_expr_with_attrs(lo.to(self.prev_token.span), kind, attrs)) } - fn is_gen_block(&self, kw: Symbol) -> bool { - self.token.is_keyword(kw) + fn is_gen_block(&self, kw: Symbol, lookahead: usize) -> bool { + self.is_keyword_ahead(lookahead, &[kw]) && (( // `async move {` - self.is_keyword_ahead(1, &[kw::Move]) - && self.look_ahead(2, |t| { + self.is_keyword_ahead(lookahead + 1, &[kw::Move]) + && self.look_ahead(lookahead + 2, |t| { *t == token::OpenDelim(Delimiter::Brace) || t.is_whole_block() }) ) || ( // `async {` - self.look_ahead(1, |t| { + self.look_ahead(lookahead + 1, |t| { *t == token::OpenDelim(Delimiter::Brace) || t.is_whole_block() }) )) } + pub(super) fn is_async_gen_block(&self) -> bool { + self.token.is_keyword(kw::Async) && self.is_gen_block(kw::Gen, 1) + } + fn is_certainly_not_a_block(&self) -> bool { self.look_ahead(1, |t| t.is_ident()) && ( diff --git a/compiler/rustc_parse/src/parser/item.rs b/compiler/rustc_parse/src/parser/item.rs index 23886b1208bb4..148ec4c3087b9 100644 --- a/compiler/rustc_parse/src/parser/item.rs +++ b/compiler/rustc_parse/src/parser/item.rs @@ -2330,8 +2330,10 @@ impl<'a> Parser<'a> { || case == Case::Insensitive && t.is_non_raw_ident_where(|i| quals.iter().any(|qual| qual.as_str() == i.name.as_str().to_lowercase())) ) - // Rule out unsafe extern block. - && !self.is_unsafe_foreign_mod()) + // Rule out `unsafe extern {`. + && !self.is_unsafe_foreign_mod() + // Rule out `async gen {` and `async gen move {` + && !self.is_async_gen_block()) }) // `extern ABI fn` || self.check_keyword_case(kw::Extern, case) diff --git a/compiler/rustc_smir/src/rustc_smir/convert/mod.rs b/compiler/rustc_smir/src/rustc_smir/convert/mod.rs index edb32df305c8e..af884f0c80790 100644 --- a/compiler/rustc_smir/src/rustc_smir/convert/mod.rs +++ b/compiler/rustc_smir/src/rustc_smir/convert/mod.rs @@ -63,6 +63,7 @@ impl<'tcx> Stable<'tcx> for rustc_hir::CoroutineKind { stable_mir::mir::CoroutineKind::Gen(source.stable(tables)) } CoroutineKind::Coroutine => stable_mir::mir::CoroutineKind::Coroutine, + CoroutineKind::AsyncGen(_) => todo!(), } } } diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs index 40b0387424221..adbcd7f91b72a 100644 --- a/compiler/rustc_span/src/symbol.rs +++ b/compiler/rustc_span/src/symbol.rs @@ -141,6 +141,9 @@ symbols! { AssertParamIsClone, AssertParamIsCopy, AssertParamIsEq, + AsyncGenFinished, + AsyncGenPending, + AsyncGenReady, AtomicBool, AtomicI128, AtomicI16, @@ -436,6 +439,7 @@ symbols! { async_closure, async_fn_in_trait, async_fn_track_caller, + async_iterator, atomic, atomic_mod, atomics, @@ -1216,6 +1220,7 @@ symbols! { pointer, pointer_like, poll, + poll_next, position, post_dash_lto: "post-lto", powerpc_target_feature, diff --git a/compiler/rustc_trait_selection/src/solve/assembly/mod.rs b/compiler/rustc_trait_selection/src/solve/assembly/mod.rs index eb30b4cd85ea4..cf6926af41896 100644 --- a/compiler/rustc_trait_selection/src/solve/assembly/mod.rs +++ b/compiler/rustc_trait_selection/src/solve/assembly/mod.rs @@ -207,6 +207,11 @@ pub(super) trait GoalKind<'tcx>: goal: Goal<'tcx, Self>, ) -> QueryResult<'tcx>; + fn consider_builtin_async_iterator_candidate( + ecx: &mut EvalCtxt<'_, 'tcx>, + goal: Goal<'tcx, Self>, + ) -> QueryResult<'tcx>; + /// A coroutine (that doesn't come from an `async` or `gen` desugaring) is known to /// implement `Coroutine`, given the resume, yield, /// and return types of the coroutine computed during type-checking. @@ -565,6 +570,8 @@ impl<'tcx> EvalCtxt<'_, 'tcx> { G::consider_builtin_future_candidate(self, goal) } else if lang_items.iterator_trait() == Some(trait_def_id) { G::consider_builtin_iterator_candidate(self, goal) + } else if lang_items.async_iterator_trait() == Some(trait_def_id) { + G::consider_builtin_async_iterator_candidate(self, goal) } else if lang_items.coroutine_trait() == Some(trait_def_id) { G::consider_builtin_coroutine_candidate(self, goal) } else if lang_items.discriminant_kind_trait() == Some(trait_def_id) { diff --git a/compiler/rustc_trait_selection/src/solve/project_goals/mod.rs b/compiler/rustc_trait_selection/src/solve/project_goals/mod.rs index 37bbf32c7682f..30d514c5e16d7 100644 --- a/compiler/rustc_trait_selection/src/solve/project_goals/mod.rs +++ b/compiler/rustc_trait_selection/src/solve/project_goals/mod.rs @@ -516,6 +516,40 @@ impl<'tcx> assembly::GoalKind<'tcx> for ProjectionPredicate<'tcx> { ) } + fn consider_builtin_async_iterator_candidate( + ecx: &mut EvalCtxt<'_, 'tcx>, + goal: Goal<'tcx, Self>, + ) -> QueryResult<'tcx> { + let self_ty = goal.predicate.self_ty(); + let ty::Coroutine(def_id, args, _) = *self_ty.kind() else { + return Err(NoSolution); + }; + + // Coroutines are not AsyncIterators unless they come from `gen` desugaring + let tcx = ecx.tcx(); + if !tcx.coroutine_is_async_gen(def_id) { + return Err(NoSolution); + } + + ecx.probe_misc_candidate("builtin AsyncIterator kind").enter(|ecx| { + // Take `AsyncIterator` and turn it into the corresponding + // coroutine yield ty `Poll>`. + let expected_ty = Ty::new_adt( + tcx, + tcx.adt_def(tcx.require_lang_item(LangItem::Poll, None)), + tcx.mk_args(&[Ty::new_adt( + tcx, + tcx.adt_def(tcx.require_lang_item(LangItem::Option, None)), + tcx.mk_args(&[goal.predicate.term.into()]), + ) + .into()]), + ); + let yield_ty = args.as_coroutine().yield_ty(); + ecx.eq(goal.param_env, expected_ty, yield_ty)?; + ecx.evaluate_added_goals_and_make_canonical_response(Certainty::Yes) + }) + } + fn consider_builtin_coroutine_candidate( ecx: &mut EvalCtxt<'_, 'tcx>, goal: Goal<'tcx, Self>, diff --git a/compiler/rustc_trait_selection/src/solve/trait_goals.rs b/compiler/rustc_trait_selection/src/solve/trait_goals.rs index 95712da3c5e82..5807f7c6153cf 100644 --- a/compiler/rustc_trait_selection/src/solve/trait_goals.rs +++ b/compiler/rustc_trait_selection/src/solve/trait_goals.rs @@ -370,6 +370,30 @@ impl<'tcx> assembly::GoalKind<'tcx> for TraitPredicate<'tcx> { ecx.evaluate_added_goals_and_make_canonical_response(Certainty::Yes) } + fn consider_builtin_async_iterator_candidate( + ecx: &mut EvalCtxt<'_, 'tcx>, + goal: Goal<'tcx, Self>, + ) -> QueryResult<'tcx> { + if goal.predicate.polarity != ty::ImplPolarity::Positive { + return Err(NoSolution); + } + + let ty::Coroutine(def_id, _, _) = *goal.predicate.self_ty().kind() else { + return Err(NoSolution); + }; + + // Coroutines are not iterators unless they come from `gen` desugaring + let tcx = ecx.tcx(); + if !tcx.coroutine_is_async_gen(def_id) { + return Err(NoSolution); + } + + // Gen coroutines unconditionally implement `Iterator` + // Technically, we need to check that the iterator output type is Sized, + // but that's already proven by the coroutines being WF. + ecx.evaluate_added_goals_and_make_canonical_response(Certainty::Yes) + } + fn consider_builtin_coroutine_candidate( ecx: &mut EvalCtxt<'_, 'tcx>, goal: Goal<'tcx, Self>, diff --git a/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs b/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs index e4cb99727fc30..5ca32c9a67df8 100644 --- a/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs +++ b/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs @@ -2425,6 +2425,23 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> { CoroutineKind::Async(CoroutineSource::Closure) => { format!("future created by async closure is not {trait_name}") } + CoroutineKind::AsyncGen(CoroutineSource::Fn) => self + .tcx + .parent(coroutine_did) + .as_local() + .map(|parent_did| self.tcx.local_def_id_to_hir_id(parent_did)) + .and_then(|parent_hir_id| hir.opt_name(parent_hir_id)) + .map(|name| { + format!("async iterator returned by `{name}` is not {trait_name}") + })?, + CoroutineKind::AsyncGen(CoroutineSource::Block) => { + format!("async iterator created by async gen block is not {trait_name}") + } + CoroutineKind::AsyncGen(CoroutineSource::Closure) => { + format!( + "async iterator created by async gen closure is not {trait_name}" + ) + } CoroutineKind::Gen(CoroutineSource::Fn) => self .tcx .parent(coroutine_did) @@ -2965,7 +2982,9 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> { let what = match self.tcx.coroutine_kind(coroutine_def_id) { None | Some(hir::CoroutineKind::Coroutine) - | Some(hir::CoroutineKind::Gen(_)) => "yield", + | Some(hir::CoroutineKind::Gen(_)) + // FIXME(gen_blocks): This could be yield or await... + | Some(hir::CoroutineKind::AsyncGen(_)) => "yield", Some(hir::CoroutineKind::Async(..)) => "await", }; err.note(format!( diff --git a/compiler/rustc_trait_selection/src/traits/error_reporting/type_err_ctxt_ext.rs b/compiler/rustc_trait_selection/src/traits/error_reporting/type_err_ctxt_ext.rs index b28bff1ca2a7a..a675053e562e5 100644 --- a/compiler/rustc_trait_selection/src/traits/error_reporting/type_err_ctxt_ext.rs +++ b/compiler/rustc_trait_selection/src/traits/error_reporting/type_err_ctxt_ext.rs @@ -1687,6 +1687,9 @@ impl<'tcx> InferCtxtPrivExt<'tcx> for TypeErrCtxt<'_, 'tcx> { hir::CoroutineKind::Async(hir::CoroutineSource::Block) => "an async block", hir::CoroutineKind::Async(hir::CoroutineSource::Fn) => "an async function", hir::CoroutineKind::Async(hir::CoroutineSource::Closure) => "an async closure", + hir::CoroutineKind::AsyncGen(hir::CoroutineSource::Block) => "an async gen block", + hir::CoroutineKind::AsyncGen(hir::CoroutineSource::Fn) => "an async gen function", + hir::CoroutineKind::AsyncGen(hir::CoroutineSource::Closure) => "an async gen closure", hir::CoroutineKind::Gen(hir::CoroutineSource::Block) => "a gen block", hir::CoroutineKind::Gen(hir::CoroutineSource::Fn) => "a gen function", hir::CoroutineKind::Gen(hir::CoroutineSource::Closure) => "a gen closure", diff --git a/compiler/rustc_trait_selection/src/traits/project.rs b/compiler/rustc_trait_selection/src/traits/project.rs index 3be149517034d..0f9b852969cfc 100644 --- a/compiler/rustc_trait_selection/src/traits/project.rs +++ b/compiler/rustc_trait_selection/src/traits/project.rs @@ -1823,11 +1823,18 @@ fn assemble_candidates_from_impls<'cx, 'tcx>( let self_ty = selcx.infcx.shallow_resolve(obligation.predicate.self_ty()); let lang_items = selcx.tcx().lang_items(); - if [lang_items.coroutine_trait(), lang_items.future_trait(), lang_items.iterator_trait()].contains(&Some(trait_ref.def_id)) - || selcx.tcx().fn_trait_kind_from_def_id(trait_ref.def_id).is_some() + if [ + lang_items.coroutine_trait(), + lang_items.future_trait(), + lang_items.iterator_trait(), + lang_items.async_iterator_trait(), + lang_items.fn_trait(), + lang_items.fn_mut_trait(), + lang_items.fn_once_trait(), + ].contains(&Some(trait_ref.def_id)) { true - } else if lang_items.discriminant_kind_trait() == Some(trait_ref.def_id) { + }else if lang_items.discriminant_kind_trait() == Some(trait_ref.def_id) { match self_ty.kind() { ty::Bool | ty::Char @@ -2042,6 +2049,8 @@ fn confirm_select_candidate<'cx, 'tcx>( confirm_future_candidate(selcx, obligation, data) } else if lang_items.iterator_trait() == Some(trait_def_id) { confirm_iterator_candidate(selcx, obligation, data) + } else if lang_items.async_iterator_trait() == Some(trait_def_id) { + confirm_async_iterator_candidate(selcx, obligation, data) } else if selcx.tcx().fn_trait_kind_from_def_id(trait_def_id).is_some() { if obligation.predicate.self_ty().is_closure() { confirm_closure_candidate(selcx, obligation, data) @@ -2206,6 +2215,58 @@ fn confirm_iterator_candidate<'cx, 'tcx>( .with_addl_obligations(obligations) } +fn confirm_async_iterator_candidate<'cx, 'tcx>( + selcx: &mut SelectionContext<'cx, 'tcx>, + obligation: &ProjectionTyObligation<'tcx>, + nested: Vec>, +) -> Progress<'tcx> { + let ty::Coroutine(_, args, _) = + selcx.infcx.shallow_resolve(obligation.predicate.self_ty()).kind() + else { + unreachable!() + }; + let gen_sig = args.as_coroutine().poly_sig(); + let Normalized { value: gen_sig, obligations } = normalize_with_depth( + selcx, + obligation.param_env, + obligation.cause.clone(), + obligation.recursion_depth + 1, + gen_sig, + ); + + debug!(?obligation, ?gen_sig, ?obligations, "confirm_async_iterator_candidate"); + + let tcx = selcx.tcx(); + let iter_def_id = tcx.require_lang_item(LangItem::AsyncIterator, None); + + let predicate = super::util::async_iterator_trait_ref_and_outputs( + tcx, + iter_def_id, + obligation.predicate.self_ty(), + gen_sig, + ) + .map_bound(|(trait_ref, yield_ty)| { + debug_assert_eq!(tcx.associated_item(obligation.predicate.def_id).name, sym::Item); + + let ty::Adt(_poll_adt, args) = *yield_ty.kind() else { + bug!(); + }; + let ty::Adt(_option_adt, args) = *args.type_at(0).kind() else { + bug!(); + }; + let item_ty = args.type_at(0); + + ty::ProjectionPredicate { + projection_ty: ty::AliasTy::new(tcx, obligation.predicate.def_id, trait_ref.args), + term: item_ty.into(), + } + }); + + confirm_param_env_candidate(selcx, obligation, predicate, false) + .with_addl_obligations(nested) + .with_addl_obligations(obligations) +} + fn confirm_builtin_candidate<'cx, 'tcx>( selcx: &mut SelectionContext<'cx, 'tcx>, obligation: &ProjectionTyObligation<'tcx>, diff --git a/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs b/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs index 5a559bb5fa0d4..36aba43c025d1 100644 --- a/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs +++ b/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs @@ -112,6 +112,8 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> { self.assemble_future_candidates(obligation, &mut candidates); } else if lang_items.iterator_trait() == Some(def_id) { self.assemble_iterator_candidates(obligation, &mut candidates); + } else if lang_items.async_iterator_trait() == Some(def_id) { + self.assemble_async_iterator_candidates(obligation, &mut candidates); } self.assemble_closure_candidates(obligation, &mut candidates); @@ -258,6 +260,34 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> { } } + fn assemble_async_iterator_candidates( + &mut self, + obligation: &PolyTraitObligation<'tcx>, + candidates: &mut SelectionCandidateSet<'tcx>, + ) { + let self_ty = obligation.self_ty().skip_binder(); + if let ty::Coroutine(did, args, _) = *self_ty.kind() { + // gen constructs get lowered to a special kind of coroutine that + // should directly `impl AsyncIterator`. + if self.tcx().coroutine_is_async_gen(did) { + debug!(?self_ty, ?obligation, "assemble_iterator_candidates",); + + // Can only confirm this candidate if we have constrained + // the `Yield` type to at least `Poll>`.. + let ty::Adt(_poll_def, args) = *args.as_coroutine().yield_ty().kind() else { + candidates.ambiguous = true; + return; + }; + let ty::Adt(_option_def, _) = *args.type_at(0).kind() else { + candidates.ambiguous = true; + return; + }; + + candidates.vec.push(AsyncIteratorCandidate); + } + } + } + /// Checks for the artificial impl that the compiler will create for an obligation like `X : /// FnMut<..>` where `X` is a closure type. /// diff --git a/compiler/rustc_trait_selection/src/traits/select/confirmation.rs b/compiler/rustc_trait_selection/src/traits/select/confirmation.rs index 4763a90617202..fd61110d4147e 100644 --- a/compiler/rustc_trait_selection/src/traits/select/confirmation.rs +++ b/compiler/rustc_trait_selection/src/traits/select/confirmation.rs @@ -98,6 +98,11 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> { ImplSource::Builtin(BuiltinImplSource::Misc, vtable_iterator) } + AsyncIteratorCandidate => { + let vtable_iterator = self.confirm_async_iterator_candidate(obligation)?; + ImplSource::Builtin(BuiltinImplSource::Misc, vtable_iterator) + } + FnPointerCandidate { is_const } => { let data = self.confirm_fn_pointer_candidate(obligation, is_const)?; ImplSource::Builtin(BuiltinImplSource::Misc, data) @@ -816,6 +821,36 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> { Ok(nested) } + fn confirm_async_iterator_candidate( + &mut self, + obligation: &PolyTraitObligation<'tcx>, + ) -> Result>, SelectionError<'tcx>> { + // Okay to skip binder because the args on coroutine types never + // touch bound regions, they just capture the in-scope + // type/region parameters. + let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder()); + let ty::Coroutine(coroutine_def_id, args, _) = *self_ty.kind() else { + bug!("closure candidate for non-closure {:?}", obligation); + }; + + debug!(?obligation, ?coroutine_def_id, ?args, "confirm_async_iterator_candidate"); + + let gen_sig = args.as_coroutine().poly_sig(); + + let trait_ref = super::util::async_iterator_trait_ref_and_outputs( + self.tcx(), + obligation.predicate.def_id(), + obligation.predicate.no_bound_vars().expect("iterator has no bound vars").self_ty(), + gen_sig, + ) + .map_bound(|(trait_ref, ..)| trait_ref); + + let nested = self.confirm_poly_trait_refs(obligation, trait_ref)?; + debug!(?trait_ref, ?nested, "iterator candidate obligations"); + + Ok(nested) + } + #[instrument(skip(self), level = "debug")] fn confirm_closure_candidate( &mut self, diff --git a/compiler/rustc_trait_selection/src/traits/select/mod.rs b/compiler/rustc_trait_selection/src/traits/select/mod.rs index 0c884b7c16a82..ce20359da4d07 100644 --- a/compiler/rustc_trait_selection/src/traits/select/mod.rs +++ b/compiler/rustc_trait_selection/src/traits/select/mod.rs @@ -1872,6 +1872,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> { | CoroutineCandidate | FutureCandidate | IteratorCandidate + | AsyncIteratorCandidate | FnPointerCandidate { .. } | BuiltinObjectCandidate | BuiltinUnsizeCandidate @@ -1901,6 +1902,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> { | CoroutineCandidate | FutureCandidate | IteratorCandidate + | AsyncIteratorCandidate | FnPointerCandidate { .. } | BuiltinObjectCandidate | BuiltinUnsizeCandidate @@ -1936,6 +1938,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> { | CoroutineCandidate | FutureCandidate | IteratorCandidate + | AsyncIteratorCandidate | FnPointerCandidate { .. } | BuiltinObjectCandidate | BuiltinUnsizeCandidate @@ -1951,6 +1954,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> { | CoroutineCandidate | FutureCandidate | IteratorCandidate + | AsyncIteratorCandidate | FnPointerCandidate { .. } | BuiltinObjectCandidate | BuiltinUnsizeCandidate @@ -2058,6 +2062,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> { | CoroutineCandidate | FutureCandidate | IteratorCandidate + | AsyncIteratorCandidate | FnPointerCandidate { .. } | BuiltinObjectCandidate | BuiltinUnsizeCandidate @@ -2069,6 +2074,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> { | CoroutineCandidate | FutureCandidate | IteratorCandidate + | AsyncIteratorCandidate | FnPointerCandidate { .. } | BuiltinObjectCandidate | BuiltinUnsizeCandidate diff --git a/compiler/rustc_trait_selection/src/traits/util.rs b/compiler/rustc_trait_selection/src/traits/util.rs index bbde0c82743b4..0b98d3ab0c4cf 100644 --- a/compiler/rustc_trait_selection/src/traits/util.rs +++ b/compiler/rustc_trait_selection/src/traits/util.rs @@ -308,6 +308,17 @@ pub fn iterator_trait_ref_and_outputs<'tcx>( sig.map_bound(|sig| (trait_ref, sig.yield_ty)) } +pub fn async_iterator_trait_ref_and_outputs<'tcx>( + tcx: TyCtxt<'tcx>, + async_iterator_def_id: DefId, + self_ty: Ty<'tcx>, + sig: ty::PolyGenSig<'tcx>, +) -> ty::Binder<'tcx, (ty::TraitRef<'tcx>, Ty<'tcx>)> { + assert!(!self_ty.has_escaping_bound_vars()); + let trait_ref = ty::TraitRef::new(tcx, async_iterator_def_id, [self_ty]); + sig.map_bound(|sig| (trait_ref, sig.yield_ty)) +} + pub fn impl_item_is_final(tcx: TyCtxt<'_>, assoc_item: &ty::AssocItem) -> bool { assoc_item.defaultness(tcx).is_final() && tcx.defaultness(assoc_item.container_id(tcx)).is_final() diff --git a/compiler/rustc_ty_utils/src/abi.rs b/compiler/rustc_ty_utils/src/abi.rs index c6ff7e2a9ef40..a61cf3072e59d 100644 --- a/compiler/rustc_ty_utils/src/abi.rs +++ b/compiler/rustc_ty_utils/src/abi.rs @@ -119,9 +119,9 @@ fn fn_sig_for_fn_abi<'tcx>( // unlike for all other coroutine kinds. env_ty } - hir::CoroutineKind::Async(_) | hir::CoroutineKind::Coroutine => { - Ty::new_adt(tcx, pin_adt_ref, pin_args) - } + hir::CoroutineKind::Async(_) + | hir::CoroutineKind::AsyncGen(_) + | hir::CoroutineKind::Coroutine => Ty::new_adt(tcx, pin_adt_ref, pin_args), }; let sig = sig.skip_binder(); @@ -169,6 +169,30 @@ fn fn_sig_for_fn_abi<'tcx>( (None, ret_ty) } + hir::CoroutineKind::AsyncGen(_) => { + // The signature should be + // `AsyncIterator::poll_next(_, &mut Context<'_>) -> Poll>` + assert_eq!(sig.return_ty, tcx.types.unit); + + // Yield type is already `Poll>` + let ret_ty = sig.yield_ty; + + // We have to replace the `ResumeTy` that is used for type and borrow checking + // with `&mut Context<'_>` which is used in codegen. + #[cfg(debug_assertions)] + { + if let ty::Adt(resume_ty_adt, _) = sig.resume_ty.kind() { + let expected_adt = + tcx.adt_def(tcx.require_lang_item(LangItem::ResumeTy, None)); + assert_eq!(*resume_ty_adt, expected_adt); + } else { + panic!("expected `ResumeTy`, found `{:?}`", sig.resume_ty); + }; + } + let context_mut_ref = Ty::new_task_context(tcx); + + (Some(context_mut_ref), ret_ty) + } hir::CoroutineKind::Coroutine => { // The signature should be `Coroutine::resume(_, Resume) -> CoroutineState` let state_did = tcx.require_lang_item(LangItem::CoroutineState, None); diff --git a/compiler/rustc_ty_utils/src/instance.rs b/compiler/rustc_ty_utils/src/instance.rs index 31f6a56eaeb40..f98e51231c141 100644 --- a/compiler/rustc_ty_utils/src/instance.rs +++ b/compiler/rustc_ty_utils/src/instance.rs @@ -271,6 +271,21 @@ fn resolve_associated_item<'tcx>( debug_assert!(tcx.defaultness(trait_item_id).has_value()); Some(Instance::new(trait_item_id, rcvr_args)) } + } else if Some(trait_ref.def_id) == lang_items.async_iterator_trait() { + let ty::Coroutine(coroutine_def_id, args, _) = *rcvr_args.type_at(0).kind() else { + bug!() + }; + + if cfg!(debug_assertions) && tcx.item_name(trait_item_id) != sym::poll_next { + span_bug!( + tcx.def_span(coroutine_def_id), + "no definition for `{trait_ref}::{}` for built-in coroutine type", + tcx.item_name(trait_item_id) + ) + } + + // `AsyncIterator::poll_next` is generated by the compiler. + Some(Instance { def: ty::InstanceDef::Item(coroutine_def_id), args }) } else if Some(trait_ref.def_id) == lang_items.coroutine_trait() { let ty::Coroutine(coroutine_def_id, args, _) = *rcvr_args.type_at(0).kind() else { bug!() diff --git a/library/core/src/async_iter/async_iter.rs b/library/core/src/async_iter/async_iter.rs index 12a47f9fc7626..15148b5fc9acd 100644 --- a/library/core/src/async_iter/async_iter.rs +++ b/library/core/src/async_iter/async_iter.rs @@ -13,6 +13,7 @@ use crate::task::{Context, Poll}; #[unstable(feature = "async_iterator", issue = "79024")] #[must_use = "async iterators do nothing unless polled"] #[doc(alias = "Stream")] +#[cfg_attr(not(bootstrap), lang = "async_iterator")] pub trait AsyncIterator { /// The type of items yielded by the async iterator. type Item; @@ -109,3 +110,25 @@ where (**self).size_hint() } } + +#[unstable(feature = "async_gen_internals", issue = "none")] +impl Poll> { + /// docs + #[unstable(feature = "async_gen_internals", issue = "none")] + #[cfg_attr(not(bootstrap), lang = "AsyncGenReady")] + pub fn async_gen_ready(t: T) -> Self { + Poll::Ready(Some(t)) + } + + /// docs + #[unstable(feature = "async_gen_internals", issue = "none")] + #[cfg_attr(not(bootstrap), lang = "AsyncGenPending")] + pub fn async_gen_pending() -> Self { + Poll::Pending + } + + /// docs + #[unstable(feature = "async_gen_internals", issue = "none")] + #[cfg_attr(not(bootstrap), lang = "AsyncGenFinished")] + pub const FINISHED: Self = Poll::Ready(None); +}