Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

time: revert "time: avoid traversing entries in the time wheel twice" #6715

Merged
merged 1 commit into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 78 additions & 21 deletions tokio/src/runtime/time/entry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
//!
//! Each timer has a state field associated with it. This field contains either
//! the current scheduled time, or a special flag value indicating its state.
//! This state can either indicate that the timer is firing (and thus will be fired
//! with an `Ok(())` result soon) or that it has already been fired/deregistered.
//! This state can either indicate that the timer is on the 'pending' queue (and
//! thus will be fired with an `Ok(())` result soon) or that it has already been
//! fired/deregistered.
//!
//! This single state field allows for code that is firing the timer to
//! synchronize with any racing `reset` calls reliably.
Expand All @@ -48,10 +49,10 @@
//! There is of course a race condition between timer reset and timer
//! expiration. If the driver fails to observe the updated expiration time, it
//! could trigger expiration of the timer too early. However, because
//! [`mark_firing`][mark_firing] performs a compare-and-swap, it will identify this race and
//! refuse to mark the timer as firing.
//! [`mark_pending`][mark_pending] performs a compare-and-swap, it will identify this race and
//! refuse to mark the timer as pending.
//!
//! [mark_firing]: TimerHandle::mark_firing
//! [mark_pending]: TimerHandle::mark_pending

use crate::loom::cell::UnsafeCell;
use crate::loom::sync::atomic::AtomicU64;
Expand All @@ -69,9 +70,9 @@ use std::{marker::PhantomPinned, pin::Pin, ptr::NonNull};

type TimerResult = Result<(), crate::time::error::Error>;

pub(super) const STATE_DEREGISTERED: u64 = u64::MAX;
const STATE_FIRING: u64 = STATE_DEREGISTERED - 1;
const STATE_MIN_VALUE: u64 = STATE_FIRING;
const STATE_DEREGISTERED: u64 = u64::MAX;
const STATE_PENDING_FIRE: u64 = STATE_DEREGISTERED - 1;
const STATE_MIN_VALUE: u64 = STATE_PENDING_FIRE;
/// The largest safe integer to use for ticks.
///
/// This value should be updated if any other signal values are added above.
Expand Down Expand Up @@ -122,6 +123,10 @@ impl StateCell {
}
}

fn is_pending(&self) -> bool {
self.state.load(Ordering::Relaxed) == STATE_PENDING_FIRE
}

/// Returns the current expiration time, or None if not currently scheduled.
fn when(&self) -> Option<u64> {
let cur_state = self.state.load(Ordering::Relaxed);
Expand Down Expand Up @@ -157,28 +162,26 @@ impl StateCell {
}
}

/// Marks this timer firing, if its scheduled time is not after `not_after`.
/// Marks this timer as being moved to the pending list, if its scheduled
/// time is not after `not_after`.
///
/// If the timer is scheduled for a time after `not_after`, returns an Err
/// containing the current scheduled time.
///
/// SAFETY: Must hold the driver lock.
unsafe fn mark_firing(&self, not_after: u64) -> Result<(), u64> {
unsafe fn mark_pending(&self, not_after: u64) -> Result<(), u64> {
// Quick initial debug check to see if the timer is already fired. Since
// firing the timer can only happen with the driver lock held, we know
// we shouldn't be able to "miss" a transition to a fired state, even
// with relaxed ordering.
let mut cur_state = self.state.load(Ordering::Relaxed);

loop {
// Because its state is STATE_DEREGISTERED, it has been fired.
if cur_state == STATE_DEREGISTERED {
break Err(cur_state);
}
// improve the error message for things like
// https://github.com/tokio-rs/tokio/issues/3675
assert!(
cur_state < STATE_MIN_VALUE,
"mark_firing called when the timer entry is in an invalid state"
"mark_pending called when the timer entry is in an invalid state"
);

if cur_state > not_after {
Expand All @@ -187,7 +190,7 @@ impl StateCell {

match self.state.compare_exchange_weak(
cur_state,
STATE_FIRING,
STATE_PENDING_FIRE,
Ordering::AcqRel,
Ordering::Acquire,
) {
Expand Down Expand Up @@ -334,6 +337,11 @@ pub(crate) struct TimerShared {
/// Only accessed under the entry lock.
pointers: linked_list::Pointers<TimerShared>,

/// The expiration time for which this entry is currently registered.
/// Generally owned by the driver, but is accessed by the entry when not
/// registered.
cached_when: AtomicU64,

/// Current state. This records whether the timer entry is currently under
/// the ownership of the driver, and if not, its current state (not
/// complete, fired, error, etc).
Expand All @@ -348,6 +356,7 @@ unsafe impl Sync for TimerShared {}
impl std::fmt::Debug for TimerShared {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TimerShared")
.field("cached_when", &self.cached_when.load(Ordering::Relaxed))
.field("state", &self.state)
.finish()
}
Expand All @@ -365,12 +374,40 @@ impl TimerShared {
pub(super) fn new(shard_id: u32) -> Self {
Self {
shard_id,
cached_when: AtomicU64::new(0),
pointers: linked_list::Pointers::new(),
state: StateCell::default(),
_p: PhantomPinned,
}
}

/// Gets the cached time-of-expiration value.
pub(super) fn cached_when(&self) -> u64 {
// Cached-when is only accessed under the driver lock, so we can use relaxed
self.cached_when.load(Ordering::Relaxed)
}

/// Gets the true time-of-expiration value, and copies it into the cached
/// time-of-expiration value.
///
/// SAFETY: Must be called with the driver lock held, and when this entry is
/// not in any timer wheel lists.
pub(super) unsafe fn sync_when(&self) -> u64 {
let true_when = self.true_when();

self.cached_when.store(true_when, Ordering::Relaxed);

true_when
}

/// Sets the cached time-of-expiration value.
///
/// SAFETY: Must be called with the driver lock held, and when this entry is
/// not in any timer wheel lists.
unsafe fn set_cached_when(&self, when: u64) {
self.cached_when.store(when, Ordering::Relaxed);
}

/// Returns the true time-of-expiration value, with relaxed memory ordering.
pub(super) fn true_when(&self) -> u64 {
self.state.when().expect("Timer already fired")
Expand All @@ -383,6 +420,7 @@ impl TimerShared {
/// in the timer wheel.
pub(super) unsafe fn set_expiration(&self, t: u64) {
self.state.set_expiration(t);
self.cached_when.store(t, Ordering::Relaxed);
}

/// Sets the true time-of-expiration only if it is after the current.
Expand Down Expand Up @@ -552,8 +590,16 @@ impl TimerEntry {
}

impl TimerHandle {
pub(super) unsafe fn true_when(&self) -> u64 {
unsafe { self.inner.as_ref().true_when() }
pub(super) unsafe fn cached_when(&self) -> u64 {
unsafe { self.inner.as_ref().cached_when() }
}

pub(super) unsafe fn sync_when(&self) -> u64 {
unsafe { self.inner.as_ref().sync_when() }
}

pub(super) unsafe fn is_pending(&self) -> bool {
unsafe { self.inner.as_ref().state.is_pending() }
}

/// Forcibly sets the true and cached expiration times to the given tick.
Expand All @@ -564,16 +610,27 @@ impl TimerHandle {
self.inner.as_ref().set_expiration(tick);
}

/// Attempts to mark this entry as firing. If the expiration time is after
/// Attempts to mark this entry as pending. If the expiration time is after
/// `not_after`, however, returns an Err with the current expiration time.
///
/// If an `Err` is returned, the `cached_when` value will be updated to this
/// new expiration time.
///
/// SAFETY: The caller must ensure that the handle remains valid, the driver
/// lock is held, and that the timer is not in any wheel linked lists.
pub(super) unsafe fn mark_firing(&self, not_after: u64) -> Result<(), u64> {
self.inner.as_ref().state.mark_firing(not_after)
/// After returning Ok, the entry must be added to the pending list.
pub(super) unsafe fn mark_pending(&self, not_after: u64) -> Result<(), u64> {
match self.inner.as_ref().state.mark_pending(not_after) {
Ok(()) => {
// mark this as being on the pending queue in cached_when
self.inner.as_ref().set_cached_when(u64::MAX);
Ok(())
}
Err(tick) => {
self.inner.as_ref().set_cached_when(tick);
Err(tick)
}
}
}

/// Attempts to transition to a terminal state. If the state is already a
Expand Down
60 changes: 15 additions & 45 deletions tokio/src/runtime/time/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

mod entry;
pub(crate) use entry::TimerEntry;
use entry::{EntryList, TimerHandle, TimerShared, MAX_SAFE_MILLIS_DURATION, STATE_DEREGISTERED};
use entry::{EntryList, TimerHandle, TimerShared, MAX_SAFE_MILLIS_DURATION};

mod handle;
pub(crate) use self::handle::Handle;
Expand Down Expand Up @@ -324,53 +324,23 @@ impl Handle {
now = lock.elapsed();
}

while let Some(expiration) = lock.poll(now) {
lock.set_elapsed(expiration.deadline);
// It is critical for `GuardedLinkedList` safety that the guard node is
// pinned in memory and is not dropped until the guarded list is dropped.
let guard = TimerShared::new(id);
pin!(guard);
let guard_handle = guard.as_ref().get_ref().handle();

// * This list will be still guarded by the lock of the Wheel with the specefied id.
// `EntryWaitersList` wrapper makes sure we hold the lock to modify it.
// * This wrapper will empty the list on drop. It is critical for safety
// that we will not leave any list entry with a pointer to the local
// guard node after this function returns / panics.
// Safety: The `TimerShared` inside this `TimerHandle` is pinned in the memory.
let mut list = unsafe { lock.get_waiters_list(&expiration, guard_handle, id, self) };

while let Some(entry) = list.pop_back_locked(&mut lock) {
let deadline = expiration.deadline;
// Try to expire the entry; this is cheap (doesn't synchronize) if
// the timer is not expired, and updates cached_when.
match unsafe { entry.mark_firing(deadline) } {
Ok(()) => {
// Entry was expired.
// SAFETY: We hold the driver lock, and just removed the entry from any linked lists.
if let Some(waker) = unsafe { entry.fire(Ok(())) } {
waker_list.push(waker);

if !waker_list.can_push() {
// Wake a batch of wakers. To avoid deadlock,
// we must do this with the lock temporarily dropped.
drop(lock);
waker_list.wake_all();

lock = self.inner.lock_sharded_wheel(id);
}
}
}
Err(state) if state == STATE_DEREGISTERED => {}
Err(state) => {
// Safety: This Entry has not expired.
unsafe { lock.reinsert_entry(entry, deadline, state) };
}
while let Some(entry) = lock.poll(now) {
debug_assert!(unsafe { entry.is_pending() });

// SAFETY: We hold the driver lock, and just removed the entry from any linked lists.
if let Some(waker) = unsafe { entry.fire(Ok(())) } {
waker_list.push(waker);

if !waker_list.can_push() {
// Wake a batch of wakers. To avoid deadlock, we must do this with the lock temporarily dropped.
drop(lock);

waker_list.wake_all();

lock = self.inner.lock_sharded_wheel(id);
}
}
lock.occupied_bit_maintain(&expiration);
}

let next_wake_up = lock.poll_at();
drop(lock);

Expand Down
22 changes: 10 additions & 12 deletions tokio/src/runtime/time/wheel/level.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pub(crate) struct Level {
}

/// Indicates when a slot must be processed next.
#[derive(Debug)]
pub(crate) struct Expiration {
/// The level containing the slot.
pub(crate) level: usize,
Expand Down Expand Up @@ -80,7 +81,7 @@ impl Level {
// pseudo-ring buffer, and we rotate around them indefinitely. If we
// compute a deadline before now, and it's the top level, it
// therefore means we're actually looking at a slot in the future.
debug_assert_eq!(self.level, super::MAX_LEVEL_INDEX);
debug_assert_eq!(self.level, super::NUM_LEVELS - 1);

deadline += level_range;
}
Expand Down Expand Up @@ -119,33 +120,30 @@ impl Level {
}

pub(crate) unsafe fn add_entry(&mut self, item: TimerHandle) {
let slot = slot_for(item.true_when(), self.level);
let slot = slot_for(item.cached_when(), self.level);

self.slot[slot].push_front(item);

self.occupied |= occupied_bit(slot);
}

pub(crate) unsafe fn remove_entry(&mut self, item: NonNull<TimerShared>) {
let slot = slot_for(unsafe { item.as_ref().true_when() }, self.level);
let slot = slot_for(unsafe { item.as_ref().cached_when() }, self.level);

unsafe { self.slot[slot].remove(item) };
if self.slot[slot].is_empty() {
// The bit is currently set
debug_assert!(self.occupied & occupied_bit(slot) != 0);

// Unset the bit
self.occupied ^= occupied_bit(slot);
}
}

pub(super) fn take_slot(&mut self, slot: usize) -> EntryList {
std::mem::take(&mut self.slot[slot])
}
pub(crate) fn take_slot(&mut self, slot: usize) -> EntryList {
self.occupied &= !occupied_bit(slot);

pub(super) fn occupied_bit_maintain(&mut self, slot: usize) {
if self.slot[slot].is_empty() {
self.occupied &= !occupied_bit(slot);
} else {
self.occupied |= occupied_bit(slot);
}
std::mem::take(&mut self.slot[slot])
}
}

Expand Down
Loading
Loading