Skip to content

Commit 795754a

Browse files
authored
sync: make notify_waiters calls atomic (#5458)
1 parent 0f17d69 commit 795754a

File tree

4 files changed

+474
-44
lines changed

4 files changed

+474
-44
lines changed

tokio/src/sync/notify.rs

Lines changed: 148 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
use crate::loom::sync::atomic::AtomicUsize;
99
use crate::loom::sync::Mutex;
10-
use crate::util::linked_list::{self, LinkedList};
10+
use crate::util::linked_list::{self, GuardedLinkedList, LinkedList};
1111
use crate::util::WakeList;
1212

1313
use std::cell::UnsafeCell;
@@ -20,6 +20,7 @@ use std::sync::atomic::Ordering::SeqCst;
2020
use std::task::{Context, Poll, Waker};
2121

2222
type WaitList = LinkedList<Waiter, <Waiter as linked_list::Link>::Target>;
23+
type GuardedWaitList = GuardedLinkedList<Waiter, <Waiter as linked_list::Link>::Target>;
2324

2425
/// Notifies a single task to wake up.
2526
///
@@ -198,10 +199,16 @@ type WaitList = LinkedList<Waiter, <Waiter as linked_list::Link>::Target>;
198199
/// [`Semaphore`]: crate::sync::Semaphore
199200
#[derive(Debug)]
200201
pub struct Notify {
201-
// This uses 2 bits to store one of `EMPTY`,
202+
// `state` uses 2 bits to store one of `EMPTY`,
202203
// `WAITING` or `NOTIFIED`. The rest of the bits
203204
// are used to store the number of times `notify_waiters`
204205
// was called.
206+
//
207+
// Throughout the code there are two assumptions:
208+
// - state can be transitioned *from* `WAITING` only if
209+
// `waiters` lock is held
210+
// - number of times `notify_waiters` was called can
211+
// be modified only if `waiters` lock is held
205212
state: AtomicUsize,
206213
waiters: Mutex<WaitList>,
207214
}
@@ -229,6 +236,17 @@ struct Waiter {
229236
_p: PhantomPinned,
230237
}
231238

239+
impl Waiter {
240+
fn new() -> Waiter {
241+
Waiter {
242+
pointers: linked_list::Pointers::new(),
243+
waker: None,
244+
notified: None,
245+
_p: PhantomPinned,
246+
}
247+
}
248+
}
249+
232250
generate_addr_of_methods! {
233251
impl<> Waiter {
234252
unsafe fn addr_of_pointers(self: NonNull<Self>) -> NonNull<linked_list::Pointers<Waiter>> {
@@ -237,6 +255,59 @@ generate_addr_of_methods! {
237255
}
238256
}
239257

258+
/// List used in `Notify::notify_waiters`. It wraps a guarded linked list
259+
/// and gates the access to it on `notify.waiters` mutex. It also empties
260+
/// the list on drop.
261+
struct NotifyWaitersList<'a> {
262+
list: GuardedWaitList,
263+
is_empty: bool,
264+
notify: &'a Notify,
265+
}
266+
267+
impl<'a> NotifyWaitersList<'a> {
268+
fn new(
269+
unguarded_list: WaitList,
270+
guard: Pin<&'a mut UnsafeCell<Waiter>>,
271+
notify: &'a Notify,
272+
) -> NotifyWaitersList<'a> {
273+
// Safety: pointer to the guarding waiter is not null.
274+
let guard_ptr = unsafe { NonNull::new_unchecked(guard.get()) };
275+
let list = unguarded_list.into_guarded(guard_ptr);
276+
NotifyWaitersList {
277+
list,
278+
is_empty: false,
279+
notify,
280+
}
281+
}
282+
283+
/// Removes the last element from the guarded list. Modifying this list
284+
/// requires an exclusive access to the main list in `Notify`.
285+
fn pop_back_locked(&mut self, _waiters: &mut WaitList) -> Option<NonNull<Waiter>> {
286+
let result = self.list.pop_back();
287+
if result.is_none() {
288+
// Save information about emptiness to avoid waiting for lock
289+
// in the destructor.
290+
self.is_empty = true;
291+
}
292+
result
293+
}
294+
}
295+
296+
impl Drop for NotifyWaitersList<'_> {
297+
fn drop(&mut self) {
298+
// If the list is not empty, we unlink all waiters from it.
299+
// We do not wake the waiters to avoid double panics.
300+
if !self.is_empty {
301+
let _lock_guard = self.notify.waiters.lock();
302+
while let Some(mut waiter) = self.list.pop_back() {
303+
// Safety: we hold the lock.
304+
let waiter = unsafe { waiter.as_mut() };
305+
waiter.notified = Some(NotificationType::AllWaiters);
306+
}
307+
}
308+
}
309+
}
310+
240311
/// Future returned from [`Notify::notified()`].
241312
///
242313
/// This future is fused, so once it has completed, any future calls to poll
@@ -249,6 +320,9 @@ pub struct Notified<'a> {
249320
/// The current state of the receiving process.
250321
state: State,
251322

323+
/// Number of calls to `notify_waiters` at the time of creation.
324+
notify_waiters_calls: usize,
325+
252326
/// Entry in the waiter `LinkedList`.
253327
waiter: UnsafeCell<Waiter>,
254328
}
@@ -258,7 +332,7 @@ unsafe impl<'a> Sync for Notified<'a> {}
258332

259333
#[derive(Debug)]
260334
enum State {
261-
Init(usize),
335+
Init,
262336
Waiting,
263337
Done,
264338
}
@@ -383,17 +457,13 @@ impl Notify {
383457
/// ```
384458
pub fn notified(&self) -> Notified<'_> {
385459
// we load the number of times notify_waiters
386-
// was called and store that in our initial state
460+
// was called and store that in the future.
387461
let state = self.state.load(SeqCst);
388462
Notified {
389463
notify: self,
390-
state: State::Init(state >> NOTIFY_WAITERS_SHIFT),
391-
waiter: UnsafeCell::new(Waiter {
392-
pointers: linked_list::Pointers::new(),
393-
waker: None,
394-
notified: None,
395-
_p: PhantomPinned,
396-
}),
464+
state: State::Init,
465+
notify_waiters_calls: get_num_notify_waiters_calls(state),
466+
waiter: UnsafeCell::new(Waiter::new()),
397467
}
398468
}
399469

@@ -500,12 +570,9 @@ impl Notify {
500570
/// }
501571
/// ```
502572
pub fn notify_waiters(&self) {
503-
let mut wakers = WakeList::new();
504-
505-
// There are waiters, the lock must be acquired to notify.
506573
let mut waiters = self.waiters.lock();
507574

508-
// The state must be reloaded while the lock is held. The state may only
575+
// The state must be loaded while the lock is held. The state may only
509576
// transition out of WAITING while the lock is held.
510577
let curr = self.state.load(SeqCst);
511578

@@ -516,12 +583,30 @@ impl Notify {
516583
return;
517584
}
518585

519-
// At this point, it is guaranteed that the state will not
520-
// concurrently change, as holding the lock is required to
521-
// transition **out** of `WAITING`.
586+
// Increment the number of times this method was called
587+
// and transition to empty.
588+
let new_state = set_state(inc_num_notify_waiters_calls(curr), EMPTY);
589+
self.state.store(new_state, SeqCst);
590+
591+
// It is critical for `GuardedLinkedList` safety that the guard node is
592+
// pinned in memory and is not dropped until the guarded list is dropped.
593+
let guard = UnsafeCell::new(Waiter::new());
594+
pin!(guard);
595+
596+
// We move all waiters to a secondary list. It uses a `GuardedLinkedList`
597+
// underneath to allow every waiter to safely remove itself from it.
598+
//
599+
// * This list will be still guarded by the `waiters` lock.
600+
// `NotifyWaitersList` wrapper makes sure we hold the lock to modify it.
601+
// * This wrapper will empty the list on drop. It is critical for safety
602+
// that we will not leave any list entry with a pointer to the local
603+
// guard node after this function returns / panics.
604+
let mut list = NotifyWaitersList::new(std::mem::take(&mut *waiters), guard, self);
605+
606+
let mut wakers = WakeList::new();
522607
'outer: loop {
523608
while wakers.can_push() {
524-
match waiters.pop_back() {
609+
match list.pop_back_locked(&mut waiters) {
525610
Some(mut waiter) => {
526611
// Safety: `waiters` lock is still held.
527612
let waiter = unsafe { waiter.as_mut() };
@@ -540,20 +625,17 @@ impl Notify {
540625
}
541626
}
542627

628+
// Release the lock before notifying.
543629
drop(waiters);
544630

631+
// One of the wakers may panic, but the remaining waiters will still
632+
// be unlinked from the list in `NotifyWaitersList` destructor.
545633
wakers.wake_all();
546634

547635
// Acquire the lock again.
548636
waiters = self.waiters.lock();
549637
}
550638

551-
// All waiters will be notified, the state must be transitioned to
552-
// `EMPTY`. As transitioning **from** `WAITING` requires the lock to be
553-
// held, a `store` is sufficient.
554-
let new = set_state(inc_num_notify_waiters_calls(curr), EMPTY);
555-
self.state.store(new, SeqCst);
556-
557639
// Release the lock before notifying
558640
drop(waiters);
559641

@@ -730,26 +812,32 @@ impl Notified<'_> {
730812

731813
/// A custom `project` implementation is used in place of `pin-project-lite`
732814
/// as a custom drop implementation is needed.
733-
fn project(self: Pin<&mut Self>) -> (&Notify, &mut State, &UnsafeCell<Waiter>) {
815+
fn project(self: Pin<&mut Self>) -> (&Notify, &mut State, &usize, &UnsafeCell<Waiter>) {
734816
unsafe {
735-
// Safety: both `notify` and `state` are `Unpin`.
817+
// Safety: `notify`, `state` and `notify_waiters_calls` are `Unpin`.
736818

737819
is_unpin::<&Notify>();
738820
is_unpin::<AtomicUsize>();
821+
is_unpin::<usize>();
739822

740823
let me = self.get_unchecked_mut();
741-
(me.notify, &mut me.state, &me.waiter)
824+
(
825+
me.notify,
826+
&mut me.state,
827+
&me.notify_waiters_calls,
828+
&me.waiter,
829+
)
742830
}
743831
}
744832

745833
fn poll_notified(self: Pin<&mut Self>, waker: Option<&Waker>) -> Poll<()> {
746834
use State::*;
747835

748-
let (notify, state, waiter) = self.project();
836+
let (notify, state, notify_waiters_calls, waiter) = self.project();
749837

750838
loop {
751839
match *state {
752-
Init(initial_notify_waiters_calls) => {
840+
Init => {
753841
let curr = notify.state.load(SeqCst);
754842

755843
// Optimistically try acquiring a pending notification
@@ -779,7 +867,7 @@ impl Notified<'_> {
779867

780868
// if notify_waiters has been called after the future
781869
// was created, then we are done
782-
if get_num_notify_waiters_calls(curr) != initial_notify_waiters_calls {
870+
if get_num_notify_waiters_calls(curr) != *notify_waiters_calls {
783871
*state = Done;
784872
return Poll::Ready(());
785873
}
@@ -846,21 +934,37 @@ impl Notified<'_> {
846934
return Poll::Pending;
847935
}
848936
Waiting => {
849-
// Currently in the "Waiting" state, implying the caller has
850-
// a waiter stored in the waiter list (guarded by
851-
// `notify.waiters`). In order to access the waker fields,
852-
// we must hold the lock.
937+
// Currently in the "Waiting" state, implying the caller has a waiter stored in
938+
// a waiter list (guarded by `notify.waiters`). In order to access the waker
939+
// fields, we must acquire the lock.
853940

854-
let waiters = notify.waiters.lock();
941+
let mut waiters = notify.waiters.lock();
942+
943+
// Load the state with the lock held.
944+
let curr = notify.state.load(SeqCst);
855945

856946
// Safety: called while locked
857947
let w = unsafe { &mut *waiter.get() };
858948

859949
if w.notified.is_some() {
860-
// Our waker has been notified. Reset the fields and
861-
// remove it from the list.
862-
w.waker = None;
950+
// Our waker has been notified and our waiter is already removed from
951+
// the list. Reset the notification and convert to `Done`.
863952
w.notified = None;
953+
w.waker = None;
954+
*state = Done;
955+
} else if get_num_notify_waiters_calls(curr) != *notify_waiters_calls {
956+
// Before we add a waiter to the list we check if these numbers are
957+
// different while holding the lock. If these numbers are different now,
958+
// it means that there is a call to `notify_waiters` in progress and this
959+
// waiter must be contained by a guarded list used in `notify_waiters`.
960+
// We can treat the waiter as notified and remove it from the list, as
961+
// it would have been notified in the `notify_waiters` call anyways.
962+
963+
w.waker = None;
964+
965+
// Safety: we hold the lock, so we have an exclusive access to the list.
966+
// The list is used in `notify_waiters`, so it must be guarded.
967+
unsafe { waiters.remove(NonNull::new_unchecked(w)) };
864968

865969
*state = Done;
866970
} else {
@@ -906,7 +1010,7 @@ impl Drop for Notified<'_> {
9061010
use State::*;
9071011

9081012
// Safety: The type only transitions to a "Waiting" state when pinned.
909-
let (notify, state, waiter) = unsafe { Pin::new_unchecked(self).project() };
1013+
let (notify, state, _, waiter) = unsafe { Pin::new_unchecked(self).project() };
9101014

9111015
// This is where we ensure safety. The `Notified` value is being
9121016
// dropped, which means we must ensure that the waiter entry is no
@@ -917,8 +1021,10 @@ impl Drop for Notified<'_> {
9171021

9181022
// remove the entry from the list (if not already removed)
9191023
//
920-
// safety: the waiter is only added to `waiters` by virtue of it
921-
// being the only `LinkedList` available to the type.
1024+
// Safety: we hold the lock, so we have an exclusive access to every list the
1025+
// waiter may be contained in. If the node is not contained in the `waiters`
1026+
// list, then it is contained by a guarded list used by `notify_waiters` and
1027+
// in such case it must be a middle node.
9221028
unsafe { waiters.remove(NonNull::new_unchecked(waiter.get())) };
9231029

9241030
if waiters.is_empty() && get_state(notify_state) == WAITING {

0 commit comments

Comments
 (0)