Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 33 additions & 9 deletions tokio/src/macros/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ doc! {macro_rules! join {
(@ {
// Type of rotator that controls which inner future to start with
// when polling our output future.
rotator=$rotator:ty;
rotator_select=$rotator_select:ty;

// One `_` for each branch in the `join!` macro. This is not used once
// normalization is complete.
Expand All @@ -126,7 +126,7 @@ doc! {macro_rules! join {
$( ( $($skip:tt)* ) $e:expr, )*

}) => {{
use $crate::macros::support::{maybe_done, poll_fn, Future, Pin};
use $crate::macros::support::{maybe_done, poll_fn, Future, Pin, RotatorSelect};
use $crate::macros::support::Poll::{Ready, Pending};

// Safety: nothing must be moved out of `futures`. This is to satisfy
Expand All @@ -143,14 +143,14 @@ doc! {macro_rules! join {
// <https://internals.rust-lang.org/t/surprising-soundness-trouble-around-pollfn/17484>
let mut futures = &mut futures;

const COUNT: u32 = $($total)*;

// Each time the future created by poll_fn is polled, if not using biased mode,
// a different future is polled first to ensure every future passed to join!
// can make progress even if one of the futures consumes the whole budget.
let mut rotator = <$rotator>::default();
let mut rotator = <$rotator_select as RotatorSelect>::Rotator::<{$($total)*}>::default();

poll_fn(move |cx| {
const COUNT: u32 = $($total)*;

let mut is_pending = false;
let mut to_run = COUNT;

Expand Down Expand Up @@ -205,24 +205,48 @@ doc! {macro_rules! join {

// ===== Normalize =====

(@ { rotator=$rotator:ty; ( $($s:tt)* ) ( $($n:tt)* ) $($t:tt)* } $e:expr, $($r:tt)* ) => {
$crate::join!(@{ rotator=$rotator; ($($s)* _) ($($n)* + 1) $($t)* ($($s)*) $e, } $($r)*)
(@ { rotator_select=$rotator_select:ty; ( $($s:tt)* ) ( $($n:tt)* ) $($t:tt)* } $e:expr, $($r:tt)* ) => {
$crate::join!(@{ rotator_select=$rotator_select; ($($s)* _) ($($n)* + 1) $($t)* ($($s)*) $e, } $($r)*)
};

// ===== Entry point =====
( biased; $($e:expr),+ $(,)?) => {
$crate::join!(@{ rotator=$crate::macros::support::BiasedRotator; () (0) } $($e,)*)
$crate::join!(@{ rotator_select=$crate::macros::support::SelectBiased; () (0) } $($e,)*)
};

( $($e:expr),+ $(,)?) => {
$crate::join!(@{ rotator=$crate::macros::support::Rotator<COUNT>; () (0) } $($e,)*)
$crate::join!(@{ rotator_select=$crate::macros::support::SelectNormal; () (0) } $($e,)*)
};

(biased;) => { async {}.await };

() => { async {}.await }
}}

/// Helper trait to select which type of `Rotator` to use.
// We need this to allow specifying a const generic without
// colliding with caller const names due to macro hygiene.
pub trait RotatorSelect {
type Rotator<const COUNT: u32>: Default;
}

/// Marker type indicating that the starting branch should
/// rotate each poll.
#[derive(Debug)]
pub struct SelectNormal;
/// Marker type indicating that the starting branch should
/// be the first declared branch each poll.
#[derive(Debug)]
pub struct SelectBiased;

impl RotatorSelect for SelectNormal {
type Rotator<const COUNT: u32> = Rotator<COUNT>;
}

impl RotatorSelect for SelectBiased {
type Rotator<const COUNT: u32> = BiasedRotator;
}

/// Rotates by one each [`Self::num_skip`] call up to COUNT - 1.
#[derive(Default, Debug)]
pub struct Rotator<const COUNT: u32> {
Expand Down
2 changes: 1 addition & 1 deletion tokio/src/macros/support.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ cfg_macros! {

pub use std::future::poll_fn;

pub use crate::macros::join::{BiasedRotator, Rotator};
pub use crate::macros::join::{BiasedRotator, Rotator, RotatorSelect, SelectNormal, SelectBiased};

#[doc(hidden)]
pub fn thread_rng_n(n: u32) -> u32 {
Expand Down
18 changes: 9 additions & 9 deletions tokio/src/macros/try_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ doc! {macro_rules! try_join {
(@ {
// Type of rotator that controls which inner future to start with
// when polling our output future.
rotator=$rotator:ty;
rotator_select=$rotator_select:ty;

// One `_` for each branch in the `try_join!` macro. This is not used once
// normalization is complete.
Expand All @@ -179,7 +179,7 @@ doc! {macro_rules! try_join {
$( ( $($skip:tt)* ) $e:expr, )*

}) => {{
use $crate::macros::support::{maybe_done, poll_fn, Future, Pin};
use $crate::macros::support::{maybe_done, poll_fn, Future, Pin, RotatorSelect};
use $crate::macros::support::Poll::{Ready, Pending};

// Safety: nothing must be moved out of `futures`. This is to satisfy
Expand All @@ -196,14 +196,14 @@ doc! {macro_rules! try_join {
// <https://internals.rust-lang.org/t/surprising-soundness-trouble-around-pollfn/17484>
let mut futures = &mut futures;

const COUNT: u32 = $($total)*;

// Each time the future created by poll_fn is polled, if not using biased mode,
// a different future is polled first to ensure every future passed to try_join!
// can make progress even if one of the futures consumes the whole budget.
let mut rotator = <$rotator>::default();
let mut rotator = <$rotator_select as RotatorSelect>::Rotator::<{$($total)*}>::default();

poll_fn(move |cx| {
const COUNT: u32 = $($total)*;

let mut is_pending = false;
let mut to_run = COUNT;

Expand Down Expand Up @@ -264,17 +264,17 @@ doc! {macro_rules! try_join {

// ===== Normalize =====

(@ { rotator=$rotator:ty; ( $($s:tt)* ) ( $($n:tt)* ) $($t:tt)* } $e:expr, $($r:tt)* ) => {
$crate::try_join!(@{ rotator=$rotator; ($($s)* _) ($($n)* + 1) $($t)* ($($s)*) $e, } $($r)*)
(@ { rotator_select=$rotator_select:ty; ( $($s:tt)* ) ( $($n:tt)* ) $($t:tt)* } $e:expr, $($r:tt)* ) => {
$crate::try_join!(@{ rotator_select=$rotator_select; ($($s)* _) ($($n)* + 1) $($t)* ($($s)*) $e, } $($r)*)
};

// ===== Entry point =====
( biased; $($e:expr),+ $(,)?) => {
$crate::try_join!(@{ rotator=$crate::macros::support::BiasedRotator; () (0) } $($e,)*)
$crate::try_join!(@{ rotator_select=$crate::macros::support::SelectBiased; () (0) } $($e,)*)
};

( $($e:expr),+ $(,)?) => {
$crate::try_join!(@{ rotator=$crate::macros::support::Rotator<COUNT>; () (0) } $($e,)*)
$crate::try_join!(@{ rotator_select=$crate::macros::support::SelectNormal; () (0) } $($e,)*)
};

(biased;) => { async { Ok(()) }.await };
Expand Down
20 changes: 20 additions & 0 deletions tokio/tests/macros_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,23 @@ async fn join_into_future() {

tokio::join!(NotAFuture);
}

// Regression test for: https://github.com/tokio-rs/tokio/issues/7637
// We want to make sure that the `const COUNT: u32` declaration
// inside the macro body doesn't leak to the caller to cause compiler failures
// or variable shadowing.
#[tokio::test]
async fn caller_names_const_count() {
let (tx, rx) = oneshot::channel::<u32>();

const COUNT: u32 = 2;

let mut join = task::spawn(async { tokio::join!(async { tx.send(COUNT).unwrap() }) });
assert_ready!(join.poll());

let res = rx.await.unwrap();

// This passing demonstrates that the const in the macro is
// not shadowing the caller-specified COUNT value
assert_eq!(2, res);
}
20 changes: 20 additions & 0 deletions tokio/tests/macros_try_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,3 +247,23 @@ async fn empty_try_join() {
assert_eq!(tokio::try_join!() as Result<_, ()>, Ok(()));
assert_eq!(tokio::try_join!(biased;) as Result<_, ()>, Ok(()));
}

// Regression test for: https://github.com/tokio-rs/tokio/issues/7637
// We want to make sure that the `const COUNT: u32` declaration
// inside the macro body doesn't leak to the caller to cause compiler failures
// or variable shadowing.
#[tokio::test]
async fn caller_names_const_count() {
let (tx, rx) = oneshot::channel::<u32>();

const COUNT: u32 = 2;

let mut try_join = task::spawn(async { tokio::try_join!(async { tx.send(COUNT) }) });
assert_ready!(try_join.poll()).unwrap();

let res = rx.await.unwrap();

// This passing demonstrates that the const in the macro is
// not shadowing the caller-specified COUNT value
assert_eq!(2, res);
}
Loading