@@ -224,7 +224,7 @@ struct SuspensionPoint<'tcx> {
224224
225225struct TransformVisitor < ' tcx > {
226226 tcx : TyCtxt < ' tcx > ,
227- is_async_kind : bool ,
227+ coroutine_kind : hir :: CoroutineKind ,
228228 state_adt_ref : AdtDef < ' tcx > ,
229229 state_args : GenericArgsRef < ' tcx > ,
230230
@@ -261,31 +261,53 @@ impl<'tcx> TransformVisitor<'tcx> {
261261 is_return : bool ,
262262 statements : & mut Vec < Statement < ' tcx > > ,
263263 ) {
264- let idx = VariantIdx :: new ( match ( is_return, self . is_async_kind ) {
265- ( true , false ) => 1 , // CoroutineState::Complete
266- ( false , false ) => 0 , // CoroutineState::Yielded
267- ( true , true ) => 0 , // Poll::Ready
268- ( false , true ) => 1 , // Poll::Pending
264+ let idx = VariantIdx :: new ( match ( is_return, self . coroutine_kind ) {
265+ ( true , hir:: CoroutineKind :: Coroutine ) => 1 , // CoroutineState::Complete
266+ ( false , hir:: CoroutineKind :: Coroutine ) => 0 , // CoroutineState::Yielded
267+ ( true , hir:: CoroutineKind :: Async ( _) ) => 0 , // Poll::Ready
268+ ( false , hir:: CoroutineKind :: Async ( _) ) => 1 , // Poll::Pending
269+ ( true , hir:: CoroutineKind :: Gen ( _) ) => 0 , // Option::None
270+ ( false , hir:: CoroutineKind :: Gen ( _) ) => 1 , // Option::Some
269271 } ) ;
270272
271273 let kind = AggregateKind :: Adt ( self . state_adt_ref . did ( ) , idx, self . state_args , None , None ) ;
272274
273- // `Poll::Pending`
274- if self . is_async_kind && idx == VariantIdx :: new ( 1 ) {
275- assert_eq ! ( self . state_adt_ref. variant( idx) . fields. len( ) , 0 ) ;
275+ match self . coroutine_kind {
276+ // `Poll::Pending`
277+ CoroutineKind :: Async ( _) => {
278+ if !is_return {
279+ assert_eq ! ( self . state_adt_ref. variant( idx) . fields. len( ) , 0 ) ;
276280
277- // FIXME(swatinem): assert that `val` is indeed unit?
278- statements. push ( Statement {
279- kind : StatementKind :: Assign ( Box :: new ( (
280- Place :: return_place ( ) ,
281- Rvalue :: Aggregate ( Box :: new ( kind) , IndexVec :: new ( ) ) ,
282- ) ) ) ,
283- source_info,
284- } ) ;
285- return ;
281+ // FIXME(swatinem): assert that `val` is indeed unit?
282+ statements. push ( Statement {
283+ kind : StatementKind :: Assign ( Box :: new ( (
284+ Place :: return_place ( ) ,
285+ Rvalue :: Aggregate ( Box :: new ( kind) , IndexVec :: new ( ) ) ,
286+ ) ) ) ,
287+ source_info,
288+ } ) ;
289+ return ;
290+ }
291+ }
292+ // `Option::None`
293+ CoroutineKind :: Gen ( _) => {
294+ if is_return {
295+ assert_eq ! ( self . state_adt_ref. variant( idx) . fields. len( ) , 0 ) ;
296+
297+ statements. push ( Statement {
298+ kind : StatementKind :: Assign ( Box :: new ( (
299+ Place :: return_place ( ) ,
300+ Rvalue :: Aggregate ( Box :: new ( kind) , IndexVec :: new ( ) ) ,
301+ ) ) ) ,
302+ source_info,
303+ } ) ;
304+ return ;
305+ }
306+ }
307+ CoroutineKind :: Coroutine => { }
286308 }
287309
288- // else: `Poll::Ready(x)`, `CoroutineState::Yielded(x)` or `CoroutineState::Complete(x)`
310+ // else: `Poll::Ready(x)`, `CoroutineState::Yielded(x)`, `CoroutineState::Complete(x)`, or `Option::Some (x)`
289311 assert_eq ! ( self . state_adt_ref. variant( idx) . fields. len( ) , 1 ) ;
290312
291313 statements. push ( Statement {
@@ -1439,18 +1461,28 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
14391461 } ;
14401462
14411463 let is_async_kind = matches ! ( body. coroutine_kind( ) , Some ( CoroutineKind :: Async ( _) ) ) ;
1442- let ( state_adt_ref, state_args) = if is_async_kind {
1443- // Compute Poll<return_ty>
1444- let poll_did = tcx. require_lang_item ( LangItem :: Poll , None ) ;
1445- let poll_adt_ref = tcx. adt_def ( poll_did) ;
1446- let poll_args = tcx. mk_args ( & [ body. return_ty ( ) . into ( ) ] ) ;
1447- ( poll_adt_ref, poll_args)
1448- } else {
1449- // Compute CoroutineState<yield_ty, return_ty>
1450- let state_did = tcx. require_lang_item ( LangItem :: CoroutineState , None ) ;
1451- let state_adt_ref = tcx. adt_def ( state_did) ;
1452- let state_args = tcx. mk_args ( & [ yield_ty. into ( ) , body. return_ty ( ) . into ( ) ] ) ;
1453- ( state_adt_ref, state_args)
1464+ let ( state_adt_ref, state_args) = match body. coroutine_kind ( ) . unwrap ( ) {
1465+ CoroutineKind :: Async ( _) => {
1466+ // Compute Poll<return_ty>
1467+ let poll_did = tcx. require_lang_item ( LangItem :: Poll , None ) ;
1468+ let poll_adt_ref = tcx. adt_def ( poll_did) ;
1469+ let poll_args = tcx. mk_args ( & [ body. return_ty ( ) . into ( ) ] ) ;
1470+ ( poll_adt_ref, poll_args)
1471+ }
1472+ CoroutineKind :: Gen ( _) => {
1473+ // Compute Option<yield_ty>
1474+ let option_did = tcx. require_lang_item ( LangItem :: Option , None ) ;
1475+ let option_adt_ref = tcx. adt_def ( option_did) ;
1476+ let option_args = tcx. mk_args ( & [ body. yield_ty ( ) . unwrap ( ) . into ( ) ] ) ;
1477+ ( option_adt_ref, option_args)
1478+ }
1479+ CoroutineKind :: Coroutine => {
1480+ // Compute CoroutineState<yield_ty, return_ty>
1481+ let state_did = tcx. require_lang_item ( LangItem :: CoroutineState , None ) ;
1482+ let state_adt_ref = tcx. adt_def ( state_did) ;
1483+ let state_args = tcx. mk_args ( & [ yield_ty. into ( ) , body. return_ty ( ) . into ( ) ] ) ;
1484+ ( state_adt_ref, state_args)
1485+ }
14541486 } ;
14551487 let ret_ty = Ty :: new_adt ( tcx, state_adt_ref, state_args) ;
14561488
@@ -1518,7 +1550,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
15181550 // or Poll::Ready(x) and Poll::Pending respectively depending on `is_async_kind`.
15191551 let mut transform = TransformVisitor {
15201552 tcx,
1521- is_async_kind ,
1553+ coroutine_kind : body . coroutine_kind ( ) . unwrap ( ) ,
15221554 state_adt_ref,
15231555 state_args,
15241556 remap,
0 commit comments