Skip to content

Commit abe4268

Browse files
committed
util: add JoinQueue data structure
1 parent 9f59c69 commit abe4268

File tree

2 files changed

+349
-0
lines changed

2 files changed

+349
-0
lines changed

tokio-util/src/task/join_queue.rs

Lines changed: 346 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,346 @@
1+
use super::AbortOnDropHandle;
2+
use std::{
3+
collections::VecDeque,
4+
future::Future,
5+
pin::Pin,
6+
task::{Context, Poll},
7+
};
8+
use tokio::{
9+
runtime::Handle,
10+
task::{AbortHandle, Id, JoinError, JoinHandle},
11+
};
12+
13+
/// A collection of tasks spawned on a Tokio runtime.
14+
///
15+
/// A `JoinQueue` can be used to await the completion of the tasks in FIFO
16+
/// order. That is, if tasks are spawned in the order A, B, C, then
17+
/// awaiting the next completed task will always return A first, then B,
18+
/// then C, regardless of the order in which the tasks actually complete.
19+
///
20+
/// All of the tasks must have the same return type `T`.
21+
///
22+
/// When the `JoinQueue` is dropped, all tasks in the `JoinQueue` are
23+
/// immediately aborted.
24+
#[derive(Debug)]
25+
pub struct JoinQueue<T>(VecDeque<AbortOnDropHandle<T>>);
26+
27+
impl<T> JoinQueue<T> {
28+
/// Create a new empty `JoinQueue`.
29+
pub const fn new() -> Self {
30+
Self(VecDeque::new())
31+
}
32+
33+
/// Creates an empty `JoinQueue` with space for at least `capacity` tasks.
34+
pub fn with_capacity(capacity: usize) -> Self {
35+
Self(VecDeque::with_capacity(capacity))
36+
}
37+
38+
/// Returns the number of tasks currently in the `JoinQueue`.
39+
///
40+
/// This includes both tasks that are currently running and tasks that have
41+
/// completed but not yet been removed from the queue because outputting of
42+
/// them waits for FIFO order.
43+
pub fn len(&self) -> usize {
44+
self.0.len()
45+
}
46+
47+
/// Returns whether the `JoinQueue` is empty.
48+
pub fn is_empty(&self) -> bool {
49+
self.0.is_empty()
50+
}
51+
52+
/// Spawn the provided task on the `JoinQueue`, returning an [`AbortHandle`]
53+
/// that can be used to remotely cancel the task.
54+
///
55+
/// The provided future will start running in the background immediately
56+
/// when this method is called, even if you don't await anything on this
57+
/// `JoinQueue`.
58+
///
59+
/// # Panics
60+
///
61+
/// This method panics if called outside of a Tokio runtime.
62+
///
63+
/// [`AbortHandle`]: tokio::task::AbortHandle
64+
#[track_caller]
65+
pub fn spawn<F>(&mut self, task: F) -> AbortHandle
66+
where
67+
F: Future<Output = T> + Send + 'static,
68+
T: Send + 'static,
69+
{
70+
self.push_back(tokio::spawn(task))
71+
}
72+
73+
/// Spawn the provided task on the provided runtime and store it in this
74+
/// `JoinQueue` returning an [`AbortHandle`] that can be used to remotely
75+
/// cancel the task.
76+
///
77+
/// The provided future will start running in the background immediately
78+
/// when this method is called, even if you don't await anything on this
79+
/// `JoinQueue`.
80+
///
81+
/// [`AbortHandle`]: tokio::task::AbortHandle
82+
#[track_caller]
83+
pub fn spawn_on<F>(&mut self, task: F, handle: &Handle) -> AbortHandle
84+
where
85+
F: Future<Output = T> + Send + 'static,
86+
T: Send + 'static,
87+
{
88+
self.push_back(handle.spawn(task))
89+
}
90+
91+
/// Spawn the provided task on the current [`LocalSet`] and store it in this
92+
/// `JoinQueue`, returning an [`AbortHandle`] that can be used to remotely
93+
/// cancel the task.
94+
///
95+
/// The provided future will start running in the background immediately
96+
/// when this method is called, even if you don't await anything on this
97+
/// `JoinQueue`.
98+
///
99+
/// # Panics
100+
///
101+
/// This method panics if it is called outside of a `LocalSet`.
102+
///
103+
/// [`LocalSet`]: tokio::task::LocalSet
104+
/// [`AbortHandle`]: tokio::task::AbortHandle
105+
#[track_caller]
106+
pub fn spawn_local<F>(&mut self, task: F) -> AbortHandle
107+
where
108+
F: Future<Output = T> + 'static,
109+
T: 'static,
110+
{
111+
self.push_back(tokio::task::spawn_local(task))
112+
}
113+
114+
/// Spawn the blocking code on the blocking threadpool and store
115+
/// it in this `JoinQueue`, returning an [`AbortHandle`] that can be
116+
/// used to remotely cancel the task.
117+
///
118+
/// # Panics
119+
///
120+
/// This method panics if called outside of a Tokio runtime.
121+
///
122+
/// [`AbortHandle`]: tokio::task::AbortHandle
123+
#[track_caller]
124+
pub fn spawn_blocking<F>(&mut self, f: F) -> AbortHandle
125+
where
126+
F: FnOnce() -> T + Send + 'static,
127+
T: Send + 'static,
128+
{
129+
self.push_back(tokio::task::spawn_blocking(f))
130+
}
131+
132+
/// Spawn the blocking code on the blocking threadpool of the
133+
/// provided runtime and store it in this `JoinQueue`, returning an
134+
/// [`AbortHandle`] that can be used to remotely cancel the task.
135+
///
136+
/// [`AbortHandle`]: tokio::task::AbortHandle
137+
#[track_caller]
138+
pub fn spawn_blocking_on<F>(&mut self, f: F, handle: &Handle) -> AbortHandle
139+
where
140+
F: FnOnce() -> T + Send + 'static,
141+
T: Send + 'static,
142+
{
143+
self.push_back(handle.spawn_blocking(f))
144+
}
145+
146+
fn push_back(&mut self, jh: JoinHandle<T>) -> AbortHandle {
147+
let join_handle = AbortOnDropHandle::new(jh);
148+
let abort_handle = join_handle.abort_handle();
149+
self.0.push_back(join_handle);
150+
abort_handle
151+
}
152+
153+
/// Waits until the next task in FIFO order completes and returns its output.
154+
///
155+
/// Returns `None` if the queue is empty.
156+
///
157+
/// # Cancel Safety
158+
///
159+
/// This method is cancel safe. If `join_next` is used as the event in a `tokio::select!`
160+
/// statement and some other branch completes first, it is guaranteed that no tasks were
161+
/// removed from this `JoinQueue`.
162+
pub async fn join_next(&mut self) -> Option<Result<T, JoinError>> {
163+
std::future::poll_fn(|cx| self.poll_join_next(cx)).await
164+
}
165+
166+
/// Waits until the next task in FIFO order completes and returns its output,
167+
/// along with the [task ID] of the completed task.
168+
///
169+
/// Returns `None` if the queue is empty.
170+
///
171+
/// When this method returns an error, then the id of the task that failed can be accessed
172+
/// using the [`JoinError::id`] method.
173+
///
174+
/// # Cancel Safety
175+
///
176+
/// This method is cancel safe. If `join_next_with_id` is used as the event in a `tokio::select!`
177+
/// statement and some other branch completes first, it is guaranteed that no tasks were
178+
/// removed from this `JoinQueue`.
179+
///
180+
/// [task ID]: tokio::task::Id
181+
/// [`JoinError::id`]: fn@tokio::task::JoinError::id
182+
pub async fn join_next_with_id(&mut self) -> Option<Result<(Id, T), JoinError>> {
183+
std::future::poll_fn(|cx| self.poll_join_next_with_id(cx)).await
184+
}
185+
186+
/// Aborts all tasks and waits for them to finish shutting down.
187+
///
188+
/// Calling this method is equivalent to calling [`abort_all`] and then calling [`join_next`] in
189+
/// a loop until it returns `None`.
190+
///
191+
/// This method ignores any panics in the tasks shutting down. When this call returns, the
192+
/// `JoinQueue` will be empty.
193+
///
194+
/// [`abort_all`]: fn@Self::abort_all
195+
/// [`join_next`]: fn@Self::join_next
196+
pub async fn shutdown(&mut self) {
197+
self.abort_all();
198+
while self.join_next().await.is_some() {}
199+
}
200+
201+
/// Awaits the completion of all tasks in this `JoinQueue`, returning a vector of their results.
202+
///
203+
/// The results will be stored in the order they were spawned, not the order they completed.
204+
/// This is a convenience method that is equivalent to calling [`join_next`] in
205+
/// a loop. If any tasks on the `JoinQueue` fail with an [`JoinError`], then this call
206+
/// to `join_all` will panic and all remaining tasks on the `JoinQueue` are
207+
/// cancelled. To handle errors in any other way, manually call [`join_next`]
208+
/// in a loop.
209+
///
210+
/// [`join_next`]: fn@Self::join_next
211+
/// [`JoinError::id`]: fn@tokio::task::JoinError::id
212+
pub async fn join_all(mut self) -> Vec<T> {
213+
let mut output = Vec::with_capacity(self.len());
214+
215+
while let Some(res) = self.join_next().await {
216+
match res {
217+
Ok(t) => output.push(t),
218+
Err(err) if err.is_panic() => std::panic::resume_unwind(err.into_panic()),
219+
Err(err) => panic!("{err}"),
220+
}
221+
}
222+
output
223+
}
224+
225+
/// Aborts all tasks on this `JoinQueue`.
226+
///
227+
/// This does not remove the tasks from the `JoinQueue`. To wait for the tasks to complete
228+
/// cancellation, you should call `join_next` in a loop until the `JoinQueue` is empty.
229+
pub fn abort_all(&mut self) {
230+
self.0.iter().for_each(|jh| jh.abort());
231+
}
232+
233+
/// Removes all tasks from this `JoinQueue` without aborting them.
234+
///
235+
/// The tasks removed by this call will continue to run in the background even if the `JoinQueue`
236+
/// is dropped.
237+
pub fn detach_all(&mut self) {
238+
self.0.drain(..).for_each(|jh| drop(jh.detach()));
239+
}
240+
241+
/// Polls for the next task in `JoinQueue` to complete.
242+
///
243+
/// If this returns `Poll::Ready(Some(_))`, then the task that completed is removed from the queue.
244+
///
245+
/// When the method returns `Poll::Pending`, the `Waker` in the provided `Context` is scheduled
246+
/// to receive a wakeup when a task in the `JoinQueue` completes. Note that on multiple calls to
247+
/// `poll_join_next`, only the `Waker` from the `Context` passed to the most recent call is
248+
/// scheduled to receive a wakeup.
249+
///
250+
/// # Returns
251+
///
252+
/// This function returns:
253+
///
254+
/// * `Poll::Pending` if the `JoinQueue` is not empty but there is no task whose output is
255+
/// available right now.
256+
/// * `Poll::Ready(Some(Ok(value)))` if the next task in this `JoinQueue` has completed.
257+
/// The `value` is the return value that task.
258+
/// * `Poll::Ready(Some(Err(err)))` if the next task in this `JoinQueue` has panicked or been
259+
/// aborted. The `err` is the `JoinError` from the panicked/aborted task.
260+
/// * `Poll::Ready(None)` if the `JoinQueue` is empty.
261+
pub fn poll_join_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<T, JoinError>>> {
262+
let jh = match self.0.front_mut() {
263+
None => return Poll::Ready(None),
264+
Some(jh) => jh,
265+
};
266+
if let Poll::Ready(res) = Pin::new(jh).poll(cx) {
267+
drop(self.0.pop_front().unwrap().detach());
268+
Poll::Ready(Some(res))
269+
} else {
270+
// A JoinHandle generally won't emit a wakeup without being ready unless
271+
// the coop limit has been reached. We yield to the executor in this
272+
// case.
273+
cx.waker().wake_by_ref();
274+
Poll::Pending
275+
}
276+
}
277+
278+
/// Polls for the next task in `JoinQueue` to complete.
279+
///
280+
/// If this returns `Poll::Ready(Some(_))`, then the task that completed is removed from the queue.
281+
///
282+
/// When the method returns `Poll::Pending`, the `Waker` in the provided `Context` is scheduled
283+
/// to receive a wakeup when a task in the `JoinQueue` completes. Note that on multiple calls to
284+
/// `poll_join_next`, only the `Waker` from the `Context` passed to the most recent call is
285+
/// scheduled to receive a wakeup.
286+
///
287+
/// # Returns
288+
///
289+
/// This function returns:
290+
///
291+
/// * `Poll::Pending` if the `JoinQueue` is not empty but there is no task whose output is
292+
/// available right now.
293+
/// * `Poll::Ready(Some(Ok((id, value))))` if the next task in this `JoinQueue` has completed.
294+
/// The `value` is the return value that task, and `id` is its [task ID].
295+
/// * `Poll::Ready(Some(Err(err)))` if the next task in this `JoinQueue` has panicked or been
296+
/// aborted. The `err` is the `JoinError` from the panicked/aborted task.
297+
/// * `Poll::Ready(None)` if the `JoinQueue` is empty.
298+
///
299+
/// [task ID]: tokio::task::Id
300+
pub fn poll_join_next_with_id(
301+
&mut self,
302+
cx: &mut Context<'_>,
303+
) -> Poll<Option<Result<(Id, T), JoinError>>> {
304+
let jh = match self.0.front_mut() {
305+
None => return Poll::Ready(None),
306+
Some(jh) => jh,
307+
};
308+
if let Poll::Ready(res) = Pin::new(jh).poll(cx) {
309+
let jh = self.0.pop_front().unwrap().detach();
310+
let id = jh.id();
311+
drop(jh);
312+
// If the task succeeded, add the task ID to the output. Otherwise, the
313+
// `JoinError` will already have the task's ID.
314+
Poll::Ready(Some(res.map(|output| (id, output))))
315+
} else {
316+
// A JoinHandle generally won't emit a wakeup without being ready unless
317+
// the coop limit has been reached. We yield to the executor in this
318+
// case.
319+
cx.waker().wake_by_ref();
320+
Poll::Pending
321+
}
322+
}
323+
}
324+
325+
impl<T> Default for JoinQueue<T> {
326+
fn default() -> Self {
327+
Self::new()
328+
}
329+
}
330+
331+
/// Collect an iterator of futures into a [`JoinQueue`].
332+
///
333+
/// This is equivalent to calling [`JoinQueue::spawn`] on each element of the iterator.
334+
impl<T, F> std::iter::FromIterator<F> for JoinQueue<T>
335+
where
336+
F: Future<Output = T> + Send + 'static,
337+
T: Send + 'static,
338+
{
339+
fn from_iter<I: IntoIterator<Item = F>>(iter: I) -> Self {
340+
let mut set = Self::new();
341+
iter.into_iter().for_each(|task| {
342+
set.spawn(task);
343+
});
344+
set
345+
}
346+
}

tokio-util/src/task/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ cfg_rt! {
1313

1414
mod abort_on_drop;
1515
pub use abort_on_drop::AbortOnDropHandle;
16+
17+
mod join_queue;
18+
pub use join_queue::JoinQueue;
1619
}
1720

1821
#[cfg(feature = "join-map")]

0 commit comments

Comments
 (0)