Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -850,8 +850,8 @@ activities no special worker parameters are needed.

Cancellation for asynchronous activities is done via
[`asyncio.Task.cancel`](https://docs.python.org/3/library/asyncio-task.html#asyncio.Task.cancel). This means that
`asyncio.CancelledError` will be raised (and can be caught, but it is not recommended). An activity must heartbeat to
receive cancellation and there are other ways to be notified about cancellation (see "Activity Context" and
`asyncio.CancelledError` will be raised (and can be caught, but it is not recommended). A non-local activity must
heartbeat to receive cancellation and there are other ways to be notified about cancellation (see "Activity Context" and
"Heartbeating and Cancellation" later).

##### Synchronous Activities
Expand All @@ -860,10 +860,10 @@ Synchronous activities, i.e. functions that do not have `async def`, can be used
`activity_executor` worker parameter must be set with a `concurrent.futures.Executor` instance to use for executing the
activities.

Cancellation for synchronous activities is done in the background and the activity must choose to listen for it and
react appropriately. If after cancellation is obtained an unwrapped `temporalio.exceptions.CancelledError` is raised,
the activity will be marked cancelled. An activity must heartbeat to receive cancellation and there are other ways to be
notified about cancellation (see "Activity Context" and "Heartbeating and Cancellation" later).
All long running, non-local activities should heartbeat so they can be cancelled. Cancellation in threaded activities
throws but multiprocess/other activities does not. The sections below on each synchronous type explain further. There
are also calls on the context that can check for cancellation. For more information, see "Activity Context" and
"Heartbeating and Cancellation" sections later.

Note, all calls from an activity to functions in the `temporalio.activity` package are powered by
[contextvars](https://docs.python.org/3/library/contextvars.html). Therefore, new threads starting _inside_ of
Expand All @@ -876,6 +876,15 @@ If `activity_executor` is set to an instance of `concurrent.futures.ThreadPoolEx
are considered multithreaded activities. Besides `activity_executor`, no other worker parameters are required for
synchronous multithreaded activities.

By default, cancellation of a synchronous multithreaded activity is done via a `temporalio.exceptions.CancelledError`
thrown into the activity thread. Activities that do not wish to have cancellation thrown can set
`no_thread_cancel_exception=True` in the `@activity.defn` decorator.

Code that wishes to be temporarily shielded from the cancellation exception can run inside
`with activity.shield_thread_cancel_exception():`. But once the last nested form of that block is finished, even if
there is a return statement within, it will throw the cancellation if there was one. A `try` +
`except temporalio.exceptions.CancelledError` would have to surround the `with` to handle the cancellation explicitly.

###### Synchronous Multiprocess/Other Activities

If `activity_executor` is set to an instance of `concurrent.futures.Executor` that is _not_
Expand All @@ -901,6 +910,8 @@ calls in the `temporalio.activity` package make use of it. Specifically:
* `is_cancelled()` - Whether a cancellation has been requested on this activity
* `wait_for_cancelled()` - `async` call to wait for cancellation request
* `wait_for_cancelled_sync(timeout)` - Synchronous blocking call to wait for cancellation request
* `shield_thread_cancel_exception()` - Context manager for use in `with` clauses by synchronous multithreaded activities
to prevent cancel exception from being thrown during the block of code
* `is_worker_shutdown()` - Whether the worker has started graceful shutdown
* `wait_for_worker_shutdown()` - `async` call to wait for start of graceful worker shutdown
* `wait_for_worker_shutdown_sync(timeout)` - Synchronous blocking call to wait for start of graceful worker shutdown
Expand All @@ -912,15 +923,17 @@ occurs. Synchronous activities cannot call any of the `async` functions.

##### Heartbeating and Cancellation

In order for an activity to be notified of cancellation requests, they must invoke `temporalio.activity.heartbeat()`.
It is strongly recommended that all but the fastest executing activities call this function regularly. "Types of
Activities" has specifics on cancellation for asynchronous and synchronous activities.
In order for a non-local activity to be notified of cancellation requests, it must invoke
`temporalio.activity.heartbeat()`. It is strongly recommended that all but the fastest executing activities call this
function regularly. "Types of Activities" has specifics on cancellation for asynchronous and synchronous activities.

In addition to obtaining cancellation information, heartbeats also support detail data that is persisted on the server
for retrieval during activity retry. If an activity calls `temporalio.activity.heartbeat(123, 456)` and then fails and
is retried, `temporalio.activity.info().heartbeat_details` will return an iterable containing `123` and `456` on the
next run.

Heartbeating has no effect on local activities.

##### Worker Shutdown

An activity can react to a worker shutdown. Using `is_worker_shutdown` or one of the `wait_for_worker_shutdown`
Expand Down
72 changes: 58 additions & 14 deletions temporalio/activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
import inspect
import logging
import threading
from contextlib import AbstractContextManager, contextmanager
from dataclasses import dataclass
from datetime import datetime, timedelta
from functools import partial
from typing import (
Any,
Callable,
Iterator,
List,
Mapping,
MutableMapping,
Expand All @@ -34,7 +35,6 @@
)

import temporalio.common
import temporalio.exceptions

from .types import CallableType

Expand All @@ -45,32 +45,41 @@ def defn(fn: CallableType) -> CallableType:


@overload
def defn(*, name: str) -> Callable[[CallableType], CallableType]:
def defn(
*, name: Optional[str] = None, no_thread_cancel_exception: bool = False
) -> Callable[[CallableType], CallableType]:
...


def defn(fn: Optional[CallableType] = None, *, name: Optional[str] = None):
def defn(
fn: Optional[CallableType] = None,
*,
name: Optional[str] = None,
no_thread_cancel_exception: bool = False,
):
"""Decorator for activity functions.

Activities can be async or non-async.

Args:
fn: The function to decorate.
name: Name to use for the activity. Defaults to function ``__name__``.
no_thread_cancel_exception: If set to true, an exception will not be
raised in synchronous, threaded activities upon cancellation.
"""

def with_name(name: str, fn: CallableType) -> CallableType:
def decorator(fn: CallableType) -> CallableType:
# This performs validation
_Definition._apply_to_callable(fn, name)
_Definition._apply_to_callable(
fn,
activity_name=name or fn.__name__,
no_thread_cancel_exception=no_thread_cancel_exception,
)
return fn

# If name option is available, return decorator function
if name is not None:
return partial(with_name, name)
if fn is None:
raise RuntimeError("Cannot invoke defn without function or name")
# Otherwise just run decorator function
return with_name(fn.__name__, fn)
if fn is not None:
return decorator(fn)
return decorator


@dataclass(frozen=True)
Expand Down Expand Up @@ -122,6 +131,7 @@ class _Context:
heartbeat: Optional[Callable[..., None]]
cancelled_event: _CompositeEvent
worker_shutdown_event: _CompositeEvent
shield_thread_cancel_exception: Optional[Callable[[], AbstractContextManager]]
_logger_details: Optional[Mapping[str, Any]] = None

@staticmethod
Expand Down Expand Up @@ -221,6 +231,36 @@ def is_cancelled() -> bool:
return _Context.current().cancelled_event.is_set()


@contextmanager
def shield_thread_cancel_exception() -> Iterator[None]:
"""Context manager for synchronous multithreaded activities to delay
cancellation exceptions.

By default, synchronous multithreaded activities have an exception thrown
inside when cancellation occurs. Code within a "with" block of this context
manager will delay that throwing until the end. Even if the block returns a
value or throws its own exception, if a cancellation exception is pending,
it is thrown instead. Therefore users are encouraged to not throw out of
this block and can surround this with a try/except if they wish to catch a
cancellation.

This properly supports nested calls and will only throw after the last one.

This just runs the blocks with no extra effects for async activities or
synchronous multiprocess/other activities.

Raises:
temporalio.exceptions.CancelledError: If a cancellation occurs anytime
during this block and this is not nested in another shield block.
"""
shield_context = _Context.current().shield_thread_cancel_exception
if not shield_context:
yield None
else:
with shield_context():
yield None


async def wait_for_cancelled() -> None:
"""Asynchronously wait for this activity to get a cancellation request.

Expand Down Expand Up @@ -353,6 +393,7 @@ class _Definition:
name: str
fn: Callable
is_async: bool
no_thread_cancel_exception: bool
# Types loaded on post init if both are None
arg_types: Optional[List[Type]] = None
ret_type: Optional[Type] = None
Expand All @@ -379,7 +420,9 @@ def must_from_callable(fn: Callable) -> _Definition:
)

@staticmethod
def _apply_to_callable(fn: Callable, activity_name: str) -> None:
def _apply_to_callable(
fn: Callable, *, activity_name: str, no_thread_cancel_exception: bool = False
) -> None:
# Validate the activity
if hasattr(fn, "__temporal_activity_definition"):
raise ValueError("Function already contains activity definition")
Expand All @@ -399,6 +442,7 @@ def _apply_to_callable(fn: Callable, activity_name: str) -> None:
# iscoroutinefunction does not return true for async __call__
# TODO(cretz): Why can't MyPy handle this?
is_async=inspect.iscoroutinefunction(fn) or inspect.iscoroutinefunction(fn.__call__), # type: ignore
no_thread_cancel_exception=no_thread_cancel_exception,
),
)

Expand Down
9 changes: 8 additions & 1 deletion temporalio/bridge/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import ClassVar, Mapping, Optional
from typing import ClassVar, Mapping, Optional, Type

import temporalio.bridge.temporal_sdk_bridge

Expand Down Expand Up @@ -54,6 +54,13 @@ def set_default(runtime: Runtime, *, error_if_already_set: bool = True) -> None:
raise RuntimeError("Runtime default already set")
_default_runtime = runtime

@staticmethod
def _raise_in_thread(thread_id: int, exc_type: Type[BaseException]) -> bool:
"""Internal helper for raising an exception in thread."""
return temporalio.bridge.temporal_sdk_bridge.raise_in_thread(
thread_id, exc_type
)

def __init__(self, *, telemetry: TelemetryConfig) -> None:
"""Create a default runtime with the given telemetry config.

Expand Down
6 changes: 6 additions & 0 deletions temporalio/bridge/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ fn temporal_sdk_bridge(py: Python, m: &PyModule) -> PyResult<()> {
// Runtime stuff
m.add_class::<runtime::RuntimeRef>()?;
m.add_function(wrap_pyfunction!(init_runtime, m)?)?;
m.add_function(wrap_pyfunction!(raise_in_thread, m)?)?;

// Testing stuff
m.add_class::<testing::EphemeralServerRef>()?;
Expand Down Expand Up @@ -48,6 +49,11 @@ fn init_runtime(telemetry_config: runtime::TelemetryConfig) -> PyResult<runtime:
runtime::init_runtime(telemetry_config)
}

#[pyfunction]
fn raise_in_thread<'a>(py: Python<'a>, thread_id: std::os::raw::c_long, exc: &PyAny) -> bool {
runtime::raise_in_thread(py, thread_id, exc)
}

#[pyfunction]
fn start_temporalite<'a>(
py: Python<'a>,
Expand Down
5 changes: 5 additions & 0 deletions temporalio/bridge/src/runtime.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use pyo3::exceptions::{PyRuntimeError, PyValueError};
use pyo3::prelude::*;
use pyo3::AsPyPointer;
use std::collections::HashMap;
use std::future::Future;
use std::net::SocketAddr;
Expand Down Expand Up @@ -75,6 +76,10 @@ pub fn init_runtime(telemetry_config: TelemetryConfig) -> PyResult<RuntimeRef> {
})
}

pub fn raise_in_thread<'a>(_py: Python<'a>, thread_id: std::os::raw::c_long, exc: &PyAny) -> bool {
unsafe { pyo3::ffi::PyThreadState_SetAsyncExc(thread_id, exc.as_ptr()) == 1 }
}

impl Runtime {
pub fn future_into_py<'a, F, T>(&self, py: Python<'a>, fut: F) -> PyResult<&'a PyAny>
where
Expand Down
2 changes: 1 addition & 1 deletion temporalio/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def non_retryable(self) -> bool:
class CancelledError(FailureError):
"""Error raised on workflow/activity cancellation."""

def __init__(self, message: str, *details: Any) -> None:
def __init__(self, message: str = "Cancelled", *details: Any) -> None:
"""Initialize a cancelled error."""
super().__init__(message)
self._details = details
Expand Down
25 changes: 25 additions & 0 deletions temporalio/testing/_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from typing_extensions import ParamSpec

import temporalio.activity
import temporalio.exceptions
import temporalio.worker._activity

_Params = ParamSpec("_Params")
_Return = TypeVar("_Return")
Expand Down Expand Up @@ -111,6 +113,17 @@ def __init__(
self.env = env
self.fn = fn
self.is_async = inspect.iscoroutinefunction(fn)
self.cancel_thread_raiser: Optional[
temporalio.worker._activity._ThreadExceptionRaiser
] = None
if not self.is_async:
# If there is a definition and they disable thread raising, don't
# set
defn = temporalio.activity._Definition.from_callable(fn)
if not defn or not defn.no_thread_cancel_exception:
self.cancel_thread_raiser = (
temporalio.worker._activity._ThreadExceptionRaiser()
)
# Create context
self.context = temporalio.activity._Context(
info=lambda: env.info,
Expand All @@ -123,10 +136,18 @@ def __init__(
thread_event=threading.Event(),
async_event=asyncio.Event() if self.is_async else None,
),
shield_thread_cancel_exception=None
if not self.cancel_thread_raiser
else self.cancel_thread_raiser.shielded,
)
self.task: Optional[asyncio.Task] = None

def run(self, *args, **kwargs) -> Any:
if self.cancel_thread_raiser:
thread_id = threading.current_thread().ident
if thread_id is not None:
self.cancel_thread_raiser.set_thread_id(thread_id)

@contextmanager
def activity_context():
# Set cancelled and shutdown if already so in environment
Expand Down Expand Up @@ -163,6 +184,10 @@ async def run_async():
def cancel(self) -> None:
if not self.context.cancelled_event.is_set():
self.context.cancelled_event.set()
if self.cancel_thread_raiser:
self.cancel_thread_raiser.raise_in_thread(
temporalio.exceptions.CancelledError
)
if self.task and not self.task.done():
self.task.cancel()

Expand Down
Loading