@@ -437,43 +437,56 @@ fn param_names(pat: Pat, record_type: RecordType) -> Box<dyn Iterator<Item = (Id
437437 }
438438}
439439
440- enum AsyncTraitKind < ' a > {
441- // old construction. Contains the function
440+ /// The specific async code pattern that was detected
441+ enum AsyncKind < ' a > {
442+ /// Immediately-invoked async fn, as generated by `async-trait <= 0.1.43`:
443+ /// `async fn foo<...>(...) {...}; Box::pin(foo<...>(...))`
442444 Function ( & ' a ItemFn ) ,
443- // new construction. Contains a reference to the async block
444- Async ( & ' a ExprAsync ) ,
445+ /// A function returning an async (move) block, optionally `Box::pin`-ed,
446+ /// as generated by `async-trait >= 0.1.44`:
447+ /// `Box::pin(async move { ... })`
448+ Async {
449+ async_expr : & ' a ExprAsync ,
450+ pinned_box : bool ,
451+ } ,
445452}
446453
447- pub ( crate ) struct AsyncTraitInfo < ' block > {
454+ pub ( crate ) struct AsyncInfo < ' block > {
448455 // statement that must be patched
449456 source_stmt : & ' block Stmt ,
450- kind : AsyncTraitKind < ' block > ,
457+ kind : AsyncKind < ' block > ,
451458 self_type : Option < syn:: TypePath > ,
452459 input : & ' block ItemFn ,
453460}
454461
455- impl < ' block > AsyncTraitInfo < ' block > {
456- /// Get the AST of the inner function we need to hook, if it was generated
457- /// by async-trait .
462+ impl < ' block > AsyncInfo < ' block > {
463+ /// Get the AST of the inner function we need to hook, if it looks like a
464+ /// manual future implementation .
458465 ///
459- /// When we are given a function annotated by async-trait, that function
460- /// is only a placeholder that returns a pinned future containing the
461- /// user logic, and it is that pinned future that needs to be instrumented.
466+ /// When we are given a function that returns a (pinned) future containing the
467+ /// user logic, it is that (pinned) future that needs to be instrumented.
462468 /// Were we to instrument its parent, we would only collect information
463469 /// regarding the allocation of that future, and not its own span of execution.
464- /// Depending on the version of async-trait, we inspect the block of the function
465- /// to find if it matches the pattern
466470 ///
467- /// `async fn foo<...>(...) {...}; Box::pin(foo<...>(...))` (<=0.1.43), or if
468- /// it matches `Box::pin(async move { ... }) (>=0.1.44). We the return the
469- /// statement that must be instrumented, along with some other informations.
471+ /// We inspect the block of the function to find if it matches any of the
472+ /// following patterns:
473+ ///
474+ /// - Immediately-invoked async fn, as generated by `async-trait <= 0.1.43`:
475+ /// `async fn foo<...>(...) {...}; Box::pin(foo<...>(...))`
476+ ///
477+ /// - A function returning an async (move) block, optionally `Box::pin`-ed,
478+ /// as generated by `async-trait >= 0.1.44`:
479+ /// `Box::pin(async move { ... })`
480+ ///
481+ /// We the return the statement that must be instrumented, along with some
482+ /// other information.
470483 /// 'gen_body' will then be able to use that information to instrument the
471484 /// proper function/future.
472485 ///
473486 /// (this follows the approach suggested in
474487 /// https://github.com/dtolnay/async-trait/issues/45#issuecomment-571245673)
475488 pub ( crate ) fn from_fn ( input : & ' block ItemFn ) -> Option < Self > {
476- // are we in an async context? If yes, this isn't a async_trait -like pattern
489+ // are we in an async context? If yes, this isn't a manual async -like pattern
477490 if input. sig . asyncness . is_some ( ) {
478491 return None ;
479492 }
@@ -491,10 +504,8 @@ impl<'block> AsyncTraitInfo<'block> {
491504 None
492505 } ) ;
493506
494- // last expression of the block (it determines the return value
495- // of the block, so that if we are working on a function whose
496- // `trait` or `impl` declaration is annotated by async_trait,
497- // this is quite likely the point where the future is pinned)
507+ // last expression of the block: it determines the return value of the
508+ // block, this is quite likely a `Box::pin` statement or an async block
498509 let ( last_expr_stmt, last_expr) = block. stmts . iter ( ) . rev ( ) . find_map ( |stmt| {
499510 if let Stmt :: Expr ( expr) = stmt {
500511 Some ( ( stmt, expr) )
@@ -503,6 +514,19 @@ impl<'block> AsyncTraitInfo<'block> {
503514 }
504515 } ) ?;
505516
517+ // is the last expression an async block?
518+ if let Expr :: Async ( async_expr) = last_expr {
519+ return Some ( AsyncInfo {
520+ source_stmt : last_expr_stmt,
521+ kind : AsyncKind :: Async {
522+ async_expr,
523+ pinned_box : false ,
524+ } ,
525+ self_type : None ,
526+ input,
527+ } ) ;
528+ }
529+
506530 // is the last expression a function call?
507531 let ( outside_func, outside_args) = match last_expr {
508532 Expr :: Call ( ExprCall { func, args, .. } ) => ( func, args) ,
@@ -528,12 +552,12 @@ impl<'block> AsyncTraitInfo<'block> {
528552 // Is the argument to Box::pin an async block that
529553 // captures its arguments?
530554 if let Expr :: Async ( async_expr) = & outside_args[ 0 ] {
531- // check that the move 'keyword' is present
532- async_expr. capture ?;
533-
534- return Some ( AsyncTraitInfo {
555+ return Some ( AsyncInfo {
535556 source_stmt : last_expr_stmt,
536- kind : AsyncTraitKind :: Async ( async_expr) ,
557+ kind : AsyncKind :: Async {
558+ async_expr,
559+ pinned_box : true ,
560+ } ,
537561 self_type : None ,
538562 input,
539563 } ) ;
@@ -579,15 +603,15 @@ impl<'block> AsyncTraitInfo<'block> {
579603 }
580604 }
581605
582- Some ( AsyncTraitInfo {
606+ Some ( AsyncInfo {
583607 source_stmt : stmt_func_declaration,
584- kind : AsyncTraitKind :: Function ( func) ,
608+ kind : AsyncKind :: Function ( func) ,
585609 self_type,
586610 input,
587611 } )
588612 }
589613
590- pub ( crate ) fn gen_async_trait (
614+ pub ( crate ) fn gen_async (
591615 self ,
592616 args : InstrumentArgs ,
593617 instrumented_function_name : & str ,
@@ -611,15 +635,18 @@ impl<'block> AsyncTraitInfo<'block> {
611635 {
612636 // instrument the future by rewriting the corresponding statement
613637 out_stmts[ iter] = match self . kind {
614- // async-trait <= 0.1.43
615- AsyncTraitKind :: Function ( fun) => gen_function (
638+ // `Box::pin(immediately_invoked_async_fn())`
639+ AsyncKind :: Function ( fun) => gen_function (
616640 fun. into ( ) ,
617641 args,
618642 instrumented_function_name,
619643 self . self_type . as_ref ( ) ,
620644 ) ,
621- // async-trait >= 0.1.44
622- AsyncTraitKind :: Async ( async_expr) => {
645+ // `async move { ... }`, optionally pinned
646+ AsyncKind :: Async {
647+ async_expr,
648+ pinned_box,
649+ } => {
623650 let instrumented_block = gen_block (
624651 & async_expr. block ,
625652 & self . input . sig . inputs ,
@@ -629,8 +656,14 @@ impl<'block> AsyncTraitInfo<'block> {
629656 None ,
630657 ) ;
631658 let async_attrs = & async_expr. attrs ;
632- quote ! {
633- Box :: pin( #( #async_attrs) * async move { #instrumented_block } )
659+ if pinned_box {
660+ quote ! {
661+ Box :: pin( #( #async_attrs) * async move { #instrumented_block } )
662+ }
663+ } else {
664+ quote ! {
665+ #( #async_attrs) * async move { #instrumented_block }
666+ }
634667 }
635668 }
636669 } ;
0 commit comments