Skip to content

Commit f01f486

Browse files
authored
Add lifetime bounds to nested references (#38)
* Add failing test * Use visitor to collect references * Include nested reference expansion output test * Fix up a few comments
1 parent 0b38976 commit f01f486

File tree

5 files changed

+108
-84
lines changed

5 files changed

+108
-84
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ proc-macro = true
1515
[dependencies]
1616
proc-macro2 = { version = "1.0", default-features = false }
1717
quote = { version = "1.0", default-features = false }
18-
syn = { version = "2.0", features = ["full", "parsing", "printing", "proc-macro", "clone-impls"], default-features = false }
18+
syn = { version = "2.0", features = ["full", "parsing", "printing", "proc-macro", "clone-impls", "visit-mut"], default-features = false }
1919

2020
[dev-dependencies]
2121
futures-executor = "0.3"

src/expand.rs

Lines changed: 72 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
use proc_macro2::{Span, TokenStream};
22
use quote::{quote, ToTokens};
33
use syn::{
4-
parse_quote, punctuated::Punctuated, Block, FnArg, Lifetime, ReturnType, Signature, Type,
5-
WhereClause,
4+
parse_quote, punctuated::Punctuated, visit_mut::VisitMut, Block, Lifetime, Receiver,
5+
ReturnType, Signature, TypeReference, WhereClause,
66
};
77

88
use crate::parse::{AsyncItem, RecursionArgs};
@@ -40,6 +40,63 @@ impl ArgLifetime {
4040
}
4141
}
4242

43+
#[derive(Default)]
44+
struct ReferenceVisitor {
45+
counter: usize,
46+
lifetimes: Vec<ArgLifetime>,
47+
self_receiver: bool,
48+
self_receiver_new_lifetime: bool,
49+
self_lifetime: Option<Lifetime>,
50+
}
51+
52+
impl VisitMut for ReferenceVisitor {
53+
fn visit_receiver_mut(&mut self, receiver: &mut Receiver) {
54+
self.self_lifetime = Some(if let Some((_, lt)) = &mut receiver.reference {
55+
self.self_receiver = true;
56+
57+
if let Some(lt) = lt {
58+
lt.clone()
59+
} else {
60+
// Use 'life_self to avoid collisions with 'life<count> lifetimes.
61+
let new_lifetime: Lifetime = parse_quote!('life_self);
62+
lt.replace(new_lifetime.clone());
63+
64+
self.self_receiver_new_lifetime = true;
65+
66+
new_lifetime
67+
}
68+
} else {
69+
return;
70+
});
71+
}
72+
73+
fn visit_type_reference_mut(&mut self, argument: &mut TypeReference) {
74+
if argument.lifetime.is_none() {
75+
// If this reference doesn't have a lifetime (e.g. &T), then give it one.
76+
let lt = Lifetime::new(&format!("'life{}", self.counter), Span::call_site());
77+
self.lifetimes.push(ArgLifetime::New(parse_quote!(#lt)));
78+
argument.lifetime = Some(lt);
79+
self.counter += 1;
80+
} else {
81+
// If it does (e.g. &'life T), then keep track of it.
82+
let lt = argument.lifetime.as_ref().cloned().unwrap();
83+
84+
// Check that this lifetime isn't already in our vector
85+
let ident_matches = |x: &ArgLifetime| {
86+
if let ArgLifetime::Existing(elt) = x {
87+
elt.ident == lt.ident
88+
} else {
89+
false
90+
}
91+
};
92+
93+
if !self.lifetimes.iter().any(ident_matches) {
94+
self.lifetimes.push(ArgLifetime::Existing(lt));
95+
}
96+
}
97+
}
98+
}
99+
43100
// Input:
44101
// async fn f<S, T>(x : S, y : &T) -> Ret;
45102
//
@@ -55,67 +112,13 @@ fn transform_sig(sig: &mut Signature, args: &RecursionArgs) {
55112
// Remove the asyncness of this function
56113
sig.asyncness = None;
57114

58-
// Find all reference arguments
59-
let mut ref_arguments = Vec::new();
60-
let mut self_lifetime = None;
61-
62-
for arg in &mut sig.inputs {
63-
if let FnArg::Typed(pt) = arg {
64-
match pt.ty.as_mut() {
65-
// rustc can give us a None-delimited group if this type comes from
66-
// a macro_rules macro. I don't think this can happen for code the user has written.
67-
Type::Group(tg) => {
68-
if let Type::Reference(tr) = &mut *tg.elem {
69-
ref_arguments.push(tr);
70-
}
71-
}
72-
Type::Reference(tr) => {
73-
ref_arguments.push(tr);
74-
}
75-
_ => {}
76-
}
77-
} else if let FnArg::Receiver(recv) = arg {
78-
if let Some((_, slt)) = &mut recv.reference {
79-
self_lifetime = Some(slt);
80-
}
81-
}
82-
}
83-
84-
let mut counter = 0;
85-
let mut lifetimes = Vec::new();
86-
87-
if !ref_arguments.is_empty() {
88-
for ra in &mut ref_arguments {
89-
// If this reference arg doesn't have a lifetime, give it an explicit one
90-
if ra.lifetime.is_none() {
91-
let lt = Lifetime::new(&format!("'life{counter}"), Span::call_site());
92-
93-
lifetimes.push(ArgLifetime::New(parse_quote!(#lt)));
94-
95-
ra.lifetime = Some(lt);
96-
counter += 1;
97-
} else {
98-
let lt = ra.lifetime.as_ref().cloned().unwrap();
99-
100-
// Check that this lifetime isn't already in our vector
101-
let ident_matches = |x: &ArgLifetime| {
102-
if let ArgLifetime::Existing(elt) = x {
103-
elt.ident == lt.ident
104-
} else {
105-
false
106-
}
107-
};
108-
109-
if !lifetimes.iter().any(ident_matches) {
110-
lifetimes.push(ArgLifetime::Existing(
111-
ra.lifetime.as_ref().cloned().unwrap(),
112-
));
113-
}
114-
}
115-
}
115+
// Find and update any references in the input arguments
116+
let mut v = ReferenceVisitor::default();
117+
for input in &mut sig.inputs {
118+
v.visit_fn_arg_mut(input);
116119
}
117120

118-
// Does this expansion require `async_recursion to be added to the output
121+
// Does this expansion require `async_recursion to be added to the output?
119122
let mut requires_lifetime = false;
120123
let mut where_clause_lifetimes = vec![];
121124
let mut where_clause_generics = vec![];
@@ -127,13 +130,13 @@ fn transform_sig(sig: &mut Signature, args: &RecursionArgs) {
127130
for param in sig.generics.type_params() {
128131
let ident = param.ident.clone();
129132
where_clause_generics.push(ident);
130-
131133
requires_lifetime = true;
132134
}
133135

134136
// Add an 'a : 'async_recursion bound to any lifetimes 'a appearing in the function
135-
if !lifetimes.is_empty() {
136-
for alt in lifetimes {
137+
if !v.lifetimes.is_empty() {
138+
requires_lifetime = true;
139+
for alt in v.lifetimes {
137140
if let ArgLifetime::New(lt) = &alt {
138141
// If this is a new argument,
139142
sig.generics.params.push(parse_quote!(#lt));
@@ -143,29 +146,15 @@ fn transform_sig(sig: &mut Signature, args: &RecursionArgs) {
143146
let lt = alt.lifetime();
144147
where_clause_lifetimes.push(lt);
145148
}
146-
147-
requires_lifetime = true;
148149
}
149150

150151
// If our function accepts &self, then we modify this to the explicit lifetime &'life_self,
151152
// and add the bound &'life_self : 'async_recursion
152-
if let Some(slt) = self_lifetime {
153-
let lt = {
154-
if let Some(lt) = slt.as_mut() {
155-
lt.clone()
156-
} else {
157-
// We use `life_self here to avoid any collisions with `life0, `life1 from above
158-
let lt: Lifetime = parse_quote!('life_self);
159-
sig.generics.params.push(parse_quote!(#lt));
160-
161-
// add lt to the lifetime of self
162-
*slt = Some(lt.clone());
163-
164-
lt
165-
}
166-
};
167-
168-
where_clause_lifetimes.push(lt);
153+
if v.self_receiver {
154+
if v.self_receiver_new_lifetime {
155+
sig.generics.params.push(parse_quote!('life_self));
156+
}
157+
where_clause_lifetimes.extend(v.self_lifetime);
169158
requires_lifetime = true;
170159
}
171160

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
use async_recursion::async_recursion;
2+
#[must_use]
3+
fn count_down<'life0, 'async_recursion>(
4+
foo: Option<&'life0 str>,
5+
) -> ::core::pin::Pin<
6+
Box<
7+
dyn ::core::future::Future<
8+
Output = i32,
9+
> + 'async_recursion + ::core::marker::Send,
10+
>,
11+
>
12+
where
13+
'life0: 'async_recursion,
14+
{
15+
Box::pin(async move {
16+
let _ = foo;
17+
0
18+
})
19+
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
use async_recursion::async_recursion;
2+
3+
#[async_recursion]
4+
async fn count_down(foo: Option<&str>) -> i32 {
5+
let _ = foo;
6+
0
7+
}

tests/lifetimes.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,13 @@ async fn contains_value_2<'a, 'b, T: PartialEq>(value: &'b T, node: &'b Node<'a,
3434
contains_value(value, node).await
3535
}
3636

37+
// The reference inside foo needs a `async_recursion bound
38+
#[async_recursion]
39+
async fn count_down(foo: Option<&str>) -> i32 {
40+
let _ = foo;
41+
0
42+
}
43+
3744
#[test]
3845
fn lifetime_expansion_works() {
3946
block_on(async move {
@@ -64,5 +71,7 @@ fn lifetime_expansion_works() {
6471
assert_eq!(contains_value_2(&17, &node).await, true);
6572
assert_eq!(contains_value_2(&13, &node).await, true);
6673
assert_eq!(contains_value_2(&12, &node).await, false);
74+
75+
count_down(None).await;
6776
});
6877
}

0 commit comments

Comments
 (0)