diff --git a/tokio/src/loom/std/mod.rs b/tokio/src/loom/std/mod.rs index 0c611af162a..7d01c05a4e4 100644 --- a/tokio/src/loom/std/mod.rs +++ b/tokio/src/loom/std/mod.rs @@ -59,12 +59,14 @@ pub(crate) mod sync { #[cfg(all(feature = "parking_lot", not(miri)))] #[allow(unused_imports)] pub(crate) use crate::loom::std::parking_lot::{ - Condvar, Mutex, MutexGuard, RwLock, RwLockReadGuard, WaitTimeoutResult, + Condvar, Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard, WaitTimeoutResult, }; #[cfg(not(all(feature = "parking_lot", not(miri))))] #[allow(unused_imports)] - pub(crate) use std::sync::{Condvar, MutexGuard, RwLock, RwLockReadGuard, WaitTimeoutResult}; + pub(crate) use std::sync::{ + Condvar, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard, WaitTimeoutResult, + }; #[cfg(not(all(feature = "parking_lot", not(miri))))] pub(crate) use crate::loom::std::mutex::Mutex; diff --git a/tokio/src/sync/broadcast.rs b/tokio/src/sync/broadcast.rs index 568a50bd59b..d7c6d09507d 100644 --- a/tokio/src/sync/broadcast.rs +++ b/tokio/src/sync/broadcast.rs @@ -118,8 +118,8 @@ use crate::loom::cell::UnsafeCell; use crate::loom::sync::atomic::AtomicUsize; -use crate::loom::sync::{Arc, Mutex, MutexGuard, RwLock, RwLockReadGuard}; -use crate::util::linked_list::{self, GuardedLinkedList, LinkedList}; +use crate::loom::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard}; +use crate::util::linked_list::{self, AtomicLinkedList, GuardedLinkedList}; use crate::util::WakeList; use std::fmt; @@ -310,7 +310,7 @@ struct Shared { mask: usize, /// Tail of the queue. Includes the rx wait list. - tail: Mutex, + tail: RwLock, /// Number of outstanding Sender handles. num_tx: AtomicUsize, @@ -328,7 +328,7 @@ struct Tail { closed: bool, /// Receivers waiting for a value. - waiters: LinkedList::Target>, + waiters: AtomicLinkedList::Target>, } /// Slot in the buffer. @@ -521,11 +521,11 @@ impl Sender { let shared = Arc::new(Shared { buffer: buffer.into_boxed_slice(), mask: capacity - 1, - tail: Mutex::new(Tail { + tail: RwLock::new(Tail { pos: 0, rx_cnt: receiver_count, closed: false, - waiters: LinkedList::new(), + waiters: AtomicLinkedList::new(), }), num_tx: AtomicUsize::new(1), }); @@ -585,7 +585,7 @@ impl Sender { /// } /// ``` pub fn send(&self, value: T) -> Result> { - let mut tail = self.shared.tail.lock(); + let mut tail = self.shared.tail.write().unwrap(); if tail.rx_cnt == 0 { return Err(SendError(value)); @@ -688,7 +688,7 @@ impl Sender { /// } /// ``` pub fn len(&self) -> usize { - let tail = self.shared.tail.lock(); + let tail = self.shared.tail.read().unwrap(); let base_idx = (tail.pos & self.shared.mask as u64) as usize; let mut low = 0; @@ -735,7 +735,7 @@ impl Sender { /// } /// ``` pub fn is_empty(&self) -> bool { - let tail = self.shared.tail.lock(); + let tail = self.shared.tail.read().unwrap(); let idx = (tail.pos.wrapping_sub(1) & self.shared.mask as u64) as usize; self.shared.buffer[idx].read().unwrap().rem.load(SeqCst) == 0 @@ -778,7 +778,7 @@ impl Sender { /// } /// ``` pub fn receiver_count(&self) -> usize { - let tail = self.shared.tail.lock(); + let tail = self.shared.tail.read().unwrap(); tail.rx_cnt } @@ -806,7 +806,7 @@ impl Sender { } fn close_channel(&self) { - let mut tail = self.shared.tail.lock(); + let mut tail = self.shared.tail.write().unwrap(); tail.closed = true; self.shared.notify_rx(tail); @@ -815,7 +815,7 @@ impl Sender { /// Create a new `Receiver` which reads starting from the tail. fn new_receiver(shared: Arc>) -> Receiver { - let mut tail = shared.tail.lock(); + let mut tail = shared.tail.write().unwrap(); assert!(tail.rx_cnt != MAX_RECEIVERS, "max receivers"); @@ -842,7 +842,7 @@ impl<'a, T> Drop for WaitersList<'a, T> { // If the list is not empty, we unlink all waiters from it. // We do not wake the waiters to avoid double panics. if !self.is_empty { - let _lock_guard = self.shared.tail.lock(); + let _lock_guard = self.shared.tail.write().unwrap(); while self.list.pop_back().is_some() {} } } @@ -850,12 +850,12 @@ impl<'a, T> Drop for WaitersList<'a, T> { impl<'a, T> WaitersList<'a, T> { fn new( - unguarded_list: LinkedList::Target>, + unguarded_list: AtomicLinkedList::Target>, guard: Pin<&'a Waiter>, shared: &'a Shared, ) -> Self { let guard_ptr = NonNull::from(guard.get_ref()); - let list = unguarded_list.into_guarded(guard_ptr); + let list = unguarded_list.into_list().into_guarded(guard_ptr); WaitersList { list, is_empty: false, @@ -877,7 +877,7 @@ impl<'a, T> WaitersList<'a, T> { } impl Shared { - fn notify_rx<'a, 'b: 'a>(&'b self, mut tail: MutexGuard<'a, Tail>) { + fn notify_rx<'a, 'b: 'a>(&'b self, mut tail: RwLockWriteGuard<'a, Tail>) { // It is critical for `GuardedLinkedList` safety that the guard node is // pinned in memory and is not dropped until the guarded list is dropped. let guard = Waiter::new(); @@ -925,7 +925,7 @@ impl Shared { wakers.wake_all(); // Acquire the lock again. - tail = self.tail.lock(); + tail = self.tail.write().unwrap(); } // Release the lock before waking. @@ -987,7 +987,7 @@ impl Receiver { /// } /// ``` pub fn len(&self) -> usize { - let next_send_pos = self.shared.tail.lock().pos; + let next_send_pos = self.shared.tail.read().unwrap().pos; (next_send_pos - self.next) as usize } @@ -1065,7 +1065,7 @@ impl Receiver { let mut old_waker = None; - let mut tail = self.shared.tail.lock(); + let tail = self.shared.tail.read().unwrap(); // Acquire slot lock again slot = self.shared.buffer[idx].read().unwrap(); @@ -1086,7 +1086,16 @@ impl Receiver { // Store the waker if let Some((waiter, waker)) = waiter { - // Safety: called while locked. + // Safety: called while holding a read lock on tail. + // It suffices since we only update two waiter members: + // - `waiter.waker` - all other accesses of this member are + // write-lock protected, + // - `waiter.queued` - all other accesses of this member are + // either write-lock protected or read-lock protected with + // exclusive reference to the `Recv` that contains the waiter. + // Concurrent calls to `recv_ref` with the same waiter + // are impossible because it implies ownership of the `Recv` + // that contains it. unsafe { // Only queue if not already queued waiter.with_mut(|ptr| { @@ -1106,6 +1115,11 @@ impl Receiver { if !(*ptr).queued { (*ptr).queued = true; + // Safety: + // - `waiter` is not already queued, + // - calling `recv_ref` with a waiter implies ownership + // of it's `Recv`. As such, this waiter cannot be pushed + // concurrently by some other thread. tail.waiters.push_front(NonNull::new_unchecked(&mut *ptr)); } }); @@ -1331,7 +1345,7 @@ impl Receiver { impl Drop for Receiver { fn drop(&mut self) { - let mut tail = self.shared.tail.lock(); + let mut tail = self.shared.tail.write().unwrap(); tail.rx_cnt -= 1; let until = tail.pos; @@ -1402,22 +1416,34 @@ where impl<'a, T> Drop for Recv<'a, T> { fn drop(&mut self) { - // Acquire the tail lock. This is required for safety before accessing + // Acquire a read lock on tail. This is required for safety before accessing // the waiter node. - let mut tail = self.receiver.shared.tail.lock(); + let tail = self.receiver.shared.tail.read().unwrap(); - // safety: tail lock is held + // Safety: we hold read lock on tail AND have exclusive reference to `Recv`. let queued = self.waiter.with(|ptr| unsafe { (*ptr).queued }); if queued { - // Remove the node + // Optimistic check failed. To remove the waiter, we need a write lock. + drop(tail); + let mut tail = self.receiver.shared.tail.write().unwrap(); + + // Double check that the waiter is still enqueued, + // in case it was removed before we reacquired the lock. // - // safety: tail lock is held and the wait node is verified to be in - // the list. - unsafe { - self.waiter.with_mut(|ptr| { - tail.waiters.remove((&mut *ptr).into()); - }); + // Safety: tail write lock is held. + let queued = self.waiter.with(|ptr| unsafe { (*ptr).queued }); + + if queued { + // Remove the node. + // + // Safety: tail write lock is held and the wait node is verified to be in + // the list. + unsafe { + self.waiter.with_mut(|ptr| { + tail.waiters.remove((&mut *ptr).into()); + }); + } } } } diff --git a/tokio/src/util/linked_list.rs b/tokio/src/util/linked_list.rs index 0ed2b616456..08ef4a28904 100644 --- a/tokio/src/util/linked_list.rs +++ b/tokio/src/util/linked_list.rs @@ -11,6 +11,10 @@ use core::fmt; use core::marker::{PhantomData, PhantomPinned}; use core::mem::ManuallyDrop; use core::ptr::{self, NonNull}; +use core::sync::atomic::{ + AtomicPtr, + Ordering::{AcqRel, Relaxed}, +}; /// An intrusive linked list. /// @@ -108,6 +112,52 @@ struct PointersInner { unsafe impl Send for Pointers {} unsafe impl Sync for Pointers {} +// ===== LinkedListBase ===== + +// Common methods between LinkedList and AtomicLinkedList. +trait LinkedListBase { + // NB: exclusive reference is important for AtomicLinkedList safety guarantees. + fn head(&mut self) -> Option>; + fn tail(&mut self) -> Option>; + + fn set_head(&mut self, node: Option>); + fn set_tail(&mut self, node: Option>); + + unsafe fn remove(&mut self, node: NonNull) -> Option { + if let Some(prev) = L::pointers(node).as_ref().get_prev() { + debug_assert_eq!(L::pointers(prev).as_ref().get_next(), Some(node)); + L::pointers(prev) + .as_mut() + .set_next(L::pointers(node).as_ref().get_next()); + } else { + if self.head() != Some(node) { + return None; + } + + self.set_head(L::pointers(node).as_ref().get_next()); + } + + if let Some(next) = L::pointers(node).as_ref().get_next() { + debug_assert_eq!(L::pointers(next).as_ref().get_prev(), Some(node)); + L::pointers(next) + .as_mut() + .set_prev(L::pointers(node).as_ref().get_prev()); + } else { + // This might be the last item in the list + if self.tail() != Some(node) { + return None; + } + + self.set_tail(L::pointers(node).as_ref().get_prev()); + } + + L::pointers(node).as_mut().set_next(None); + L::pointers(node).as_mut().set_prev(None); + + Some(L::from_raw(node)) + } +} + // ===== impl LinkedList ===== impl LinkedList { @@ -121,6 +171,24 @@ impl LinkedList { } } +impl LinkedListBase for LinkedList { + fn head(&mut self) -> Option::Target>> { + self.head + } + + fn tail(&mut self) -> Option::Target>> { + self.tail + } + + fn set_head(&mut self, node: Option::Target>>) { + self.head = node; + } + + fn set_tail(&mut self, node: Option::Target>>) { + self.tail = node; + } +} + impl LinkedList { /// Adds an element first in the list. pub(crate) fn push_front(&mut self, val: L::Handle) { @@ -185,37 +253,7 @@ impl LinkedList { /// the caller has an exclusive access to that list. This condition is /// used by the linked list in `sync::Notify`. pub(crate) unsafe fn remove(&mut self, node: NonNull) -> Option { - if let Some(prev) = L::pointers(node).as_ref().get_prev() { - debug_assert_eq!(L::pointers(prev).as_ref().get_next(), Some(node)); - L::pointers(prev) - .as_mut() - .set_next(L::pointers(node).as_ref().get_next()); - } else { - if self.head != Some(node) { - return None; - } - - self.head = L::pointers(node).as_ref().get_next(); - } - - if let Some(next) = L::pointers(node).as_ref().get_next() { - debug_assert_eq!(L::pointers(next).as_ref().get_prev(), Some(node)); - L::pointers(next) - .as_mut() - .set_prev(L::pointers(node).as_ref().get_prev()); - } else { - // This might be the last item in the list - if self.tail != Some(node) { - return None; - } - - self.tail = L::pointers(node).as_ref().get_prev(); - } - - L::pointers(node).as_mut().set_next(None); - L::pointers(node).as_mut().set_prev(None); - - Some(L::from_raw(node)) + LinkedListBase::remove(self, node) } } @@ -313,6 +351,141 @@ cfg_taskdump! { } } +// ===== impl AtomicLinkedList ===== + +feature! { + #![feature = "sync"] + + /// An atomic intrusive linked list. It allows pushing new nodes concurrently. + /// Removing nodes still requires an exclusive reference. + pub(crate) struct AtomicLinkedList { + /// Linked list head. + head: AtomicPtr, + + /// Linked list tail. + tail: UnsafeCell>>, + + /// Node type marker. + _marker: PhantomData<*const L>, + } + + unsafe impl Send for AtomicLinkedList where L::Target: Send {} + unsafe impl Sync for AtomicLinkedList where L::Target: Sync {} + + impl Default for AtomicLinkedList { + fn default() -> Self { + Self::new() + } + } + + impl AtomicLinkedList { + /// Creates an empty atomic linked list. + pub(crate) const fn new() -> AtomicLinkedList { + AtomicLinkedList { + head: AtomicPtr::new(core::ptr::null_mut()), + tail: UnsafeCell::new(None), + _marker: PhantomData, + } + } + + /// Convert an atomic linked list into a non-atomic version. + pub(crate) fn into_list(mut self) -> LinkedList { + LinkedList { + head: NonNull::new(*self.head.get_mut()), + tail: *self.tail.get_mut(), + _marker: PhantomData, + } + } + } + + impl LinkedListBase for AtomicLinkedList { + fn head(&mut self) -> Option> { + NonNull::new(*self.head.get_mut()) + } + + fn tail(&mut self) -> Option> { + *self.tail.get_mut() + } + + fn set_head(&mut self, node: Option>) { + *self.head.get_mut() = match node { + Some(ptr) => ptr.as_ptr(), + None => core::ptr::null_mut(), + }; + } + + fn set_tail(&mut self, node: Option>) { + *self.tail.get_mut() = node; + } + } + + impl AtomicLinkedList { + /// Atomically adds an element first in the list. + /// This method can be called concurrently from multiple threads. + /// + /// # Safety + /// + /// The caller must ensure that: + /// - `val` is not pushed concurrently by muptiple threads, + /// - `val` is not already part of some list. + pub(crate) unsafe fn push_front(&self, val: L::Handle) { + // Note that removing nodes from the list still requires + // an exclusive reference, so we need not worry about + // concurrent node removals. + + // The value should not be dropped, it is being inserted into the list. + let val = ManuallyDrop::new(val); + let new_head = L::as_raw(&val); + + // Safety: due to the function contract, no concurrent `push_front` + // is called on this particular element, so we are safe to assume + // ownership. + L::pointers(new_head).as_mut().set_prev(None); + + let mut old_head = self.head.load(Relaxed); + loop { + // Safety: due to the function contract, no concurrent `push_front` + // is called on this particular element, and we have not + // inserted it into the list, so we can still assume ownership. + L::pointers(new_head).as_mut().set_next(NonNull::new(old_head)); + + let Err(actual_head) = self.head.compare_exchange_weak( + old_head, + new_head.as_ptr(), + AcqRel, + Relaxed, + ) else { + break; + }; + + old_head = actual_head; + } + + if old_head.is_null() { + // Safety: only the thread that successfully inserted the first + // element is granted the right to update tail. + *self.tail.get() = Some(new_head); + } else { + // Safety: + // 1. Only the thread that successfully inserted the new element + // is granted the right to update the previous head's `prev`, + // 2. Upon successfull insertion, we have synchronized with all the + // previous insertions (due to `AcqRel` memory ordering), so all + // the previous `set_prev` calls on `old_head` happen-before this call, + // 3. Due the `push_front` contract, we can assume that `old_head` + // is not pushed concurrently by another thread, as it is already + // in the list. Thus, no data race on `set_prev` can happen. + L::pointers(NonNull::new_unchecked(old_head)).as_mut().set_prev(Some(new_head)); + } + } + + /// See [LinkedList::remove]. + pub(crate) unsafe fn remove(&mut self, node: NonNull) -> Option { + LinkedListBase::remove(self, node) + } + } +} + // ===== impl GuardedLinkedList ===== feature! { @@ -797,4 +970,45 @@ pub(crate) mod tests { } } } + + #[cfg(feature = "sync")] + #[test] + fn atomic_push_front() { + use std::sync::Arc; + + let atomic_list = Arc::new(AtomicLinkedList::<&Entry, <&Entry as Link>::Target>::new()); + + let _entries = [5, 7] + .into_iter() + .map(|x| { + std::thread::spawn({ + let atomic_list = atomic_list.clone(); + move || { + let list_entry = entry(x); + unsafe { + atomic_list.push_front(list_entry.as_ref()); + } + list_entry + } + }) + }) + .collect::>() + .into_iter() + .map(|handle| handle.join().unwrap()) + .collect::>(); + + let mut list = Arc::into_inner(atomic_list).unwrap().into_list(); + + assert!(!list.is_empty()); + + let first = list.pop_back().unwrap(); + assert!(first.val == 5 || first.val == 7); + + let second = list.pop_back().unwrap(); + assert!(second.val == 5 || second.val == 7); + assert_ne!(first.val, second.val); + + assert!(list.is_empty()); + assert!(list.pop_back().is_none()); + } }