Skip to content

Commit 649bdfd

Browse files
authored
Support racy initialization of an Executor's state
Fixes smol-rs#89. Uses @notgull's suggestion of using a `AtomicPtr` with a racy initialization instead of a `OnceCell`. For the addition of more `unsafe`, I added the `clippy::undocumented_unsafe_blocks` lint at a warn, and fixed a few of the remaining open clippy issues (i.e. `Waker::clone_from` already handling the case where they're equal). Removing `async_lock` as a dependency shouldn't be a SemVer breaking change.
1 parent 4b37c61 commit 649bdfd

File tree

2 files changed

+88
-43
lines changed

2 files changed

+88
-43
lines changed

Cargo.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ categories = ["asynchronous", "concurrency"]
1515
exclude = ["/.*"]
1616

1717
[dependencies]
18-
async-lock = "3.0.0"
1918
async-task = "4.4.0"
2019
concurrent-queue = "2.0.0"
2120
fastrand = "2.0.0"

src/lib.rs

Lines changed: 88 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,12 @@
2525
//! future::block_on(ex.run(task));
2626
//! ```
2727
28-
#![warn(missing_docs, missing_debug_implementations, rust_2018_idioms)]
28+
#![warn(
29+
missing_docs,
30+
missing_debug_implementations,
31+
rust_2018_idioms,
32+
clippy::undocumented_unsafe_blocks
33+
)]
2934
#![doc(
3035
html_favicon_url = "https://raw.githubusercontent.com/smol-rs/smol/master/assets/images/logo_fullsize_transparent.png"
3136
)]
@@ -37,11 +42,10 @@ use std::fmt;
3742
use std::marker::PhantomData;
3843
use std::panic::{RefUnwindSafe, UnwindSafe};
3944
use std::rc::Rc;
40-
use std::sync::atomic::{AtomicBool, Ordering};
45+
use std::sync::atomic::{AtomicBool, AtomicPtr, Ordering};
4146
use std::sync::{Arc, Mutex, RwLock, TryLockError};
4247
use std::task::{Poll, Waker};
4348

44-
use async_lock::OnceCell;
4549
use async_task::{Builder, Runnable};
4650
use concurrent_queue::ConcurrentQueue;
4751
use futures_lite::{future, prelude::*};
@@ -76,13 +80,15 @@ pub use async_task::Task;
7680
/// ```
7781
pub struct Executor<'a> {
7882
/// The executor state.
79-
state: OnceCell<Arc<State>>,
83+
state: AtomicPtr<State>,
8084

8185
/// Makes the `'a` lifetime invariant.
8286
_marker: PhantomData<std::cell::UnsafeCell<&'a ()>>,
8387
}
8488

89+
// SAFETY: Executor stores no thread local state that can be accessed via other thread.
8590
unsafe impl Send for Executor<'_> {}
91+
// SAFETY: Executor internally synchronizes all of it's operations internally.
8692
unsafe impl Sync for Executor<'_> {}
8793

8894
impl UnwindSafe for Executor<'_> {}
@@ -106,7 +112,7 @@ impl<'a> Executor<'a> {
106112
/// ```
107113
pub const fn new() -> Executor<'a> {
108114
Executor {
109-
state: OnceCell::new(),
115+
state: AtomicPtr::new(std::ptr::null_mut()),
110116
_marker: PhantomData,
111117
}
112118
}
@@ -231,7 +237,7 @@ impl<'a> Executor<'a> {
231237
// Remove the task from the set of active tasks when the future finishes.
232238
let entry = active.vacant_entry();
233239
let index = entry.key();
234-
let state = self.state().clone();
240+
let state = self.state_as_arc();
235241
let future = async move {
236242
let _guard = CallOnDrop(move || drop(state.active.lock().unwrap().try_remove(index)));
237243
future.await
@@ -361,7 +367,7 @@ impl<'a> Executor<'a> {
361367

362368
/// Returns a function that schedules a runnable task when it gets woken up.
363369
fn schedule(&self) -> impl Fn(Runnable) + Send + Sync + 'static {
364-
let state = self.state().clone();
370+
let state = self.state_as_arc();
365371

366372
// TODO: If possible, push into the current local queue and notify the ticker.
367373
move |runnable| {
@@ -370,34 +376,73 @@ impl<'a> Executor<'a> {
370376
}
371377
}
372378

373-
/// Returns a reference to the inner state.
374-
fn state(&self) -> &Arc<State> {
375-
#[cfg(not(target_family = "wasm"))]
376-
{
377-
return self.state.get_or_init_blocking(|| Arc::new(State::new()));
379+
/// Returns a pointer to the inner state.
380+
#[inline]
381+
fn state_ptr(&self) -> *const State {
382+
#[cold]
383+
fn alloc_state(atomic_ptr: &AtomicPtr<State>) -> *mut State {
384+
let state = Arc::new(State::new());
385+
// TODO: Switch this to use cast_mut once the MSRV can be bumped past 1.65
386+
let ptr = Arc::into_raw(state) as *mut State;
387+
if let Err(actual) = atomic_ptr.compare_exchange(
388+
std::ptr::null_mut(),
389+
ptr,
390+
Ordering::AcqRel,
391+
Ordering::Acquire,
392+
) {
393+
// SAFETY: This was just created from Arc::into_raw.
394+
drop(unsafe { Arc::from_raw(ptr) });
395+
actual
396+
} else {
397+
ptr
398+
}
378399
}
379400

380-
// Some projects use this on WASM for some reason. In this case get_or_init_blocking
381-
// doesn't work. Just poll the future once and panic if there is contention.
382-
#[cfg(target_family = "wasm")]
383-
future::block_on(future::poll_once(
384-
self.state.get_or_init(|| async { Arc::new(State::new()) }),
385-
))
386-
.expect("encountered contention on WASM")
401+
let mut ptr = self.state.load(Ordering::Acquire);
402+
if ptr.is_null() {
403+
ptr = alloc_state(&self.state);
404+
}
405+
ptr
406+
}
407+
408+
/// Returns a reference to the inner state.
409+
#[inline]
410+
fn state(&self) -> &State {
411+
// SAFETY: So long as an Executor lives, it's state pointer will always be valid
412+
// when accessed through state_ptr.
413+
unsafe { &*self.state_ptr() }
414+
}
415+
416+
// Clones the inner state Arc
417+
#[inline]
418+
fn state_as_arc(&self) -> Arc<State> {
419+
// SAFETY: So long as an Executor lives, it's state pointer will always be a valid
420+
// Arc when accessed through state_ptr.
421+
let arc = unsafe { Arc::from_raw(self.state_ptr()) };
422+
let clone = arc.clone();
423+
std::mem::forget(arc);
424+
clone
387425
}
388426
}
389427

390428
impl Drop for Executor<'_> {
391429
fn drop(&mut self) {
392-
if let Some(state) = self.state.get() {
393-
let mut active = state.active.lock().unwrap_or_else(|e| e.into_inner());
394-
for w in active.drain() {
395-
w.wake();
396-
}
397-
drop(active);
430+
let ptr = *self.state.get_mut();
431+
if ptr.is_null() {
432+
return;
433+
}
434+
435+
// SAFETY: As ptr is not null, it was allocated via Arc::new and converted
436+
// via Arc::into_raw in state_ptr.
437+
let state = unsafe { Arc::from_raw(ptr) };
398438

399-
while state.queue.pop().is_ok() {}
439+
let mut active = state.active.lock().unwrap_or_else(|e| e.into_inner());
440+
for w in active.drain() {
441+
w.wake();
400442
}
443+
drop(active);
444+
445+
while state.queue.pop().is_ok() {}
401446
}
402447
}
403448

@@ -718,9 +763,7 @@ impl Sleepers {
718763
fn update(&mut self, id: usize, waker: &Waker) -> bool {
719764
for item in &mut self.wakers {
720765
if item.0 == id {
721-
if !item.1.will_wake(waker) {
722-
item.1.clone_from(waker);
723-
}
766+
item.1.clone_from(waker);
724767
return false;
725768
}
726769
}
@@ -1006,21 +1049,24 @@ fn steal<T>(src: &ConcurrentQueue<T>, dest: &ConcurrentQueue<T>) {
10061049
/// Debug implementation for `Executor` and `LocalExecutor`.
10071050
fn debug_executor(executor: &Executor<'_>, name: &str, f: &mut fmt::Formatter<'_>) -> fmt::Result {
10081051
// Get a reference to the state.
1009-
let state = match executor.state.get() {
1010-
Some(state) => state,
1011-
None => {
1012-
// The executor has not been initialized.
1013-
struct Uninitialized;
1014-
1015-
impl fmt::Debug for Uninitialized {
1016-
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1017-
f.write_str("<uninitialized>")
1018-
}
1052+
let ptr = executor.state.load(Ordering::Acquire);
1053+
if ptr.is_null() {
1054+
// The executor has not been initialized.
1055+
struct Uninitialized;
1056+
1057+
impl fmt::Debug for Uninitialized {
1058+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1059+
f.write_str("<uninitialized>")
10191060
}
1020-
1021-
return f.debug_tuple(name).field(&Uninitialized).finish();
10221061
}
1023-
};
1062+
1063+
return f.debug_tuple(name).field(&Uninitialized).finish();
1064+
}
1065+
1066+
// SAFETY: If the state pointer is not null, it must have been
1067+
// allocated properly by Arc::new and converted via Arc::into_raw
1068+
// in state_ptr.
1069+
let state = unsafe { &*ptr };
10241070

10251071
/// Debug wrapper for the number of active tasks.
10261072
struct ActiveTasks<'a>(&'a Mutex<Slab<Waker>>);

0 commit comments

Comments
 (0)