From 4fba37881ecd9120ab652905b5323ca7f0288244 Mon Sep 17 00:00:00 2001 From: "xikai.wxk" Date: Thu, 13 Jul 2023 09:53:33 +0800 Subject: [PATCH] feat: support CancellationSafeFuture --- Cargo.lock | 1 + common_util/Cargo.toml | 1 + common_util/src/future_cancel.rs | 199 +++++++++++++++++++++---------- 3 files changed, 140 insertions(+), 61 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8cd113cde2..3fe0415f2a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1319,6 +1319,7 @@ dependencies = [ "common_types", "crossbeam-utils 0.8.15", "env_logger", + "futures 0.3.28", "gag", "hex", "lazy_static", diff --git a/common_util/Cargo.toml b/common_util/Cargo.toml index bc2032af19..bef0814525 100644 --- a/common_util/Cargo.toml +++ b/common_util/Cargo.toml @@ -22,6 +22,7 @@ chrono = { workspace = true } common_types = { workspace = true, features = ["test"] } crossbeam-utils = "0.8.7" env_logger = { workspace = true, optional = true } +futures = { workspace = true } hex = "0.4.3" lazy_static = { workspace = true } libc = "0.2" diff --git a/common_util/src/future_cancel.rs b/common_util/src/future_cancel.rs index d7d504b259..93b8eaf9bc 100644 --- a/common_util/src/future_cancel.rs +++ b/common_util/src/future_cancel.rs @@ -1,35 +1,97 @@ // Copyright 2022-2023 CeresDB Project Authors. Licensed under Apache-2.0. -/// A guard to detect whether a future is cancelled. -pub struct FutureCancelGuard { - cancelled: bool, - on_cancel: F, -} +//! A future wrapper to ensure the wrapped future must be polled. +//! +//! This implementation is forked from: https://github.com/influxdata/influxdb_iox/blob/885767aa0a6010de592bde9992945b01389eb994/cache_system/src/cancellation_safe_future.rs +//! Here is the copyright and license disclaimer: +//! Copyright (c) 2020 InfluxData. Licensed under Apache-2.0. + +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +use futures::future::BoxFuture; + +use crate::runtime::RuntimeRef; -impl FutureCancelGuard { - /// Create a guard to assume the future will be cancelled at the following - /// await point. +/// Wrapper around a future that cannot be cancelled. +/// +/// When the future is dropped/cancelled, we'll spawn a tokio task to _rescue_ +/// it. +pub struct CancellationSafeFuture +where + F: Future + Send + 'static, + F::Output: Send, +{ + /// Mark if the inner future finished. If not, we must spawn a helper task + /// on drop. + done: bool, + + /// Inner future. /// - /// If the future is really cancelled, the provided `on_cancel` callback - /// will be executed. - pub fn new_cancelled(on_cancel: F) -> Self { - Self { - cancelled: true, - on_cancel, + /// Wrapped in an `Option` so we can extract it during drop. Inside that + /// option however we also need a pinned box because once this wrapper + /// is polled, it will be pinned in memory -- even during drop. Now the + /// inner future does not necessarily implement `Unpin`, so we need a + /// heap allocation to pin it in memory even when we move it out of this + /// option. + inner: Option>, + + /// The runtime to execute the dropped future. + runtime: RuntimeRef, +} + +impl Drop for CancellationSafeFuture +where + F: Future + Send + 'static, + F::Output: Send, +{ + fn drop(&mut self) { + if !self.done { + let inner = self.inner.take().unwrap(); + let handle = self.runtime.spawn(async move { inner.await }); + drop(handle); } } +} - /// Set the inner state is uncancelled to ensure the `on_cancel` callback - /// won't be executed. - pub fn uncancel(&mut self) { - self.cancelled = false; +impl CancellationSafeFuture +where + F: Future + Send, + F::Output: Send, +{ + /// Create new future that is protected from cancellation. + /// + /// If [`CancellationSafeFuture`] is cancelled (i.e. dropped) and there is + /// still some external receiver of the state left, than we will drive + /// the payload (`f`) to completion. Otherwise `f` will be cancelled. + pub fn new(fut: F, runtime: RuntimeRef) -> Self { + Self { + done: false, + inner: Some(Box::pin(fut)), + runtime, + } } } -impl Drop for FutureCancelGuard { - fn drop(&mut self) { - if self.cancelled { - (self.on_cancel)(); +impl Future for CancellationSafeFuture +where + F: Future + Send, + F::Output: Send, +{ + type Output = F::Output; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + assert!(!self.done, "Polling future that already returned"); + + match self.inner.as_mut().unwrap().as_mut().poll(cx) { + Poll::Ready(res) => { + self.done = true; + Poll::Ready(res) + } + Poll::Pending => Poll::Pending, } } } @@ -37,55 +99,70 @@ impl Drop for FutureCancelGuard { #[cfg(test)] mod tests { use std::{ - sync::{Arc, Mutex}, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, time::Duration, }; + use tokio::sync::Barrier; + use super::*; + use crate::runtime::Builder; - #[derive(Debug, PartialEq, Eq)] - enum State { - Init, - Processing, - Finished, - Cancelled, + fn rt() -> RuntimeRef { + let rt = Builder::default() + .worker_threads(2) + .thread_name("test_spawn_join") + .enable_all() + .build(); + assert!(rt.is_ok()); + Arc::new(rt.unwrap()) } - #[tokio::test] - async fn test_future_cancel() { - let state = Arc::new(Mutex::new(State::Init)); - let lock = Arc::new(tokio::sync::Mutex::new(())); + #[test] + fn test_happy_path() { + let runtime = rt(); + let runtime_clone = runtime.clone(); + runtime.block_on(async move { + let done = Arc::new(AtomicBool::new(false)); + let done_captured = Arc::clone(&done); - // Hold the lock before the task is spawned. - let lock_guard = lock.lock().await; + let fut = CancellationSafeFuture::new( + async move { + done_captured.store(true, Ordering::SeqCst); + }, + runtime_clone, + ); - let cloned_lock = lock.clone(); - let cloned_state = state.clone(); - let handle = tokio::spawn(async move { - { - let mut state = cloned_state.lock().unwrap(); - *state = State::Processing; - } - let mut cancel_guard = FutureCancelGuard::new_cancelled(|| { - let mut state = cloned_state.lock().unwrap(); - *state = State::Cancelled; - }); - - // It will be cancelled at this await point. - let _lock_guard = cloned_lock.lock().await; - cancel_guard.uncancel(); - let mut state = cloned_state.lock().unwrap(); - *state = State::Finished; - }); + fut.await; - // Ensure the spawned task is started. - tokio::time::sleep(Duration::from_millis(50)).await; - handle.abort(); - // Ensure the future cancel guard is dropped. - tokio::time::sleep(Duration::from_millis(0)).await; - drop(lock_guard); + assert!(done.load(Ordering::SeqCst)); + }) + } + + #[test] + fn test_cancel_future() { + let runtime = rt(); + let runtime_clone = runtime.clone(); + + runtime.block_on(async move { + let done = Arc::new(Barrier::new(2)); + let done_captured = Arc::clone(&done); - let state = state.lock().unwrap(); - assert_eq!(*state, State::Cancelled); + let fut = CancellationSafeFuture::new( + async move { + done_captured.wait().await; + }, + runtime_clone, + ); + + drop(fut); + + tokio::time::timeout(Duration::from_secs(5), done.wait()) + .await + .unwrap(); + }); } }