Skip to content

Commit

Permalink
feat: support CancellationSafeFuture
Browse files Browse the repository at this point in the history
  • Loading branch information
ShiKaiWi committed Jul 14, 2023
1 parent 270acc2 commit 4fba378
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 61 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions common_util/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ chrono = { workspace = true }
common_types = { workspace = true, features = ["test"] }
crossbeam-utils = "0.8.7"
env_logger = { workspace = true, optional = true }
futures = { workspace = true }
hex = "0.4.3"
lazy_static = { workspace = true }
libc = "0.2"
Expand Down
199 changes: 138 additions & 61 deletions common_util/src/future_cancel.rs
Original file line number Diff line number Diff line change
@@ -1,91 +1,168 @@
// Copyright 2022-2023 CeresDB Project Authors. Licensed under Apache-2.0.

/// A guard to detect whether a future is cancelled.
pub struct FutureCancelGuard<F: FnMut()> {
cancelled: bool,
on_cancel: F,
}
//! A future wrapper to ensure the wrapped future must be polled.
//!
//! This implementation is forked from: https://github.com/influxdata/influxdb_iox/blob/885767aa0a6010de592bde9992945b01389eb994/cache_system/src/cancellation_safe_future.rs
//! Here is the copyright and license disclaimer:
//! Copyright (c) 2020 InfluxData. Licensed under Apache-2.0.

use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};

use futures::future::BoxFuture;

use crate::runtime::RuntimeRef;

impl<F: FnMut()> FutureCancelGuard<F> {
/// Create a guard to assume the future will be cancelled at the following
/// await point.
/// Wrapper around a future that cannot be cancelled.
///
/// When the future is dropped/cancelled, we'll spawn a tokio task to _rescue_
/// it.
pub struct CancellationSafeFuture<F>
where
F: Future + Send + 'static,
F::Output: Send,
{
/// Mark if the inner future finished. If not, we must spawn a helper task
/// on drop.
done: bool,

/// Inner future.
///
/// If the future is really cancelled, the provided `on_cancel` callback
/// will be executed.
pub fn new_cancelled(on_cancel: F) -> Self {
Self {
cancelled: true,
on_cancel,
/// Wrapped in an `Option` so we can extract it during drop. Inside that
/// option however we also need a pinned box because once this wrapper
/// is polled, it will be pinned in memory -- even during drop. Now the
/// inner future does not necessarily implement `Unpin`, so we need a
/// heap allocation to pin it in memory even when we move it out of this
/// option.
inner: Option<BoxFuture<'static, F::Output>>,

/// The runtime to execute the dropped future.
runtime: RuntimeRef,
}

impl<F> Drop for CancellationSafeFuture<F>
where
F: Future + Send + 'static,
F::Output: Send,
{
fn drop(&mut self) {
if !self.done {
let inner = self.inner.take().unwrap();
let handle = self.runtime.spawn(async move { inner.await });
drop(handle);
}
}
}

/// Set the inner state is uncancelled to ensure the `on_cancel` callback
/// won't be executed.
pub fn uncancel(&mut self) {
self.cancelled = false;
impl<F> CancellationSafeFuture<F>
where
F: Future + Send,
F::Output: Send,
{
/// Create new future that is protected from cancellation.
///
/// If [`CancellationSafeFuture`] is cancelled (i.e. dropped) and there is
/// still some external receiver of the state left, than we will drive
/// the payload (`f`) to completion. Otherwise `f` will be cancelled.
pub fn new(fut: F, runtime: RuntimeRef) -> Self {
Self {
done: false,
inner: Some(Box::pin(fut)),
runtime,
}
}
}

impl<F: FnMut()> Drop for FutureCancelGuard<F> {
fn drop(&mut self) {
if self.cancelled {
(self.on_cancel)();
impl<F> Future for CancellationSafeFuture<F>
where
F: Future + Send,
F::Output: Send,
{
type Output = F::Output;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
assert!(!self.done, "Polling future that already returned");

match self.inner.as_mut().unwrap().as_mut().poll(cx) {
Poll::Ready(res) => {
self.done = true;
Poll::Ready(res)
}
Poll::Pending => Poll::Pending,
}
}
}

#[cfg(test)]
mod tests {
use std::{
sync::{Arc, Mutex},
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::Duration,
};

use tokio::sync::Barrier;

use super::*;
use crate::runtime::Builder;

#[derive(Debug, PartialEq, Eq)]
enum State {
Init,
Processing,
Finished,
Cancelled,
fn rt() -> RuntimeRef {
let rt = Builder::default()
.worker_threads(2)
.thread_name("test_spawn_join")
.enable_all()
.build();
assert!(rt.is_ok());
Arc::new(rt.unwrap())
}

#[tokio::test]
async fn test_future_cancel() {
let state = Arc::new(Mutex::new(State::Init));
let lock = Arc::new(tokio::sync::Mutex::new(()));
#[test]
fn test_happy_path() {
let runtime = rt();
let runtime_clone = runtime.clone();
runtime.block_on(async move {
let done = Arc::new(AtomicBool::new(false));
let done_captured = Arc::clone(&done);

// Hold the lock before the task is spawned.
let lock_guard = lock.lock().await;
let fut = CancellationSafeFuture::new(
async move {
done_captured.store(true, Ordering::SeqCst);
},
runtime_clone,
);

let cloned_lock = lock.clone();
let cloned_state = state.clone();
let handle = tokio::spawn(async move {
{
let mut state = cloned_state.lock().unwrap();
*state = State::Processing;
}
let mut cancel_guard = FutureCancelGuard::new_cancelled(|| {
let mut state = cloned_state.lock().unwrap();
*state = State::Cancelled;
});

// It will be cancelled at this await point.
let _lock_guard = cloned_lock.lock().await;
cancel_guard.uncancel();
let mut state = cloned_state.lock().unwrap();
*state = State::Finished;
});
fut.await;

// Ensure the spawned task is started.
tokio::time::sleep(Duration::from_millis(50)).await;
handle.abort();
// Ensure the future cancel guard is dropped.
tokio::time::sleep(Duration::from_millis(0)).await;
drop(lock_guard);
assert!(done.load(Ordering::SeqCst));
})
}

#[test]
fn test_cancel_future() {
let runtime = rt();
let runtime_clone = runtime.clone();

runtime.block_on(async move {
let done = Arc::new(Barrier::new(2));
let done_captured = Arc::clone(&done);

let state = state.lock().unwrap();
assert_eq!(*state, State::Cancelled);
let fut = CancellationSafeFuture::new(
async move {
done_captured.wait().await;
},
runtime_clone,
);

drop(fut);

tokio::time::timeout(Duration::from_secs(5), done.wait())
.await
.unwrap();
});
}
}

0 comments on commit 4fba378

Please sign in to comment.