Skip to content

Commit

Permalink
feat: support uninterruptible session for asynchronous operations (#52)
Browse files Browse the repository at this point in the history
This way, handing over onstack buffers to asynchronous party will free
from invalid memory access. It is also preferred to perform actively
cancellation to cancel asynchronous operation in task completion.

Closes #45.
  • Loading branch information
kezhuw authored Apr 17, 2024
1 parent 9d41ce5 commit 629ebef
Show file tree
Hide file tree
Showing 3 changed files with 256 additions and 42 deletions.
21 changes: 20 additions & 1 deletion src/coroutine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ thread_local! {
static THREAD_CONTEXT: UnsafeCell<Context> = UnsafeCell::new(Context::empty());
}

pub(crate) fn try_current() -> Option<ptr::NonNull<Coroutine>> {
COROUTINE.with(|p| p.get())
}

pub(crate) fn current() -> ptr::NonNull<Coroutine> {
COROUTINE.with(|p| p.get()).expect("no running coroutine")
}
Expand Down Expand Up @@ -73,11 +77,22 @@ impl ThisThread {
#[derive(PartialEq, Eq, Clone, Copy, Debug)]
pub(crate) enum Status {
Running,
Aborting,
Cancelling,
Completed,
}

impl Status {
pub fn into_abort(self) -> Self {
match self {
Self::Running | Self::Aborting => Self::Aborting,
_ => self,
}
}
}

pub(crate) struct Coroutine {
status: Status,
pub status: Status,
context: Box<Context>,
f: Option<Box<dyn FnOnce()>>,
}
Expand Down Expand Up @@ -121,6 +136,10 @@ impl Coroutine {
pub fn suspend(&mut self) {
ThisThread::suspend(&mut self.context);
}

pub fn is_cancelling(&self) -> bool {
self.status == Status::Cancelling
}
}

/// Spawns a cooperative task and returns a [JoinHandle] for it.
Expand Down
71 changes: 63 additions & 8 deletions src/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::sync::{Arc, Mutex};
use std::{mem, ptr};

use hashbrown::HashMap;
use ignore_result::Ignore;
use static_assertions::assert_impl_all;

pub use self::session::{session, Session, SessionWaker};
Expand Down Expand Up @@ -124,7 +125,7 @@ impl<T: Send + 'static> JoinHandle<T> {
/// Waits for associated task to finish and returns its result.
pub fn join(self) -> Result<T, JoinError> {
let joint = unsafe { self.session.into_joint() };
joint.join().map_err(JoinError::new)
joint.join(None::<fn()>).map_err(JoinError::new)
}
}

Expand All @@ -133,6 +134,30 @@ pub(crate) trait Yielding {
fn interrupt(&self, reason: &'static str) -> bool;
}

static INTERRUPTIBLE: Interruptible = Interruptible {};
static UNINTERRUPTIBLE: Uninterruptible = Uninterruptible {};

struct Interruptible {}
struct Uninterruptible {}

impl Yielding for Interruptible {
fn interrupt(&self, _reason: &'static str) -> bool {
true
}
}

impl Yielding for Uninterruptible {
fn interrupt(&self, _reason: &'static str) -> bool {
false
}
}

pub(crate) enum Interruption<'a, T: FnOnce() = fn()> {
Cancellation(T),
Interruptible(&'a dyn Yielding),
Uninterruptible,
}

#[derive(PartialEq, Eq, Clone, Copy, Debug)]
enum Status {
Running,
Expand Down Expand Up @@ -245,12 +270,19 @@ impl Task {
for co in self.yielding_coroutines.drain(..) {
self.running_coroutines.push_back(co);
}
for (co, yielding) in self.suspending_coroutines.drain() {
for (mut co, yielding) in self.suspending_coroutines.drain() {
yielding.interrupt(msg);
self.running_coroutines.push_back(co);
let co = unsafe { co.as_mut() };
co.status = co.status.into_abort();
}
for (co, _) in self.blocking_coroutines.drain_filter(|_, yielding| yielding.interrupt(msg)) {
for (mut co, _) in self
.blocking_coroutines
.drain_filter(|co, yielding| unsafe { !co.as_ref().is_cancelling() && yielding.interrupt(msg) })
{
self.running_coroutines.push_back(co);
let co = unsafe { co.as_mut() };
co.status = co.status.into_abort();
}
self.unblock(false);
}
Expand Down Expand Up @@ -321,15 +353,38 @@ impl Task {
co.suspend();
}

fn block(&mut self, mut co: ptr::NonNull<Coroutine>, yielding: &dyn Yielding) {
fn block<F: FnOnce()>(&mut self, mut co: ptr::NonNull<Coroutine>, interruption: Interruption<'_, F>) {
assert!(co == coroutine::current(), "Session.block: running coroutine changed");
let yielding = unsafe { std::mem::transmute::<&dyn Yielding, &'_ dyn Yielding>(yielding) };
self.blocking_coroutines.insert(co, yielding);
let co = unsafe { co.as_mut() };
co.suspend();
match interruption {
Interruption::Interruptible(yielding) => {
let yielding = unsafe { std::mem::transmute::<&dyn Yielding, &'_ dyn Yielding>(yielding) };
self.blocking_coroutines.insert(co, yielding);
let co = unsafe { co.as_mut() };
co.suspend();
},
Interruption::Uninterruptible => {
self.blocking_coroutines.insert(co, &UNINTERRUPTIBLE);
let co = unsafe { co.as_mut() };
co.suspend();
},
Interruption::Cancellation(cancellation) => {
self.blocking_coroutines.insert(co, &INTERRUPTIBLE);
let co = unsafe { co.as_mut() };
co.suspend();
if co.status == coroutine::Status::Aborting {
co.status = coroutine::Status::Cancelling;
std::panic::catch_unwind(AssertUnwindSafe(cancellation)).ignore();
co.status = coroutine::Status::Aborting;
// Cancellation completed, wakeup caller to check its completion.
}
},
}
}

pub fn yield_coroutine(&mut self, mut co: ptr::NonNull<Coroutine>) {
if self.status != Status::Running {
return;
}
self.yielding_coroutines.push(co);
let co = unsafe { co.as_mut() };
co.suspend();
Expand Down
Loading

0 comments on commit 629ebef

Please sign in to comment.