@@ -12,8 +12,18 @@ use futures_lite::{future, pin, FutureExt};
12
12
13
13
use crate :: Task ;
14
14
15
+ struct CallOnDrop ( Option < Arc < dyn Fn ( ) + Send + Sync + ' static > > ) ;
16
+
17
+ impl Drop for CallOnDrop {
18
+ fn drop ( & mut self ) {
19
+ if let Some ( call) = self . 0 . as_ref ( ) {
20
+ call ( ) ;
21
+ }
22
+ }
23
+ }
24
+
15
25
/// Used to create a [`TaskPool`]
16
- #[ derive( Debug , Default , Clone ) ]
26
+ #[ derive( Default ) ]
17
27
#[ must_use]
18
28
pub struct TaskPoolBuilder {
19
29
/// If set, we'll set up the thread pool to use at most `num_threads` threads.
@@ -24,6 +34,9 @@ pub struct TaskPoolBuilder {
24
34
/// Allows customizing the name of the threads - helpful for debugging. If set, threads will
25
35
/// be named <thread_name> (<thread_index>), i.e. "MyThreadPool (2)"
26
36
thread_name : Option < String > ,
37
+
38
+ on_thread_spawn : Option < Arc < dyn Fn ( ) + Send + Sync + ' static > > ,
39
+ on_thread_destroy : Option < Arc < dyn Fn ( ) + Send + Sync + ' static > > ,
27
40
}
28
41
29
42
impl TaskPoolBuilder {
@@ -52,13 +65,27 @@ impl TaskPoolBuilder {
52
65
self
53
66
}
54
67
68
+ /// Sets a callback that is invoked once for every created thread as it starts.
69
+ ///
70
+ /// This is called on the thread itself and has access to all thread-local storage.
71
+ /// This will block running async tasks on the thread until the callback completes.
72
+ pub fn on_thread_spawn ( mut self , f : impl Fn ( ) + Send + Sync + ' static ) -> Self {
73
+ self . on_thread_spawn = Some ( Arc :: new ( f) ) ;
74
+ self
75
+ }
76
+
77
+ /// Sets a callback that is invoked once for every created thread as it terminates.
78
+ ///
79
+ /// This is called on the thread itself and has access to all thread-local storage.
80
+ /// This will block thread termination until the callback completes.
81
+ pub fn on_thread_destroy ( mut self , f : impl Fn ( ) + Send + Sync + ' static ) -> Self {
82
+ self . on_thread_destroy = Some ( Arc :: new ( f) ) ;
83
+ self
84
+ }
85
+
55
86
/// Creates a new [`TaskPool`] based on the current options.
56
87
pub fn build ( self ) -> TaskPool {
57
- TaskPool :: new_internal (
58
- self . num_threads ,
59
- self . stack_size ,
60
- self . thread_name . as_deref ( ) ,
61
- )
88
+ TaskPool :: new_internal ( self )
62
89
}
63
90
}
64
91
@@ -88,36 +115,42 @@ impl TaskPool {
88
115
TaskPoolBuilder :: new ( ) . build ( )
89
116
}
90
117
91
- fn new_internal (
92
- num_threads : Option < usize > ,
93
- stack_size : Option < usize > ,
94
- thread_name : Option < & str > ,
95
- ) -> Self {
118
+ fn new_internal ( builder : TaskPoolBuilder ) -> Self {
96
119
let ( shutdown_tx, shutdown_rx) = async_channel:: unbounded :: < ( ) > ( ) ;
97
120
98
121
let executor = Arc :: new ( async_executor:: Executor :: new ( ) ) ;
99
122
100
- let num_threads = num_threads. unwrap_or_else ( crate :: available_parallelism) ;
123
+ let num_threads = builder
124
+ . num_threads
125
+ . unwrap_or_else ( crate :: available_parallelism) ;
101
126
102
127
let threads = ( 0 ..num_threads)
103
128
. map ( |i| {
104
129
let ex = Arc :: clone ( & executor) ;
105
130
let shutdown_rx = shutdown_rx. clone ( ) ;
106
131
107
- let thread_name = if let Some ( thread_name) = thread_name {
132
+ let thread_name = if let Some ( thread_name) = builder . thread_name . as_deref ( ) {
108
133
format ! ( "{thread_name} ({i})" )
109
134
} else {
110
135
format ! ( "TaskPool ({i})" )
111
136
} ;
112
137
let mut thread_builder = thread:: Builder :: new ( ) . name ( thread_name) ;
113
138
114
- if let Some ( stack_size) = stack_size {
139
+ if let Some ( stack_size) = builder . stack_size {
115
140
thread_builder = thread_builder. stack_size ( stack_size) ;
116
141
}
117
142
143
+ let on_thread_spawn = builder. on_thread_spawn . clone ( ) ;
144
+ let on_thread_destroy = builder. on_thread_destroy . clone ( ) ;
145
+
118
146
thread_builder
119
147
. spawn ( move || {
120
148
TaskPool :: LOCAL_EXECUTOR . with ( |local_executor| {
149
+ if let Some ( on_thread_spawn) = on_thread_spawn {
150
+ on_thread_spawn ( ) ;
151
+ drop ( on_thread_spawn) ;
152
+ }
153
+ let _destructor = CallOnDrop ( on_thread_destroy) ;
121
154
loop {
122
155
let res = std:: panic:: catch_unwind ( || {
123
156
let tick_forever = async move {
@@ -452,6 +485,57 @@ mod tests {
452
485
assert_eq ! ( count. load( Ordering :: Relaxed ) , 100 ) ;
453
486
}
454
487
488
+ #[ test]
489
+ fn test_thread_callbacks ( ) {
490
+ let counter = Arc :: new ( AtomicI32 :: new ( 0 ) ) ;
491
+ let start_counter = counter. clone ( ) ;
492
+ {
493
+ let barrier = Arc :: new ( Barrier :: new ( 11 ) ) ;
494
+ let last_barrier = barrier. clone ( ) ;
495
+ // Build and immediately drop to terminate
496
+ let _pool = TaskPoolBuilder :: new ( )
497
+ . num_threads ( 10 )
498
+ . on_thread_spawn ( move || {
499
+ start_counter. fetch_add ( 1 , Ordering :: Relaxed ) ;
500
+ barrier. clone ( ) . wait ( ) ;
501
+ } )
502
+ . build ( ) ;
503
+ last_barrier. wait ( ) ;
504
+ assert_eq ! ( 10 , counter. load( Ordering :: Relaxed ) ) ;
505
+ }
506
+ assert_eq ! ( 10 , counter. load( Ordering :: Relaxed ) ) ;
507
+ let end_counter = counter. clone ( ) ;
508
+ {
509
+ let _pool = TaskPoolBuilder :: new ( )
510
+ . num_threads ( 20 )
511
+ . on_thread_destroy ( move || {
512
+ end_counter. fetch_sub ( 1 , Ordering :: Relaxed ) ;
513
+ } )
514
+ . build ( ) ;
515
+ assert_eq ! ( 10 , counter. load( Ordering :: Relaxed ) ) ;
516
+ }
517
+ assert_eq ! ( -10 , counter. load( Ordering :: Relaxed ) ) ;
518
+ let start_counter = counter. clone ( ) ;
519
+ let end_counter = counter. clone ( ) ;
520
+ {
521
+ let barrier = Arc :: new ( Barrier :: new ( 6 ) ) ;
522
+ let last_barrier = barrier. clone ( ) ;
523
+ let _pool = TaskPoolBuilder :: new ( )
524
+ . num_threads ( 5 )
525
+ . on_thread_spawn ( move || {
526
+ start_counter. fetch_add ( 1 , Ordering :: Relaxed ) ;
527
+ barrier. wait ( ) ;
528
+ } )
529
+ . on_thread_destroy ( move || {
530
+ end_counter. fetch_sub ( 1 , Ordering :: Relaxed ) ;
531
+ } )
532
+ . build ( ) ;
533
+ last_barrier. wait ( ) ;
534
+ assert_eq ! ( -5 , counter. load( Ordering :: Relaxed ) ) ;
535
+ }
536
+ assert_eq ! ( -10 , counter. load( Ordering :: Relaxed ) ) ;
537
+ }
538
+
455
539
#[ test]
456
540
fn test_mixed_spawn_on_scope_and_spawn ( ) {
457
541
let pool = TaskPool :: new ( ) ;
0 commit comments