|
20 | 20 |
|
21 | 21 | #![warn(missing_docs, missing_debug_implementations, rust_2018_idioms)]
|
22 | 22 |
|
| 23 | +use std::cell::RefCell; |
23 | 24 | use std::fmt;
|
24 | 25 | use std::future::Future;
|
25 | 26 | use std::marker::PhantomData;
|
@@ -229,29 +230,56 @@ impl<'a> Executor<'a> {
|
229 | 230 | let runner = Runner::new(self.state());
|
230 | 231 | let mut rng = fastrand::Rng::new();
|
231 | 232 |
|
232 |
| - // A future that runs tasks forever. |
233 |
| - let run_forever = async { |
234 |
| - loop { |
235 |
| - for _ in 0..200 { |
236 |
| - let runnable = runner.runnable(&mut rng).await; |
237 |
| - runnable.run(); |
238 |
| - } |
239 |
| - future::yield_now().await; |
240 |
| - } |
241 |
| - }; |
| 233 | + // Set the local queue while we're running. |
| 234 | + LocalQueue::set(self.state(), &runner.local, { |
| 235 | + let runner = &runner; |
| 236 | + async move { |
| 237 | + // A future that runs tasks forever. |
| 238 | + let run_forever = async { |
| 239 | + loop { |
| 240 | + for _ in 0..200 { |
| 241 | + let runnable = runner.runnable(&mut rng).await; |
| 242 | + runnable.run(); |
| 243 | + } |
| 244 | + future::yield_now().await; |
| 245 | + } |
| 246 | + }; |
242 | 247 |
|
243 |
| - // Run `future` and `run_forever` concurrently until `future` completes. |
244 |
| - future.or(run_forever).await |
| 248 | + // Run `future` and `run_forever` concurrently until `future` completes. |
| 249 | + future.or(run_forever).await |
| 250 | + } |
| 251 | + }) |
| 252 | + .await |
245 | 253 | }
|
246 | 254 |
|
247 | 255 | /// Returns a function that schedules a runnable task when it gets woken up.
|
248 | 256 | fn schedule(&self) -> impl Fn(Runnable) + Send + Sync + 'static {
|
249 | 257 | let state = self.state().clone();
|
250 | 258 |
|
251 |
| - // TODO(stjepang): If possible, push into the current local queue and notify the ticker. |
| 259 | + // If possible, push into the current local queue and notify the ticker. |
252 | 260 | move |runnable| {
|
253 |
| - state.queue.push(runnable).unwrap(); |
254 |
| - state.notify(); |
| 261 | + let mut runnable = Some(runnable); |
| 262 | + |
| 263 | + // Try to push into the local queue. |
| 264 | + LocalQueue::with(|local_queue| { |
| 265 | + // Make sure that we don't accidentally push to an executor that isn't ours. |
| 266 | + if !std::ptr::eq(local_queue.state, &*state) { |
| 267 | + return; |
| 268 | + } |
| 269 | + |
| 270 | + if let Err(e) = local_queue.queue.push(runnable.take().unwrap()) { |
| 271 | + runnable = Some(e.into_inner()); |
| 272 | + return; |
| 273 | + } |
| 274 | + |
| 275 | + local_queue.waker.wake_by_ref(); |
| 276 | + }); |
| 277 | + |
| 278 | + // If the local queue push failed, just push to the global queue. |
| 279 | + if let Some(runnable) = runnable { |
| 280 | + state.queue.push(runnable).unwrap(); |
| 281 | + state.notify(); |
| 282 | + } |
255 | 283 | }
|
256 | 284 | }
|
257 | 285 |
|
@@ -819,6 +847,97 @@ impl Drop for Runner<'_> {
|
819 | 847 | }
|
820 | 848 | }
|
821 | 849 |
|
| 850 | +/// The state of the currently running local queue. |
| 851 | +struct LocalQueue { |
| 852 | + /// The pointer to the state of the executor. |
| 853 | + /// |
| 854 | + /// Used to make sure we don't push runnables to the wrong executor. |
| 855 | + state: *const State, |
| 856 | + |
| 857 | + /// The concurrent queue. |
| 858 | + queue: Arc<ConcurrentQueue<Runnable>>, |
| 859 | + |
| 860 | + /// The waker for the runnable. |
| 861 | + waker: Waker, |
| 862 | +} |
| 863 | + |
| 864 | +impl LocalQueue { |
| 865 | + /// Run a function with the current local queue. |
| 866 | + fn with<R>(f: impl FnOnce(&LocalQueue) -> R) -> Option<R> { |
| 867 | + std::thread_local! { |
| 868 | + /// The current local queue. |
| 869 | + static LOCAL_QUEUE: RefCell<Option<LocalQueue>> = RefCell::new(None); |
| 870 | + } |
| 871 | + |
| 872 | + impl LocalQueue { |
| 873 | + /// Run a function with a set local queue. |
| 874 | + async fn set<F>( |
| 875 | + state: &State, |
| 876 | + queue: &Arc<ConcurrentQueue<Runnable>>, |
| 877 | + fut: F, |
| 878 | + ) -> F::Output |
| 879 | + where |
| 880 | + F: Future, |
| 881 | + { |
| 882 | + // Store the local queue and the current waker. |
| 883 | + let mut old = with_waker(|waker| { |
| 884 | + LOCAL_QUEUE.with(move |slot| { |
| 885 | + slot.borrow_mut().replace(LocalQueue { |
| 886 | + state: state as *const State, |
| 887 | + queue: queue.clone(), |
| 888 | + waker: waker.clone(), |
| 889 | + }) |
| 890 | + }) |
| 891 | + }) |
| 892 | + .await; |
| 893 | + |
| 894 | + // Restore the old local queue on drop. |
| 895 | + let _guard = CallOnDrop(move || { |
| 896 | + let old = old.take(); |
| 897 | + let _ = LOCAL_QUEUE.try_with(move |slot| { |
| 898 | + *slot.borrow_mut() = old; |
| 899 | + }); |
| 900 | + }); |
| 901 | + |
| 902 | + // Pin the future. |
| 903 | + futures_lite::pin!(fut); |
| 904 | + |
| 905 | + // Run it such that the waker is updated every time it's polled. |
| 906 | + future::poll_fn(move |cx| { |
| 907 | + LOCAL_QUEUE |
| 908 | + .try_with({ |
| 909 | + let waker = cx.waker(); |
| 910 | + move |slot| { |
| 911 | + let mut slot = slot.borrow_mut(); |
| 912 | + let qaw = slot.as_mut().expect("missing local queue"); |
| 913 | + |
| 914 | + // If we've been replaced, just ignore the slot. |
| 915 | + if !Arc::ptr_eq(&qaw.queue, queue) { |
| 916 | + return; |
| 917 | + } |
| 918 | + |
| 919 | + // Update the waker, if it has changed. |
| 920 | + if !qaw.waker.will_wake(waker) { |
| 921 | + qaw.waker = waker.clone(); |
| 922 | + } |
| 923 | + } |
| 924 | + }) |
| 925 | + .ok(); |
| 926 | + |
| 927 | + // Poll the future. |
| 928 | + fut.as_mut().poll(cx) |
| 929 | + }) |
| 930 | + .await |
| 931 | + } |
| 932 | + } |
| 933 | + |
| 934 | + LOCAL_QUEUE |
| 935 | + .try_with(|local_queue| local_queue.borrow().as_ref().map(f)) |
| 936 | + .ok() |
| 937 | + .flatten() |
| 938 | + } |
| 939 | +} |
| 940 | + |
822 | 941 | /// Steals some items from one queue into another.
|
823 | 942 | fn steal<T>(src: &ConcurrentQueue<T>, dest: &ConcurrentQueue<T>) {
|
824 | 943 | // Half of `src`'s length rounded up.
|
@@ -911,10 +1030,19 @@ fn debug_executor(executor: &Executor<'_>, name: &str, f: &mut fmt::Formatter<'_
|
911 | 1030 | }
|
912 | 1031 |
|
913 | 1032 | /// Runs a closure when dropped.
|
914 |
| -struct CallOnDrop<F: Fn()>(F); |
| 1033 | +struct CallOnDrop<F: FnMut()>(F); |
915 | 1034 |
|
916 |
| -impl<F: Fn()> Drop for CallOnDrop<F> { |
| 1035 | +impl<F: FnMut()> Drop for CallOnDrop<F> { |
917 | 1036 | fn drop(&mut self) {
|
918 | 1037 | (self.0)();
|
919 | 1038 | }
|
920 | 1039 | }
|
| 1040 | + |
| 1041 | +/// Run a closure with the current waker. |
| 1042 | +fn with_waker<F: FnOnce(&Waker) -> R, R>(f: F) -> impl Future<Output = R> { |
| 1043 | + let mut f = Some(f); |
| 1044 | + future::poll_fn(move |cx| { |
| 1045 | + let f = f.take().unwrap(); |
| 1046 | + Poll::Ready(f(cx.waker())) |
| 1047 | + }) |
| 1048 | +} |
0 commit comments