Skip to content

Commit

Permalink
feat: complete task only after all auxiliary coroutines completed (#48)
Browse files Browse the repository at this point in the history
Closes #46.
  • Loading branch information
kezhuw authored Apr 16, 2024
1 parent 702cdfe commit 9d41ce5
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 18 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ more-asserts = "0.2.2"

[dev-dependencies]
pretty_assertions = "1.2.0"
scopeguard = "1.2.0"
test-case = "2.0.2"

[workspace]
Expand Down
16 changes: 11 additions & 5 deletions src/coroutine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,15 @@ impl ThisThread {
}
}

#[derive(PartialEq, Eq, Clone, Copy, Debug)]
pub(crate) enum Status {
Running,
Completed,
}

pub(crate) struct Coroutine {
status: Status,
context: Box<Context>,
completed: bool,
f: Option<Box<dyn FnOnce()>>,
}

Expand All @@ -83,8 +89,8 @@ impl Coroutine {
#[allow(invalid_value)]
let mut co = Box::new(Coroutine {
f: Option::Some(f),
status: Status::Running,
context: unsafe { mem::MaybeUninit::zeroed().assume_init() },
completed: false,
});
let entry = Entry { f: Self::main, arg: (co.as_mut() as *mut Coroutine) as *mut libc::c_void, stack_size };
mem::forget(mem::replace(&mut co.context, Context::new(&entry, None)));
Expand All @@ -94,7 +100,7 @@ impl Coroutine {
extern "C" fn main(arg: *mut libc::c_void) {
let co = unsafe { &mut *(arg as *mut Coroutine) };
co.run();
co.completed = true;
co.status = Status::Completed;
ThisThread::restore();
}

Expand All @@ -106,10 +112,10 @@ impl Coroutine {
/// Resumes coroutine.
///
/// Returns whether this coroutine should be resumed again.
pub fn resume(&mut self) -> bool {
pub fn resume(&mut self) -> Status {
let _scope = Scope::enter(self);
ThisThread::resume(&self.context);
!self.completed
self.status
}

pub fn suspend(&mut self) {
Expand Down
8 changes: 8 additions & 0 deletions src/coroutine/suspension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,14 @@ mod tests {
assert_eq!(resumption.joint.is_ready(), true);
}

#[crate::test(crate = "crate")]
#[should_panic(expected = "deadlock suspending coroutines")]
fn suspension_deadlock() {
let (suspension, resumption) = coroutine::suspension::<()>();
suspension.suspend();
drop(resumption);
}

#[crate::test(crate = "crate")]
fn join_handle_join() {
let join_handle = coroutine::spawn(|| 5);
Expand Down
55 changes: 42 additions & 13 deletions src/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ impl Builder<'_> {
let handle = JoinHandle::new(session);
let main: FnMain = Box::new(move || {
let result = panic::catch_unwind(AssertUnwindSafe(f));
let task = unsafe { current().as_mut() };
let co = unsafe { coroutine::current().as_mut() };
task.status = Status::Completing;
co.suspend();
waker.set_result(result);
});
let task = Task::new(main, self.stack_size);
Expand Down Expand Up @@ -129,15 +133,20 @@ pub(crate) trait Yielding {
fn interrupt(&self, reason: &'static str) -> bool;
}

#[derive(PartialEq, Eq, Clone, Copy, Debug)]
enum Status {
Running,
Completing,
Completed,
}

pub(crate) struct Task {
id: u64,

// main is special as task will terminate after main terminated
main: ptr::NonNull<Coroutine>,

running: Cell<bool>,

aborting: bool,
status: Status,

yielding: bool,

Expand All @@ -155,6 +164,9 @@ pub(crate) struct Task {

// Unblocking by events from outside this task
unblocking_coroutines: Mutex<Vec<ptr::NonNull<Coroutine>>>,

// Guarded by above mutex.
running: Cell<bool>,
}

unsafe impl Sync for Task {}
Expand All @@ -178,14 +190,14 @@ impl Task {
Task {
id,
main: co,
running: Cell::new(true),
aborting: false,
status: Status::Running,
yielding: false,
running_coroutines: VecDeque::from([co]),
yielding_coroutines: Vec::with_capacity(5),
suspending_coroutines: HashMap::new(),
blocking_coroutines: HashMap::new(),
unblocking_coroutines: Mutex::new(Default::default()),
running: Cell::new(true),
}
}

Expand All @@ -197,14 +209,16 @@ impl Task {
drop(unsafe { Box::from_raw(co.as_ptr()) });
}

pub fn run_coroutine(&mut self, mut co: ptr::NonNull<Coroutine>) {
if unsafe { co.as_mut() }.resume() {
return;
fn run_coroutine(&mut self, mut co: ptr::NonNull<Coroutine>) -> coroutine::Status {
let status = unsafe { co.as_mut().resume() };
if status != coroutine::Status::Completed {
return status;
}
if co == self.main && !self.aborting {
self.abort("task main terminated");
if co == self.main {
self.status = Status::Completed;
}
Self::drop_coroutine(co);
status
}

pub fn unblock(&mut self, block: bool) -> bool {
Expand Down Expand Up @@ -242,7 +256,6 @@ impl Task {
}

pub fn abort(&mut self, msg: &'static str) {
self.aborting = true;
loop {
self.interrupt(msg);
if self.running_coroutines.is_empty() {
Expand All @@ -257,6 +270,9 @@ impl Task {
self.run_coroutine(co);
}
}
let status = self.run_coroutine(self.main);
assert_eq!(status, coroutine::Status::Completed);
assert_eq!(self.status, Status::Completed);
}

// Grab this task to runq. Return false if waker win.
Expand All @@ -269,10 +285,13 @@ impl Task {
let _scope = Scope::enter(self);
self.running_coroutines.extend(self.yielding_coroutines.drain(..));
self.unblock(false);
while !self.yielding && !self.running_coroutines.is_empty() {
while !self.yielding && self.status == Status::Running && !self.running_coroutines.is_empty() {
let co = unsafe { self.running_coroutines.pop_front().unwrap_unchecked() };
self.run_coroutine(co);
}
if self.status == Status::Completing {
self.abort("task main terminated");
}
self.yielding = false;
if !self.yielding_coroutines.is_empty() {
SchedFlow::Yield
Expand Down Expand Up @@ -336,7 +355,7 @@ impl Task {
}

pub fn spawn(&mut self, f: impl FnOnce() + 'static, stack_size: StackSize) {
if self.aborting {
if self.status != Status::Running {
return;
}
let f = Box::new(f);
Expand Down Expand Up @@ -409,13 +428,22 @@ mod tests {
fn main_coroutine() {
use std::cell::Cell;
use std::rc::Rc;
use std::sync::atomic::{AtomicI32, Ordering};
use std::time::Duration;

use scopeguard::defer;

let read = Arc::new(AtomicI32::new(0));
let write = read.clone();
let t = task::spawn(|| {
let cell = Rc::new(Cell::new(0));
coroutine::spawn({
let cell = cell.clone();
move || {
defer! {
std::thread::sleep(Duration::from_secs(2));
write.store(cell.get(), Ordering::Relaxed);
}
time::sleep(Duration::from_secs(20));
cell.set(10);
}
Expand All @@ -431,5 +459,6 @@ mod tests {
cell.get()
});
assert_eq!(t.join().unwrap(), 5);
assert_eq!(read.load(Ordering::Relaxed), 5);
}
}

0 comments on commit 9d41ce5

Please sign in to comment.