Skip to content

Activity worker: refactoring part 2 #899

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 23, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 31 additions & 26 deletions temporalio/worker/_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,24 +138,21 @@ def __init__(
self._dynamic_activity = defn

async def run(self) -> None:
# Create a task that fails when we get a failure on the queue
async def raise_from_queue() -> NoReturn:
"""Continually poll for activity tasks and dispatch to handlers."""

async def raise_from_exception_queue() -> NoReturn:
raise await self._fail_worker_exception_queue.get()

exception_task = asyncio.create_task(raise_from_queue())
exception_task = asyncio.create_task(raise_from_exception_queue())

# Continually poll for activity work
while True:
try:
# Poll for a task
poll_task = asyncio.create_task(
self._bridge_worker().poll_activity_task()
)
await asyncio.wait(
[poll_task, exception_task], return_when=asyncio.FIRST_COMPLETED
) # type: ignore
# If exception for failing the worker happened, raise it.
# Otherwise, the poll succeeded.
)
if exception_task.done():
poll_task.cancel()
await exception_task
Expand All @@ -167,11 +164,14 @@ async def raise_from_queue() -> NoReturn:
# size of 1000 should be plenty for the heartbeat queue.
activity = _RunningActivity(pending_heartbeats=asyncio.Queue(1000))
activity.task = asyncio.create_task(
self._run_activity(task.task_token, task.start, activity)
self._handle_start_activity_task(
task.task_token, task.start, activity
)
)
self._running_activities[task.task_token] = activity
elif task.HasField("cancel"):
self._cancel(task.task_token, task.cancel)
# TODO(nexus-prerelease): does the task get removed from running_activities?
self._handle_cancel_activity_task(task.task_token, task.cancel)
else:
raise RuntimeError(f"Unrecognized activity task: {task}")
except temporalio.bridge.worker.PollShutdownError:
Expand Down Expand Up @@ -208,9 +208,10 @@ async def wait_all_completed(self) -> None:
if running_tasks:
await asyncio.gather(*running_tasks, return_exceptions=False)

def _cancel(
def _handle_cancel_activity_task(
self, task_token: bytes, cancel: temporalio.bridge.proto.activity_task.Cancel
) -> None:
"""Request cancellation of a running activity task."""
activity = self._running_activities.get(task_token)
if not activity:
warnings.warn(f"Cannot find activity to cancel for token {task_token!r}")
Expand Down Expand Up @@ -275,12 +276,17 @@ async def _heartbeat_async(
)
activity.cancel(cancelled_due_to_heartbeat_error=err)

async def _run_activity(
async def _handle_start_activity_task(
self,
task_token: bytes,
start: temporalio.bridge.proto.activity_task.Start,
running_activity: _RunningActivity,
) -> None:
"""Handle a start activity task.

Attempt to execute the user activity function and invoke the data converter on
the result. Handle errors and send the task completion.
"""
logger.debug("Running activity %s (token %s)", start.activity_type, task_token)
# We choose to surround interceptor creation and activity invocation in
# a try block so we can mark the workflow as failed on any error instead
Expand All @@ -289,7 +295,9 @@ async def _run_activity(
task_token=task_token
)
try:
await self._execute_activity(start, running_activity, completion)
result = await self._execute_activity(start, running_activity, task_token)
[payload] = await self._data_converter.encode([result])
completion.result.completed.result.CopyFrom(payload)
except BaseException as err:
try:
if isinstance(err, temporalio.activity._CompleteAsyncError):
Expand Down Expand Up @@ -318,7 +326,7 @@ async def _run_activity(
and running_activity.cancellation_details.details.paused
):
temporalio.activity.logger.warning(
f"Completing as failure due to unhandled cancel error produced by activity pause",
"Completing as failure due to unhandled cancel error produced by activity pause",
)
await self._data_converter.encode_failure(
temporalio.exceptions.ApplicationError(
Expand Down Expand Up @@ -402,8 +410,12 @@ async def _execute_activity(
self,
start: temporalio.bridge.proto.activity_task.Start,
running_activity: _RunningActivity,
completion: temporalio.bridge.proto.ActivityTaskCompletion,
):
task_token: bytes,
) -> Any:
"""Invoke the user's activity function.

Exceptions are handled by a caller of this function.
"""
# Find activity or fail
activity_def = self._activities.get(start.activity_type, self._dynamic_activity)
if not activity_def:
Expand Down Expand Up @@ -523,7 +535,7 @@ async def _execute_activity(
else None,
started_time=_proto_to_datetime(start.started_time),
task_queue=self._task_queue,
task_token=completion.task_token,
task_token=task_token,
workflow_id=start.workflow_execution.workflow_id,
workflow_namespace=start.workflow_namespace,
workflow_run_id=start.workflow_execution.run_id,
Expand Down Expand Up @@ -562,16 +574,9 @@ async def _execute_activity(
impl: ActivityInboundInterceptor = _ActivityInboundImpl(self, running_activity)
for interceptor in reversed(list(self._interceptors)):
impl = interceptor.intercept_activity(impl)
# Init

impl.init(_ActivityOutboundImpl(self, running_activity.info))
# Exec
result = await impl.execute_activity(input)
# Convert result even if none. Since Python essentially only
# supports single result types (even if they are tuples), we will do
# the same.
completion.result.completed.result.CopyFrom(
(await self._data_converter.encode([result]))[0]
)
return await impl.execute_activity(input)

def assert_activity_valid(self, activity) -> None:
if self._dynamic_activity:
Expand Down
Loading