Skip to content

Commit

Permalink
Properly handle mutability for awaited futures (#239)
Browse files Browse the repository at this point in the history
* Properly handle mutability for awaited futures

* Simplify and add tests

* Implement review
  • Loading branch information
borchero authored Apr 9, 2024
1 parent 61a7007 commit 3c2fb9c
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 6 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
[`Sync`](https://doc.rust-lang.org/std/marker/trait.Sync.html) to prevent UB
when tests are executed in parallel. (see [#235](https://github.com/la10736/rstest/issues/235)
for more details)
- `#[future(awt)]` and `#[awt]` now properly handle mutable (`mut`) parameters by treating futures as immutable and
treating the awaited rebinding as mutable.

## [0.18.2] 2023/8/13

Expand Down
24 changes: 24 additions & 0 deletions rstest/tests/resources/rstest/cases/async_awt_mut.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
use rstest::*;

#[rstest]
#[case::pass(async { 3 })]
#[awt]
async fn my_mut_test_global_awt(
#[future]
#[case]
mut a: i32,
) {
a = 4;
assert_eq!(a, 4);
}

#[rstest]
#[case::pass(async { 3 })]
async fn my_mut_test_local_awt(
#[future(awt)]
#[case]
mut a: i32,
) {
a = 4;
assert_eq!(a, 4);
}
13 changes: 13 additions & 0 deletions rstest/tests/rstest/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,19 @@ mod cases {
.assert(output);
}

#[rstest]
fn should_run_async_mut() {
let prj = prj(res("async_awt_mut.rs"));
prj.add_dependency("async-std", r#"{version="*", features=["attributes"]}"#);

let output = prj.run_tests().unwrap();

TestResults::new()
.ok("my_mut_test_global_awt::case_1_pass")
.ok("my_mut_test_local_awt::case_1_pass")
.assert(output);
}

#[test]
fn should_use_injected_test_attr() {
let prj = prj(res("inject.rs"));
Expand Down
32 changes: 32 additions & 0 deletions rstest_macros/src/refident.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,35 @@ impl MaybeIdent for crate::parse::Attribute {
}
}
}

pub trait MaybePatIdent {
fn maybe_patident(&self) -> Option<&syn::PatIdent>;
}

impl MaybePatIdent for FnArg {
fn maybe_patident(&self) -> Option<&syn::PatIdent> {
match self {
FnArg::Typed(PatType { pat, .. }) => match pat.as_ref() {
Pat::Ident(ident) => Some(ident),
_ => None,
},
_ => None,
}
}
}

pub trait RemoveMutability {
fn remove_mutability(&mut self);
}

impl RemoveMutability for FnArg {
fn remove_mutability(&mut self) {
match self {
FnArg::Typed(PatType { pat, .. }) => match pat.as_mut() {
Pat::Ident(ident) => ident.mutability = None,
_ => {}
},
_ => {}
};
}
}
33 changes: 27 additions & 6 deletions rstest_macros/src/render/apply_argumets.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use syn::{parse_quote, FnArg, Generics, Ident, ItemFn, Lifetime, Signature, Type

use crate::{
parse::{arguments::ArgumentsInfo, future::MaybeFutureImplType},
refident::MaybeIdent,
refident::{MaybeIdent, MaybePatIdent, RemoveMutability},
};

pub(crate) trait ApplyArgumets<R: Sized = ()> {
Expand Down Expand Up @@ -59,17 +59,20 @@ impl ApplyArgumets for Signature {

impl ApplyArgumets for ItemFn {
fn apply_argumets(&mut self, arguments: &ArgumentsInfo) {
let awaited_args = self
let rebound_awaited_args = self
.sig
.inputs
.iter()
.filter_map(|a| a.maybe_ident())
.filter(|&a| arguments.is_future_await(a))
.cloned();
.filter_map(|a| a.maybe_patident())
.filter(|p| arguments.is_future_await(&p.ident))
.map(|p| {
let a = &p.ident;
quote::quote! { let #p = #a.await; }
});
let orig_block_impl = self.block.clone();
self.block = parse_quote! {
{
#(let #awaited_args = #awaited_args.await;)*
#(#rebound_awaited_args)*
#orig_block_impl
}
};
Expand All @@ -90,6 +93,7 @@ impl ImplFutureArg for FnArg {
*ty = parse_quote! {
impl std::future::Future<Output = #ty>
};
self.remove_mutability();
lifetime
}
None => None,
Expand Down Expand Up @@ -154,6 +158,11 @@ mod should {
&["a"],
"fn f<S: AsRef<str>>(a: impl std::future::Future<Output = S>) {}"
)]
#[case::remove_mut(
"fn f(mut a: u32) {}",
&["a"],
r#"fn f(a: impl std::future::Future<Output = u32>) {}"#
)]
fn replace_future_basic_type(
#[case] item_fn: &str,
#[case] futures: &[&str],
Expand Down Expand Up @@ -245,5 +254,17 @@ mod should {
assert_in!(code, await_argument_code_string("b"));
assert_not_in!(code, await_argument_code_string("c"));
}

#[test]
fn with_mut_await() {
let mut item_fn: ItemFn = r#"fn test(mut a: i32) {} "#.ast();
let mut arguments: ArgumentsInfo = Default::default();
arguments.set_future(ident("a"), FutureArg::Await);

item_fn.apply_argumets(&arguments);

let code = item_fn.block.display_code();
assert_in!(code, mut_await_argument_code_string("a"));
}
}
}
8 changes: 8 additions & 0 deletions rstest_macros/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -319,3 +319,11 @@ pub(crate) fn await_argument_code_string(arg_name: &str) -> String {
};
statment.display_code()
}

pub(crate) fn mut_await_argument_code_string(arg_name: &str) -> String {
let arg_name = ident(arg_name);
let statement: Stmt = parse_quote! {
let mut #arg_name = #arg_name.await;
};
statement.display_code()
}

0 comments on commit 3c2fb9c

Please sign in to comment.