Skip to content

Commit 196babc

Browse files
committed
fix(macros/(try_)join): don't leak const COUNT in macro expansion
1 parent bbe7c26 commit 196babc

File tree

5 files changed

+47
-25
lines changed

5 files changed

+47
-25
lines changed

tokio/src/macros/join.rs

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ doc! {macro_rules! join {
113113
(@ {
114114
// Type of rotator that controls which inner future to start with
115115
// when polling our output future.
116-
rotator=$rotator:ty;
116+
rotator_select=$rotator_select:ty;
117117

118118
// One `_` for each branch in the `join!` macro. This is not used once
119119
// normalization is complete.
@@ -126,7 +126,7 @@ doc! {macro_rules! join {
126126
$( ( $($skip:tt)* ) $e:expr, )*
127127

128128
}) => {{
129-
use $crate::macros::support::{maybe_done, poll_fn, Future, Pin};
129+
use $crate::macros::support::{maybe_done, poll_fn, Future, Pin, RotatorSelect};
130130
use $crate::macros::support::Poll::{Ready, Pending};
131131

132132
// Safety: nothing must be moved out of `futures`. This is to satisfy
@@ -143,14 +143,14 @@ doc! {macro_rules! join {
143143
// <https://internals.rust-lang.org/t/surprising-soundness-trouble-around-pollfn/17484>
144144
let mut futures = &mut futures;
145145

146-
const COUNT: u32 = $($total)*;
147-
148146
// Each time the future created by poll_fn is polled, if not using biased mode,
149147
// a different future is polled first to ensure every future passed to join!
150148
// can make progress even if one of the futures consumes the whole budget.
151-
let mut rotator = <$rotator>::default();
149+
let mut rotator = <$rotator_select as RotatorSelect>::Rotator::<{$($total)*}>::default();
152150

153151
poll_fn(move |cx| {
152+
const COUNT: u32 = $($total)*;
153+
154154
let mut is_pending = false;
155155
let mut to_run = COUNT;
156156

@@ -205,24 +205,48 @@ doc! {macro_rules! join {
205205

206206
// ===== Normalize =====
207207

208-
(@ { rotator=$rotator:ty; ( $($s:tt)* ) ( $($n:tt)* ) $($t:tt)* } $e:expr, $($r:tt)* ) => {
209-
$crate::join!(@{ rotator=$rotator; ($($s)* _) ($($n)* + 1) $($t)* ($($s)*) $e, } $($r)*)
208+
(@ { rotator_select=$rotator_select:ty; ( $($s:tt)* ) ( $($n:tt)* ) $($t:tt)* } $e:expr, $($r:tt)* ) => {
209+
$crate::join!(@{ rotator_select=$rotator_select; ($($s)* _) ($($n)* + 1) $($t)* ($($s)*) $e, } $($r)*)
210210
};
211211

212212
// ===== Entry point =====
213213
( biased; $($e:expr),+ $(,)?) => {
214-
$crate::join!(@{ rotator=$crate::macros::support::BiasedRotator; () (0) } $($e,)*)
214+
$crate::join!(@{ rotator_select=$crate::macros::support::SelectBiased; () (0) } $($e,)*)
215215
};
216216

217217
( $($e:expr),+ $(,)?) => {
218-
$crate::join!(@{ rotator=$crate::macros::support::Rotator<COUNT>; () (0) } $($e,)*)
218+
$crate::join!(@{ rotator_select=$crate::macros::support::SelectNormal; () (0) } $($e,)*)
219219
};
220220

221221
(biased;) => { async {}.await };
222222

223223
() => { async {}.await }
224224
}}
225225

226+
/// Helper trait to select which type of `Rotator` to use.
227+
// We need this to allow specifying a const generic without
228+
// colliding with caller const names due to macro hygiene.
229+
pub trait RotatorSelect {
230+
type Rotator<const COUNT: u32>: Default;
231+
}
232+
233+
/// Marker type indicating that the starting branch should
234+
/// rotate each poll.
235+
#[derive(Debug)]
236+
pub struct SelectNormal;
237+
/// Marker type indicating that the starting branch should
238+
/// be the first declared branch each poll.
239+
#[derive(Debug)]
240+
pub struct SelectBiased;
241+
242+
impl RotatorSelect for SelectNormal {
243+
type Rotator<const COUNT: u32> = Rotator<COUNT>;
244+
}
245+
246+
impl RotatorSelect for SelectBiased {
247+
type Rotator<const COUNT: u32> = BiasedRotator;
248+
}
249+
226250
/// Rotates by one each [`Self::num_skip`] call up to COUNT - 1.
227251
#[derive(Default, Debug)]
228252
pub struct Rotator<const COUNT: u32> {

tokio/src/macros/support.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ cfg_macros! {
33

44
pub use std::future::poll_fn;
55

6-
pub use crate::macros::join::{BiasedRotator, Rotator};
6+
pub use crate::macros::join::{BiasedRotator, Rotator, RotatorSelect, SelectNormal, SelectBiased};
77

88
#[doc(hidden)]
99
pub fn thread_rng_n(n: u32) -> u32 {

tokio/src/macros/try_join.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ doc! {macro_rules! try_join {
166166
(@ {
167167
// Type of rotator that controls which inner future to start with
168168
// when polling our output future.
169-
rotator=$rotator:ty;
169+
rotator_select=$rotator_select:ty;
170170

171171
// One `_` for each branch in the `try_join!` macro. This is not used once
172172
// normalization is complete.
@@ -179,7 +179,7 @@ doc! {macro_rules! try_join {
179179
$( ( $($skip:tt)* ) $e:expr, )*
180180

181181
}) => {{
182-
use $crate::macros::support::{maybe_done, poll_fn, Future, Pin};
182+
use $crate::macros::support::{maybe_done, poll_fn, Future, Pin, RotatorSelect};
183183
use $crate::macros::support::Poll::{Ready, Pending};
184184

185185
// Safety: nothing must be moved out of `futures`. This is to satisfy
@@ -196,14 +196,14 @@ doc! {macro_rules! try_join {
196196
// <https://internals.rust-lang.org/t/surprising-soundness-trouble-around-pollfn/17484>
197197
let mut futures = &mut futures;
198198

199-
const COUNT: u32 = $($total)*;
200-
201199
// Each time the future created by poll_fn is polled, if not using biased mode,
202200
// a different future is polled first to ensure every future passed to try_join!
203201
// can make progress even if one of the futures consumes the whole budget.
204-
let mut rotator = <$rotator>::default();
202+
let mut rotator = <$rotator_select as RotatorSelect>::Rotator::<{$($total)*}>::default();
205203

206204
poll_fn(move |cx| {
205+
const COUNT: u32 = $($total)*;
206+
207207
let mut is_pending = false;
208208
let mut to_run = COUNT;
209209

@@ -264,17 +264,17 @@ doc! {macro_rules! try_join {
264264

265265
// ===== Normalize =====
266266

267-
(@ { rotator=$rotator:ty; ( $($s:tt)* ) ( $($n:tt)* ) $($t:tt)* } $e:expr, $($r:tt)* ) => {
268-
$crate::try_join!(@{ rotator=$rotator; ($($s)* _) ($($n)* + 1) $($t)* ($($s)*) $e, } $($r)*)
267+
(@ { rotator_select=$rotator_select:ty; ( $($s:tt)* ) ( $($n:tt)* ) $($t:tt)* } $e:expr, $($r:tt)* ) => {
268+
$crate::try_join!(@{ rotator_select=$rotator_select; ($($s)* _) ($($n)* + 1) $($t)* ($($s)*) $e, } $($r)*)
269269
};
270270

271271
// ===== Entry point =====
272272
( biased; $($e:expr),+ $(,)?) => {
273-
$crate::try_join!(@{ rotator=$crate::macros::support::BiasedRotator; () (0) } $($e,)*)
273+
$crate::try_join!(@{ rotator_select=$crate::macros::support::SelectBiased; () (0) } $($e,)*)
274274
};
275275

276276
( $($e:expr),+ $(,)?) => {
277-
$crate::try_join!(@{ rotator=$crate::macros::support::Rotator<COUNT>; () (0) } $($e,)*)
277+
$crate::try_join!(@{ rotator_select=$crate::macros::support::SelectNormal; () (0) } $($e,)*)
278278
};
279279

280280
(biased;) => { async { Ok(()) }.await };

tokio/tests/macros_join.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,6 @@ async fn join_into_future() {
243243
async fn caller_names_const_count() {
244244
let (tx, rx) = oneshot::channel::<u32>();
245245

246-
#[allow(unused)]
247246
const COUNT: u32 = 2;
248247

249248
let mut join = task::spawn(async { tokio::join!(async { tx.send(COUNT).unwrap() }) });
@@ -252,6 +251,6 @@ async fn caller_names_const_count() {
252251
let res = rx.await.unwrap();
253252

254253
// This passing demonstrates that the const in the macro is
255-
// shadowing the caller-specified COUNT value.
256-
assert_eq!(1, res);
254+
// not shadowing the caller-specified COUNT value
255+
assert_eq!(2, res);
257256
}

tokio/tests/macros_try_join.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,6 @@ async fn empty_try_join() {
256256
async fn caller_names_const_count() {
257257
let (tx, rx) = oneshot::channel::<u32>();
258258

259-
#[allow(unused)]
260259
const COUNT: u32 = 2;
261260

262261
let mut try_join = task::spawn(async { tokio::try_join!(async { tx.send(COUNT) }) });
@@ -265,6 +264,6 @@ async fn caller_names_const_count() {
265264
let res = rx.await.unwrap();
266265

267266
// This passing demonstrates that the const in the macro is
268-
// shadowing the caller-specified COUNT value.
269-
assert_eq!(1, res);
267+
// not shadowing the caller-specified COUNT value
268+
assert_eq!(2, res);
270269
}

0 commit comments

Comments
 (0)