Skip to content

Commit 1bd3348

Browse files
committed
Add thread create/destroy callbacks to TaskPool (#6561)
# Objective Fix #1991. Allow users to have a bit more control over the creation and finalization of the threads in `TaskPool`. ## Solution Add new methods to `TaskPoolBuilder` that expose callbacks that are called to initialize and finalize each thread in the `TaskPool`. Unlike the proposed solution in #1991, the callback is argument-less. If an an identifier is needed, `std::thread::current` should provide that information easily. Added a unit test to ensure that they're being called correctly.
1 parent f8e4b75 commit 1bd3348

File tree

1 file changed

+98
-14
lines changed

1 file changed

+98
-14
lines changed

crates/bevy_tasks/src/task_pool.rs

Lines changed: 98 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,18 @@ use futures_lite::{future, pin, FutureExt};
1212

1313
use crate::Task;
1414

15+
struct CallOnDrop(Option<Arc<dyn Fn() + Send + Sync + 'static>>);
16+
17+
impl Drop for CallOnDrop {
18+
fn drop(&mut self) {
19+
if let Some(call) = self.0.as_ref() {
20+
call();
21+
}
22+
}
23+
}
24+
1525
/// Used to create a [`TaskPool`]
16-
#[derive(Debug, Default, Clone)]
26+
#[derive(Default)]
1727
#[must_use]
1828
pub struct TaskPoolBuilder {
1929
/// If set, we'll set up the thread pool to use at most `num_threads` threads.
@@ -24,6 +34,9 @@ pub struct TaskPoolBuilder {
2434
/// Allows customizing the name of the threads - helpful for debugging. If set, threads will
2535
/// be named <thread_name> (<thread_index>), i.e. "MyThreadPool (2)"
2636
thread_name: Option<String>,
37+
38+
on_thread_spawn: Option<Arc<dyn Fn() + Send + Sync + 'static>>,
39+
on_thread_destroy: Option<Arc<dyn Fn() + Send + Sync + 'static>>,
2740
}
2841

2942
impl TaskPoolBuilder {
@@ -52,13 +65,27 @@ impl TaskPoolBuilder {
5265
self
5366
}
5467

68+
/// Sets a callback that is invoked once for every created thread as it starts.
69+
///
70+
/// This is called on the thread itself and has access to all thread-local storage.
71+
/// This will block running async tasks on the thread until the callback completes.
72+
pub fn on_thread_spawn(mut self, f: impl Fn() + Send + Sync + 'static) -> Self {
73+
self.on_thread_spawn = Some(Arc::new(f));
74+
self
75+
}
76+
77+
/// Sets a callback that is invoked once for every created thread as it terminates.
78+
///
79+
/// This is called on the thread itself and has access to all thread-local storage.
80+
/// This will block thread termination until the callback completes.
81+
pub fn on_thread_destroy(mut self, f: impl Fn() + Send + Sync + 'static) -> Self {
82+
self.on_thread_destroy = Some(Arc::new(f));
83+
self
84+
}
85+
5586
/// Creates a new [`TaskPool`] based on the current options.
5687
pub fn build(self) -> TaskPool {
57-
TaskPool::new_internal(
58-
self.num_threads,
59-
self.stack_size,
60-
self.thread_name.as_deref(),
61-
)
88+
TaskPool::new_internal(self)
6289
}
6390
}
6491

@@ -88,36 +115,42 @@ impl TaskPool {
88115
TaskPoolBuilder::new().build()
89116
}
90117

91-
fn new_internal(
92-
num_threads: Option<usize>,
93-
stack_size: Option<usize>,
94-
thread_name: Option<&str>,
95-
) -> Self {
118+
fn new_internal(builder: TaskPoolBuilder) -> Self {
96119
let (shutdown_tx, shutdown_rx) = async_channel::unbounded::<()>();
97120

98121
let executor = Arc::new(async_executor::Executor::new());
99122

100-
let num_threads = num_threads.unwrap_or_else(crate::available_parallelism);
123+
let num_threads = builder
124+
.num_threads
125+
.unwrap_or_else(crate::available_parallelism);
101126

102127
let threads = (0..num_threads)
103128
.map(|i| {
104129
let ex = Arc::clone(&executor);
105130
let shutdown_rx = shutdown_rx.clone();
106131

107-
let thread_name = if let Some(thread_name) = thread_name {
132+
let thread_name = if let Some(thread_name) = builder.thread_name.as_deref() {
108133
format!("{thread_name} ({i})")
109134
} else {
110135
format!("TaskPool ({i})")
111136
};
112137
let mut thread_builder = thread::Builder::new().name(thread_name);
113138

114-
if let Some(stack_size) = stack_size {
139+
if let Some(stack_size) = builder.stack_size {
115140
thread_builder = thread_builder.stack_size(stack_size);
116141
}
117142

143+
let on_thread_spawn = builder.on_thread_spawn.clone();
144+
let on_thread_destroy = builder.on_thread_destroy.clone();
145+
118146
thread_builder
119147
.spawn(move || {
120148
TaskPool::LOCAL_EXECUTOR.with(|local_executor| {
149+
if let Some(on_thread_spawn) = on_thread_spawn {
150+
on_thread_spawn();
151+
drop(on_thread_spawn);
152+
}
153+
let _destructor = CallOnDrop(on_thread_destroy);
121154
loop {
122155
let res = std::panic::catch_unwind(|| {
123156
let tick_forever = async move {
@@ -452,6 +485,57 @@ mod tests {
452485
assert_eq!(count.load(Ordering::Relaxed), 100);
453486
}
454487

488+
#[test]
489+
fn test_thread_callbacks() {
490+
let counter = Arc::new(AtomicI32::new(0));
491+
let start_counter = counter.clone();
492+
{
493+
let barrier = Arc::new(Barrier::new(11));
494+
let last_barrier = barrier.clone();
495+
// Build and immediately drop to terminate
496+
let _pool = TaskPoolBuilder::new()
497+
.num_threads(10)
498+
.on_thread_spawn(move || {
499+
start_counter.fetch_add(1, Ordering::Relaxed);
500+
barrier.clone().wait();
501+
})
502+
.build();
503+
last_barrier.wait();
504+
assert_eq!(10, counter.load(Ordering::Relaxed));
505+
}
506+
assert_eq!(10, counter.load(Ordering::Relaxed));
507+
let end_counter = counter.clone();
508+
{
509+
let _pool = TaskPoolBuilder::new()
510+
.num_threads(20)
511+
.on_thread_destroy(move || {
512+
end_counter.fetch_sub(1, Ordering::Relaxed);
513+
})
514+
.build();
515+
assert_eq!(10, counter.load(Ordering::Relaxed));
516+
}
517+
assert_eq!(-10, counter.load(Ordering::Relaxed));
518+
let start_counter = counter.clone();
519+
let end_counter = counter.clone();
520+
{
521+
let barrier = Arc::new(Barrier::new(6));
522+
let last_barrier = barrier.clone();
523+
let _pool = TaskPoolBuilder::new()
524+
.num_threads(5)
525+
.on_thread_spawn(move || {
526+
start_counter.fetch_add(1, Ordering::Relaxed);
527+
barrier.wait();
528+
})
529+
.on_thread_destroy(move || {
530+
end_counter.fetch_sub(1, Ordering::Relaxed);
531+
})
532+
.build();
533+
last_barrier.wait();
534+
assert_eq!(-5, counter.load(Ordering::Relaxed));
535+
}
536+
assert_eq!(-10, counter.load(Ordering::Relaxed));
537+
}
538+
455539
#[test]
456540
fn test_mixed_spawn_on_scope_and_spawn() {
457541
let pool = TaskPool::new();

0 commit comments

Comments
 (0)