77
88use crate :: loom:: sync:: atomic:: AtomicUsize ;
99use crate :: loom:: sync:: Mutex ;
10- use crate :: util:: linked_list:: { self , LinkedList } ;
10+ use crate :: util:: linked_list:: { self , GuardedLinkedList , LinkedList } ;
1111use crate :: util:: WakeList ;
1212
1313use std:: cell:: UnsafeCell ;
@@ -20,6 +20,7 @@ use std::sync::atomic::Ordering::SeqCst;
2020use std:: task:: { Context , Poll , Waker } ;
2121
2222type WaitList = LinkedList < Waiter , <Waiter as linked_list:: Link >:: Target > ;
23+ type GuardedWaitList = GuardedLinkedList < Waiter , <Waiter as linked_list:: Link >:: Target > ;
2324
2425/// Notifies a single task to wake up.
2526///
@@ -198,10 +199,16 @@ type WaitList = LinkedList<Waiter, <Waiter as linked_list::Link>::Target>;
198199/// [`Semaphore`]: crate::sync::Semaphore
199200#[ derive( Debug ) ]
200201pub struct Notify {
201- // This uses 2 bits to store one of `EMPTY`,
202+ // `state` uses 2 bits to store one of `EMPTY`,
202203 // `WAITING` or `NOTIFIED`. The rest of the bits
203204 // are used to store the number of times `notify_waiters`
204205 // was called.
206+ //
207+ // Throughout the code there are two assumptions:
208+ // - state can be transitioned *from* `WAITING` only if
209+ // `waiters` lock is held
210+ // - number of times `notify_waiters` was called can
211+ // be modified only if `waiters` lock is held
205212 state : AtomicUsize ,
206213 waiters : Mutex < WaitList > ,
207214}
@@ -229,6 +236,17 @@ struct Waiter {
229236 _p : PhantomPinned ,
230237}
231238
239+ impl Waiter {
240+ fn new ( ) -> Waiter {
241+ Waiter {
242+ pointers : linked_list:: Pointers :: new ( ) ,
243+ waker : None ,
244+ notified : None ,
245+ _p : PhantomPinned ,
246+ }
247+ }
248+ }
249+
232250generate_addr_of_methods ! {
233251 impl <> Waiter {
234252 unsafe fn addr_of_pointers( self : NonNull <Self >) -> NonNull <linked_list:: Pointers <Waiter >> {
@@ -237,6 +255,59 @@ generate_addr_of_methods! {
237255 }
238256}
239257
258+ /// List used in `Notify::notify_waiters`. It wraps a guarded linked list
259+ /// and gates the access to it on `notify.waiters` mutex. It also empties
260+ /// the list on drop.
261+ struct NotifyWaitersList < ' a > {
262+ list : GuardedWaitList ,
263+ is_empty : bool ,
264+ notify : & ' a Notify ,
265+ }
266+
267+ impl < ' a > NotifyWaitersList < ' a > {
268+ fn new (
269+ unguarded_list : WaitList ,
270+ guard : Pin < & ' a mut UnsafeCell < Waiter > > ,
271+ notify : & ' a Notify ,
272+ ) -> NotifyWaitersList < ' a > {
273+ // Safety: pointer to the guarding waiter is not null.
274+ let guard_ptr = unsafe { NonNull :: new_unchecked ( guard. get ( ) ) } ;
275+ let list = unguarded_list. into_guarded ( guard_ptr) ;
276+ NotifyWaitersList {
277+ list,
278+ is_empty : false ,
279+ notify,
280+ }
281+ }
282+
283+ /// Removes the last element from the guarded list. Modifying this list
284+ /// requires an exclusive access to the main list in `Notify`.
285+ fn pop_back_locked ( & mut self , _waiters : & mut WaitList ) -> Option < NonNull < Waiter > > {
286+ let result = self . list . pop_back ( ) ;
287+ if result. is_none ( ) {
288+ // Save information about emptiness to avoid waiting for lock
289+ // in the destructor.
290+ self . is_empty = true ;
291+ }
292+ result
293+ }
294+ }
295+
296+ impl Drop for NotifyWaitersList < ' _ > {
297+ fn drop ( & mut self ) {
298+ // If the list is not empty, we unlink all waiters from it.
299+ // We do not wake the waiters to avoid double panics.
300+ if !self . is_empty {
301+ let _lock_guard = self . notify . waiters . lock ( ) ;
302+ while let Some ( mut waiter) = self . list . pop_back ( ) {
303+ // Safety: we hold the lock.
304+ let waiter = unsafe { waiter. as_mut ( ) } ;
305+ waiter. notified = Some ( NotificationType :: AllWaiters ) ;
306+ }
307+ }
308+ }
309+ }
310+
240311/// Future returned from [`Notify::notified()`].
241312///
242313/// This future is fused, so once it has completed, any future calls to poll
@@ -249,6 +320,9 @@ pub struct Notified<'a> {
249320 /// The current state of the receiving process.
250321 state : State ,
251322
323+ /// Number of calls to `notify_waiters` at the time of creation.
324+ notify_waiters_calls : usize ,
325+
252326 /// Entry in the waiter `LinkedList`.
253327 waiter : UnsafeCell < Waiter > ,
254328}
@@ -258,7 +332,7 @@ unsafe impl<'a> Sync for Notified<'a> {}
258332
259333#[ derive( Debug ) ]
260334enum State {
261- Init ( usize ) ,
335+ Init ,
262336 Waiting ,
263337 Done ,
264338}
@@ -383,17 +457,13 @@ impl Notify {
383457 /// ```
384458 pub fn notified ( & self ) -> Notified < ' _ > {
385459 // we load the number of times notify_waiters
386- // was called and store that in our initial state
460+ // was called and store that in the future.
387461 let state = self . state . load ( SeqCst ) ;
388462 Notified {
389463 notify : self ,
390- state : State :: Init ( state >> NOTIFY_WAITERS_SHIFT ) ,
391- waiter : UnsafeCell :: new ( Waiter {
392- pointers : linked_list:: Pointers :: new ( ) ,
393- waker : None ,
394- notified : None ,
395- _p : PhantomPinned ,
396- } ) ,
464+ state : State :: Init ,
465+ notify_waiters_calls : get_num_notify_waiters_calls ( state) ,
466+ waiter : UnsafeCell :: new ( Waiter :: new ( ) ) ,
397467 }
398468 }
399469
@@ -500,12 +570,9 @@ impl Notify {
500570 /// }
501571 /// ```
502572 pub fn notify_waiters ( & self ) {
503- let mut wakers = WakeList :: new ( ) ;
504-
505- // There are waiters, the lock must be acquired to notify.
506573 let mut waiters = self . waiters . lock ( ) ;
507574
508- // The state must be reloaded while the lock is held. The state may only
575+ // The state must be loaded while the lock is held. The state may only
509576 // transition out of WAITING while the lock is held.
510577 let curr = self . state . load ( SeqCst ) ;
511578
@@ -516,12 +583,30 @@ impl Notify {
516583 return ;
517584 }
518585
519- // At this point, it is guaranteed that the state will not
520- // concurrently change, as holding the lock is required to
521- // transition **out** of `WAITING`.
586+ // Increment the number of times this method was called
587+ // and transition to empty.
588+ let new_state = set_state ( inc_num_notify_waiters_calls ( curr) , EMPTY ) ;
589+ self . state . store ( new_state, SeqCst ) ;
590+
591+ // It is critical for `GuardedLinkedList` safety that the guard node is
592+ // pinned in memory and is not dropped until the guarded list is dropped.
593+ let guard = UnsafeCell :: new ( Waiter :: new ( ) ) ;
594+ pin ! ( guard) ;
595+
596+ // We move all waiters to a secondary list. It uses a `GuardedLinkedList`
597+ // underneath to allow every waiter to safely remove itself from it.
598+ //
599+ // * This list will be still guarded by the `waiters` lock.
600+ // `NotifyWaitersList` wrapper makes sure we hold the lock to modify it.
601+ // * This wrapper will empty the list on drop. It is critical for safety
602+ // that we will not leave any list entry with a pointer to the local
603+ // guard node after this function returns / panics.
604+ let mut list = NotifyWaitersList :: new ( std:: mem:: take ( & mut * waiters) , guard, self ) ;
605+
606+ let mut wakers = WakeList :: new ( ) ;
522607 ' outer: loop {
523608 while wakers. can_push ( ) {
524- match waiters . pop_back ( ) {
609+ match list . pop_back_locked ( & mut waiters ) {
525610 Some ( mut waiter) => {
526611 // Safety: `waiters` lock is still held.
527612 let waiter = unsafe { waiter. as_mut ( ) } ;
@@ -540,20 +625,17 @@ impl Notify {
540625 }
541626 }
542627
628+ // Release the lock before notifying.
543629 drop ( waiters) ;
544630
631+ // One of the wakers may panic, but the remaining waiters will still
632+ // be unlinked from the list in `NotifyWaitersList` destructor.
545633 wakers. wake_all ( ) ;
546634
547635 // Acquire the lock again.
548636 waiters = self . waiters . lock ( ) ;
549637 }
550638
551- // All waiters will be notified, the state must be transitioned to
552- // `EMPTY`. As transitioning **from** `WAITING` requires the lock to be
553- // held, a `store` is sufficient.
554- let new = set_state ( inc_num_notify_waiters_calls ( curr) , EMPTY ) ;
555- self . state . store ( new, SeqCst ) ;
556-
557639 // Release the lock before notifying
558640 drop ( waiters) ;
559641
@@ -730,26 +812,32 @@ impl Notified<'_> {
730812
731813 /// A custom `project` implementation is used in place of `pin-project-lite`
732814 /// as a custom drop implementation is needed.
733- fn project ( self : Pin < & mut Self > ) -> ( & Notify , & mut State , & UnsafeCell < Waiter > ) {
815+ fn project ( self : Pin < & mut Self > ) -> ( & Notify , & mut State , & usize , & UnsafeCell < Waiter > ) {
734816 unsafe {
735- // Safety: both `notify` and `state ` are `Unpin`.
817+ // Safety: `notify`, `state` and `notify_waiters_calls ` are `Unpin`.
736818
737819 is_unpin :: < & Notify > ( ) ;
738820 is_unpin :: < AtomicUsize > ( ) ;
821+ is_unpin :: < usize > ( ) ;
739822
740823 let me = self . get_unchecked_mut ( ) ;
741- ( me. notify , & mut me. state , & me. waiter )
824+ (
825+ me. notify ,
826+ & mut me. state ,
827+ & me. notify_waiters_calls ,
828+ & me. waiter ,
829+ )
742830 }
743831 }
744832
745833 fn poll_notified ( self : Pin < & mut Self > , waker : Option < & Waker > ) -> Poll < ( ) > {
746834 use State :: * ;
747835
748- let ( notify, state, waiter) = self . project ( ) ;
836+ let ( notify, state, notify_waiters_calls , waiter) = self . project ( ) ;
749837
750838 loop {
751839 match * state {
752- Init ( initial_notify_waiters_calls ) => {
840+ Init => {
753841 let curr = notify. state . load ( SeqCst ) ;
754842
755843 // Optimistically try acquiring a pending notification
@@ -779,7 +867,7 @@ impl Notified<'_> {
779867
780868 // if notify_waiters has been called after the future
781869 // was created, then we are done
782- if get_num_notify_waiters_calls ( curr) != initial_notify_waiters_calls {
870+ if get_num_notify_waiters_calls ( curr) != * notify_waiters_calls {
783871 * state = Done ;
784872 return Poll :: Ready ( ( ) ) ;
785873 }
@@ -846,21 +934,37 @@ impl Notified<'_> {
846934 return Poll :: Pending ;
847935 }
848936 Waiting => {
849- // Currently in the "Waiting" state, implying the caller has
850- // a waiter stored in the waiter list (guarded by
851- // `notify.waiters`). In order to access the waker fields,
852- // we must hold the lock.
937+ // Currently in the "Waiting" state, implying the caller has a waiter stored in
938+ // a waiter list (guarded by `notify.waiters`). In order to access the waker
939+ // fields, we must acquire the lock.
853940
854- let waiters = notify. waiters . lock ( ) ;
941+ let mut waiters = notify. waiters . lock ( ) ;
942+
943+ // Load the state with the lock held.
944+ let curr = notify. state . load ( SeqCst ) ;
855945
856946 // Safety: called while locked
857947 let w = unsafe { & mut * waiter. get ( ) } ;
858948
859949 if w. notified . is_some ( ) {
860- // Our waker has been notified. Reset the fields and
861- // remove it from the list.
862- w. waker = None ;
950+ // Our waker has been notified and our waiter is already removed from
951+ // the list. Reset the notification and convert to `Done`.
863952 w. notified = None ;
953+ w. waker = None ;
954+ * state = Done ;
955+ } else if get_num_notify_waiters_calls ( curr) != * notify_waiters_calls {
956+ // Before we add a waiter to the list we check if these numbers are
957+ // different while holding the lock. If these numbers are different now,
958+ // it means that there is a call to `notify_waiters` in progress and this
959+ // waiter must be contained by a guarded list used in `notify_waiters`.
960+ // We can treat the waiter as notified and remove it from the list, as
961+ // it would have been notified in the `notify_waiters` call anyways.
962+
963+ w. waker = None ;
964+
965+ // Safety: we hold the lock, so we have an exclusive access to the list.
966+ // The list is used in `notify_waiters`, so it must be guarded.
967+ unsafe { waiters. remove ( NonNull :: new_unchecked ( w) ) } ;
864968
865969 * state = Done ;
866970 } else {
@@ -906,7 +1010,7 @@ impl Drop for Notified<'_> {
9061010 use State :: * ;
9071011
9081012 // Safety: The type only transitions to a "Waiting" state when pinned.
909- let ( notify, state, waiter) = unsafe { Pin :: new_unchecked ( self ) . project ( ) } ;
1013+ let ( notify, state, _ , waiter) = unsafe { Pin :: new_unchecked ( self ) . project ( ) } ;
9101014
9111015 // This is where we ensure safety. The `Notified` value is being
9121016 // dropped, which means we must ensure that the waiter entry is no
@@ -917,8 +1021,10 @@ impl Drop for Notified<'_> {
9171021
9181022 // remove the entry from the list (if not already removed)
9191023 //
920- // safety: the waiter is only added to `waiters` by virtue of it
921- // being the only `LinkedList` available to the type.
1024+ // Safety: we hold the lock, so we have an exclusive access to every list the
1025+ // waiter may be contained in. If the node is not contained in the `waiters`
1026+ // list, then it is contained by a guarded list used by `notify_waiters` and
1027+ // in such case it must be a middle node.
9221028 unsafe { waiters. remove ( NonNull :: new_unchecked ( waiter. get ( ) ) ) } ;
9231029
9241030 if waiters. is_empty ( ) && get_state ( notify_state) == WAITING {
0 commit comments