Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/kyron/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,15 @@ rust_binary(
visibility = ["//visibility:public"],
deps = _EXAMPLE_DEPS,
)

rust_binary(
name = "safety_task",
srcs = [
"examples/safety_task.rs",
],
proc_macro_deps = [
"//src/kyron-macros:runtime_macros",
],
visibility = ["//visibility:public"],
deps = _EXAMPLE_DEPS,
)
18 changes: 12 additions & 6 deletions src/kyron/src/safety.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ pub fn ensure_safety_enabled() {
///
/// # Safety
/// This API is intended to provide a way to ensure that user can react on errors within a `task` independent of other workers state (ie. being busy looping etc).
/// This means that if the `task` (aka provided Future) will return Err(_), then the task that is awaiting on JoinHandle will be woken up in `SafetyWorker`.
/// This means that if the `task` (aka provided Future) will return Err(_), then the task that is awaiting on JoinHandle will be woken up in either `SafetyWorker` or regular worker.
/// Assumption of Use is that the task that is running on SafetyWorker never blocks.
///
pub fn spawn<F, T, E>(future: F) -> JoinHandle<F::Output>
where
Expand All @@ -64,7 +65,8 @@ where
///
/// # Safety
/// This API is intended to provide a way to ensure that user can react on errors within a `task` independent of other workers state (ie. being busy looping etc).
/// This means that if the `task` (aka provided Future) will return Err(_), then the task that is awaiting on JoinHandle will be woken up in `SafetyWorker`.
/// This means that if the `task` (aka provided Future) will return Err(_), then the task that is awaiting on JoinHandle will be woken up in either `SafetyWorker` or regular worker.
/// Assumption of Use is that the task that is running on SafetyWorker never blocks.
///
pub fn spawn_from_boxed<T, E>(boxed: FutureBox<SafetyResult<T, E>>) -> JoinHandle<SafetyResult<T, E>>
where
Expand All @@ -88,7 +90,8 @@ where
///
/// # Safety
/// This API is intended to provide a way to ensure that user can react on errors within a `task` independent of other workers state (ie. being busy looping etc).
/// This means that if the `task` (aka provided Future) will return Err(_), then the task that is awaiting on JoinHandle will be woken up in `SafetyWorker`.
/// This means that if the `task` (aka provided Future) will return Err(_), then the task that is awaiting on JoinHandle will be woken up in either `SafetyWorker` or regular worker.
/// Assumption of Use is that the task that is running on SafetyWorker never blocks.
///
pub fn spawn_from_reusable<T, E>(reusable: ReusableBoxFuture<SafetyResult<T, E>>) -> JoinHandle<SafetyResult<T, E>>
where
Expand All @@ -113,7 +116,8 @@ where
///
/// # Safety
/// This API is intended to provide a way to ensure that user can react on errors within a `task` independent of other workers state (ie. being busy looping etc).
/// This means that if the `task` (aka provided Future) will return Err(_), then the task that is awaiting on JoinHandle will be woken up in `SafetyWorker`.
/// This means that if the `task` (aka provided Future) will return Err(_), then the task that is awaiting on JoinHandle will be woken up in either `SafetyWorker` or regular worker.
/// Assumption of Use is that the task that is running on SafetyWorker never blocks.
///
pub fn spawn_on_dedicated<F, T, E>(future: F, worker_id: UniqueWorkerId) -> JoinHandle<F::Output>
where
Expand All @@ -131,7 +135,8 @@ where
///
/// # Safety
/// This API is intended to provide a way to ensure that user can react on errors within a `task` independent of other workers state (ie. being busy looping etc).
/// This means that if the `task` (aka provided Future) will return Err(_), then the task that is awaiting on JoinHandle will be woken up in `SafetyWorker`.
/// This means that if the `task` (aka provided Future) will return Err(_), then the task that is awaiting on JoinHandle will be woken up in either `SafetyWorker` or regular worker.
/// Assumption of Use is that the task that is running on SafetyWorker never blocks.
///
pub fn spawn_from_boxed_on_dedicated<T, E>(boxed: FutureBox<SafetyResult<T, E>>, worker_id: UniqueWorkerId) -> JoinHandle<SafetyResult<T, E>>
where
Expand All @@ -155,7 +160,8 @@ where
///
/// # Safety
/// This API is intended to provide a way to ensure that user can react on errors within a `task` independent of other workers state (ie. being busy looping etc).
/// This means that if the `task` (aka provided Future) will return Err(_), then the task that is awaiting on JoinHandle will be woken up in `SafetyWorker`.
/// This means that if the `task` (aka provided Future) will return Err(_), then the task that is awaiting on JoinHandle will be woken up in either `SafetyWorker` or regular worker.
/// Assumption of Use is that the task that is running on SafetyWorker never blocks.
///
pub fn spawn_from_reusable_on_dedicated<T, E>(
reusable: ReusableBoxFuture<SafetyResult<T, E>>,
Expand Down
21 changes: 21 additions & 0 deletions src/kyron/src/scheduler/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,7 @@ pub(crate) fn ctx_get_drivers() -> Drivers {
.unwrap()
}

#[allow(dead_code)] // Mock function is used instead of this if mock runtime feature is enabled
///
/// Sets currently running `task`
///
Expand All @@ -494,6 +495,7 @@ pub(super) fn ctx_set_running_task(task: TaskRef) {
});
}

#[allow(dead_code)] // Mock function is used instead of this if mock runtime feature is enabled
///
/// Clears currently running `task`
///
Expand All @@ -505,6 +507,7 @@ pub(super) fn ctx_unset_running_task() {
.map_err(|_| {});
}

#[allow(dead_code)] // Mock function is used instead of this if mock runtime feature is enabled
///
/// Gets currently running `task id`
///
Expand All @@ -523,6 +526,24 @@ pub(crate) fn ctx_get_running_task_id() -> Option<TaskId> {
})
}

#[allow(dead_code)] // Mock function is used instead of this if mock runtime feature is enabled
///
/// Returns `true` if the running task resulted in safety error
///
pub(crate) fn ctx_get_task_safety_error() -> bool {
CTX.try_with(|ctx| {
// This funcation can be called from a thread outside of Kyron runtime through wake()/wake_by_ref(), so we need to check for ctx presence
if let Some(cx) = ctx.borrow().as_ref() {
cx.running_task.borrow().as_ref().is_some_and(|task| task.get_task_safety_error())
} else {
false
}
})
.unwrap_or_else(|e| {
panic!("Something is really bad here, error {}!", e);
})
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
3 changes: 2 additions & 1 deletion src/kyron/src/scheduler/execution_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,8 @@ mod tests {
#[test]
#[cfg(not(miri))] // Provenance issues
fn create_engine_with_worker_and_verify_ids() {
use crate::scheduler::context::{ctx_get_running_task_id, ctx_get_worker_id};
use crate::scheduler::context::ctx_get_worker_id;
use crate::testing::mock_context::ctx_get_running_task_id;
let mut engine = ExecutionEngineBuilder::new().workers(1).task_queue_size(16).set_engine_id(1).build();
let result: Result<bool, ()> = engine
.run_in_engine(async move {
Expand Down
71 changes: 44 additions & 27 deletions src/kyron/src/scheduler/join_handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ impl<T: Send + 'static> Future for JoinHandle<T> {
///
fn poll(self: ::core::pin::Pin<&mut Self>, cx: &mut ::core::task::Context<'_>) -> Poll<Self::Output> {
let res: FutureInternalReturn<JoinResult<T>> = match self.state {
FutureState::New => {
FutureState::New | FutureState::Polled => {
let waker = cx.waker();

// Set the waker, return values tells what have happen and took care about correct synchronization
Expand All @@ -80,28 +80,14 @@ impl<T: Send + 'static> Future for JoinHandle<T> {

match ret {
Ok(v) => FutureInternalReturn::ready(Ok(v)),
Err(CommonErrors::NoData) => FutureInternalReturn::polled(),
Err(CommonErrors::OperationAborted) => FutureInternalReturn::ready(Err(CommonErrors::OperationAborted)),
Err(e) => {
not_recoverable_error!(with e, "There has been an error in a task that is not recoverable ({})!");
}
}
}
}
FutureState::Polled => {
// Safety belows forms AqrRel so waker is really written before we do marking
let mut ret: Result<T, CommonErrors> = Err(CommonErrors::NoData);
let ret_as_ptr = &mut ret as *mut _;
self.for_task.get_return_val(ret_as_ptr as *mut u8);

match ret {
Ok(v) => FutureInternalReturn::ready(Ok(v)),
Err(CommonErrors::NoData) => FutureInternalReturn::polled(),
Err(CommonErrors::OperationAborted) => FutureInternalReturn::ready(Err(CommonErrors::OperationAborted)),
Err(e) => {
not_recoverable_error!(with e, "There has been an error in a task that is not recoverable ({})!");
}
}
}
FutureState::Finished => {
not_recoverable_error!("Future polled after it finished!");
}
Expand Down Expand Up @@ -262,6 +248,41 @@ mod tests {
assert_eq!(poller.poll(), ::core::task::Poll::Ready(Ok(0)));
}
}

#[test]
fn test_join_handle_waker_is_set_in_polled_state_also() {
let scheduler = create_mock_scheduler();

{
// Data is present before first poll of join handle
let worker_id = create_mock_worker_id(0, 1);
let task = ArcInternal::new(AsyncTask::new(box_future(test_function::<u32>()), &worker_id, scheduler.clone()));

let handle = JoinHandle::<u32>::new(TaskRef::new(task.clone()));

let mut poller = TestingFuturePoller::new(handle);

let waker_mock1 = TrackableWaker::new();
let waker1 = waker_mock1.get_waker();

let waker_mock2 = TrackableWaker::new();
let waker2 = waker_mock2.get_waker();

let _ = poller.poll_with_waker(&waker1);
// Now in polled state, poll again with waker2
let _ = poller.poll_with_waker(&waker2);
{
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
task.poll(&mut cx); // task done
}

assert!(!waker_mock1.was_waked());
// this should be TRUE
assert!(waker_mock2.was_waked());
assert_eq!(poller.poll(), ::core::task::Poll::Ready(Ok(0)));
}
}
}

#[cfg(test)]
Expand All @@ -284,8 +305,9 @@ mod tests {

#[test]
fn test_join_handler_mt_get_result() {
let builder = Builder::new();

let mut builder = Builder::new();
// Limit preemption to avoid loom error "Model exceeded maximum number of branches."
builder.preemption_bound = Some(4);
builder.check(|| {
let scheduler = create_mock_scheduler();

Expand All @@ -307,22 +329,17 @@ mod tests {

let waker_mock = TrackableWaker::new();
let waker = waker_mock.get_waker();
let mut was_pending = false;

loop {
match poller.poll_with_waker(&waker) {
Poll::Ready(v) => {
assert_eq!(v, Ok(1234));

if was_pending {
assert!(waker_mock.was_waked());
}
// Note:
// Cannot check whether the waker was woken or not since the waker is set in the join handle poll every time if task is not yet done.
// So depending on the interleaving, the task may finish before the waker is set.

break;
}
Poll::Pending => {
was_pending = true;
}
Poll::Pending => {}
}
loom::hint::spin_loop();
}
Expand Down
8 changes: 3 additions & 5 deletions src/kyron/src/scheduler/safety_waker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,9 @@ static VTABLE: RawWakerVTable = RawWakerVTable::new(clone_waker, wake, wake_by_r
///
/// Waker will store internally a pointer to the ref counted Task.
///
pub(crate) unsafe fn create_safety_waker(waker: Waker) -> Waker {
let raw_waker = RawWaker::new(waker.data(), &VTABLE);

// Forget original as we took over the ownership, so ref count
::core::mem::forget(waker);
pub(crate) fn create_safety_waker(ptr: TaskRef) -> Waker {
let ptr = TaskRef::into_raw(ptr); // Extracts the pointer from TaskRef not decreasing it's reference count. Since we have a clone here, ref cnt was already increased
let raw_waker = RawWaker::new(ptr as *const (), &VTABLE);

// Convert RawWaker to Waker
unsafe { Waker::from_raw(raw_waker) }
Expand Down
Loading
Loading