Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ license = "MIT"
license-files = ["LICENSE"]
keywords = ["temporal", "workflow"]
dependencies = [
"nexus-rpc==1.1.0",
"nexus-rpc==1.2.0",
"protobuf>=3.20,<7.0.0",
"python-dateutil>=2.8.2,<3 ; python_version < '3.11'",
"types-protobuf>=3.20",
Expand Down
2 changes: 1 addition & 1 deletion temporalio/nexus/_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ async def _start(
return WorkflowRunOperationHandler(_start)

method_name = get_callable_name(start)
nexusrpc.set_operation_definition(
nexusrpc.set_operation(
operation_handler_factory,
nexusrpc.Operation(
name=name or method_name,
Expand Down
19 changes: 0 additions & 19 deletions temporalio/nexus/_operation_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,10 @@
HandlerError,
HandlerErrorType,
InputT,
OperationInfo,
OutputT,
)
from nexusrpc.handler import (
CancelOperationContext,
FetchOperationInfoContext,
FetchOperationResultContext,
OperationHandler,
StartOperationContext,
StartOperationResultAsync,
Expand Down Expand Up @@ -81,22 +78,6 @@ async def cancel(self, ctx: CancelOperationContext, token: str) -> None:
"""Cancel the operation, by cancelling the workflow."""
await _cancel_workflow(token)

async def fetch_info(
self, ctx: FetchOperationInfoContext, token: str
) -> OperationInfo:
"""Fetch operation info (not supported for Temporal Nexus operations)."""
raise NotImplementedError(
"Temporal Nexus operation handlers do not support fetching operation info."
)

async def fetch_result(
self, ctx: FetchOperationResultContext, token: str
) -> OutputT:
"""Fetch operation result (not supported for Temporal Nexus operations)."""
raise NotImplementedError(
"Temporal Nexus operation handlers do not support fetching the operation result."
)


async def _cancel_workflow(
token: str,
Expand Down
4 changes: 2 additions & 2 deletions temporalio/nexus/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,12 @@ def get_operation_factory(

``obj`` should be a decorated operation start method.
"""
op_defn = nexusrpc.get_operation_definition(obj)
op_defn = nexusrpc.get_operation(obj)
if op_defn:
factory = obj
else:
if factory := getattr(obj, "__nexus_operation_factory__", None):
op_defn = nexusrpc.get_operation_definition(factory)
op_defn = nexusrpc.get_operation(factory)
if not isinstance(op_defn, nexusrpc.Operation):
return None, None
return factory, op_defn
Expand Down
96 changes: 88 additions & 8 deletions temporalio/worker/_nexus.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import asyncio
import concurrent.futures
import json
import threading
from dataclasses import dataclass
from typing import (
Any,
Expand Down Expand Up @@ -32,7 +33,10 @@
import temporalio.common
import temporalio.converter
import temporalio.nexus
from temporalio.exceptions import ApplicationError, WorkflowAlreadyStartedError
from temporalio.exceptions import (
ApplicationError,
WorkflowAlreadyStartedError,
)
from temporalio.nexus import Info, logger
from temporalio.service import RPCError, RPCStatusCode

Expand All @@ -41,6 +45,16 @@
_TEMPORAL_FAILURE_PROTO_TYPE = "temporal.api.failure.v1.Failure"


@dataclass
class _RunningNexusTask:
task: asyncio.Task[Any]
cancellation: _NexusTaskCancellation

def cancel(self, reason: str):
self.cancellation.cancel(reason)
self.task.cancel()


class _NexusWorker:
def __init__(
self,
Expand All @@ -65,7 +79,7 @@ def __init__(
self._interceptors = interceptors
# TODO(nexus-preview): metric_meter
self._metric_meter = metric_meter
self._running_tasks: dict[bytes, asyncio.Task[Any]] = {}
self._running_tasks: dict[bytes, _RunningNexusTask] = {}
self._fail_worker_exception_queue: asyncio.Queue[Exception] = asyncio.Queue()

async def run(self) -> None:
Expand All @@ -90,21 +104,31 @@ async def raise_from_exception_queue() -> NoReturn:
if nexus_task.HasField("task"):
task = nexus_task.task
if task.request.HasField("start_operation"):
self._running_tasks[task.task_token] = asyncio.create_task(
task_cancellation = _NexusTaskCancellation()
start_op_task = asyncio.create_task(
self._handle_start_operation_task(
task.task_token,
task.request.start_operation,
dict(task.request.header),
task_cancellation,
)
)
self._running_tasks[task.task_token] = _RunningNexusTask(
start_op_task, task_cancellation
)
elif task.request.HasField("cancel_operation"):
self._running_tasks[task.task_token] = asyncio.create_task(
task_cancellation = _NexusTaskCancellation()
cancel_op_task = asyncio.create_task(
self._handle_cancel_operation_task(
task.task_token,
task.request.cancel_operation,
dict(task.request.header),
task_cancellation,
)
)
self._running_tasks[task.task_token] = _RunningNexusTask(
cancel_op_task, task_cancellation
)
else:
raise NotImplementedError(
f"Invalid Nexus task request: {task.request}"
Expand All @@ -113,8 +137,12 @@ async def raise_from_exception_queue() -> NoReturn:
if running_task := self._running_tasks.get(
nexus_task.cancel_task.task_token
):
# TODO(nexus-prerelease): when do we remove the entry from _running_operations?
running_task.cancel()
reason = (
temporalio.bridge.proto.nexus.NexusTaskCancelReason.Name(
nexus_task.cancel_task.reason
)
)
running_task.cancel(reason)
else:
logger.debug(
f"Received cancel_task but no running task exists for "
Expand Down Expand Up @@ -147,7 +175,10 @@ async def drain_poll_queue(self) -> None:
# Only call this after run()/drain_poll_queue() have returned. This will not
# raise an exception.
async def wait_all_completed(self) -> None:
await asyncio.gather(*self._running_tasks.values(), return_exceptions=True)
running_tasks = [
running_task.task for running_task in self._running_tasks.values()
]
await asyncio.gather(*running_tasks, return_exceptions=True)

# TODO(nexus-preview): stack trace pruning. See sdk-typescript NexusHandler.execute
# "Any call up to this function and including this one will be trimmed out of stack traces.""
Expand All @@ -157,6 +188,7 @@ async def _handle_cancel_operation_task(
task_token: bytes,
request: temporalio.api.nexus.v1.CancelOperationRequest,
headers: Mapping[str, str],
task_cancellation: nexusrpc.handler.OperationTaskCancellation,
) -> None:
"""Handle a cancel operation task.

Expand All @@ -168,6 +200,7 @@ async def _handle_cancel_operation_task(
service=request.service,
operation=request.operation,
headers=headers,
task_cancellation=task_cancellation,
)
temporalio.nexus._operation_context._TemporalCancelOperationContext(
info=lambda: Info(task_queue=self._task_queue),
Expand All @@ -177,6 +210,11 @@ async def _handle_cancel_operation_task(
try:
try:
await self._handler.cancel_operation(ctx, request.operation_token)
except asyncio.CancelledError:
completion = temporalio.bridge.proto.nexus.NexusTaskCompletion(
task_token=task_token,
ack_cancel=task_cancellation.is_cancelled(),
)
except BaseException as err:
logger.warning("Failed to execute Nexus cancel operation method")
completion = temporalio.bridge.proto.nexus.NexusTaskCompletion(
Expand Down Expand Up @@ -209,6 +247,7 @@ async def _handle_start_operation_task(
task_token: bytes,
start_request: temporalio.api.nexus.v1.StartOperationRequest,
headers: Mapping[str, str],
task_cancellation: nexusrpc.handler.OperationTaskCancellation,
) -> None:
"""Handle a start operation task.

Expand All @@ -217,7 +256,14 @@ async def _handle_start_operation_task(
"""
try:
try:
start_response = await self._start_operation(start_request, headers)
start_response = await self._start_operation(
start_request, headers, task_cancellation
)
except asyncio.CancelledError:
completion = temporalio.bridge.proto.nexus.NexusTaskCompletion(
task_token=task_token,
ack_cancel=task_cancellation.is_cancelled(),
)
except BaseException as err:
logger.warning("Failed to execute Nexus start operation method")
completion = temporalio.bridge.proto.nexus.NexusTaskCompletion(
Expand All @@ -226,6 +272,7 @@ async def _handle_start_operation_task(
_exception_to_handler_error(err)
),
)

if isinstance(err, concurrent.futures.BrokenExecutor):
self._fail_worker_exception_queue.put_nowait(err)
else:
Expand All @@ -235,6 +282,7 @@ async def _handle_start_operation_task(
start_operation=start_response
),
)

await self._bridge_worker().complete_nexus_task(completion)
except Exception:
logger.exception("Failed to send Nexus task completion")
Expand All @@ -250,6 +298,7 @@ async def _start_operation(
self,
start_request: temporalio.api.nexus.v1.StartOperationRequest,
headers: Mapping[str, str],
cancellation: nexusrpc.handler.OperationTaskCancellation,
) -> temporalio.api.nexus.v1.StartOperationResponse:
"""Invoke the Nexus handler's start_operation method and construct the StartOperationResponse.

Expand All @@ -268,6 +317,7 @@ async def _start_operation(
for link in start_request.links
],
callback_headers=dict(start_request.callback_header),
task_cancellation=cancellation,
)
temporalio.nexus._operation_context._TemporalStartOperationContext(
nexus_context=ctx,
Expand Down Expand Up @@ -517,3 +567,33 @@ def _exception_to_handler_error(err: BaseException) -> nexusrpc.HandlerError:
)
handler_err.__cause__ = err
return handler_err


class _NexusTaskCancellation(nexusrpc.handler.OperationTaskCancellation):
def __init__(self):
self._thread_evt = threading.Event()
self._async_evt = asyncio.Event()
self._lock = threading.Lock()
self._reason: Optional[str] = None

def is_cancelled(self) -> bool:
return self._thread_evt.is_set()

def cancellation_reason(self) -> Optional[str]:
with self._lock:
return self._reason

def wait_until_cancelled_sync(self, timeout: float | None = None) -> bool:
return self._thread_evt.wait(timeout)

async def wait_until_cancelled(self) -> None:
await self._async_evt.wait()

def cancel(self, reason: str) -> bool:
with self._lock:
if self._thread_evt.is_set():
return False
self._reason = reason
self._thread_evt.set()
self._async_evt.set()
return True
33 changes: 33 additions & 0 deletions temporalio/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5417,6 +5417,22 @@ async def start_operation(
headers: Optional[Mapping[str, str]] = None,
) -> NexusOperationHandle[OutputT]: ...

# Overload for operation_handler
@overload
@abstractmethod
async def start_operation(
self,
operation: Callable[
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not entirely sure about this.

At the workflow level, we should be seeing the ServiceDefinition (i.e. not the implementation), which should not expose the notion of factory-style operation handlers (that's an implementation detail, not part of the public contract of a service).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see what you mean. Removing this and updating the test to define the service and use that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, let's move on with this. We'll review pertinence of all overloads in another issue.

[ServiceHandlerT], nexusrpc.handler.OperationHandler[InputT, OutputT]
],
input: InputT,
*,
output_type: Optional[Type[OutputT]] = None,
schedule_to_close_timeout: Optional[timedelta] = None,
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
headers: Optional[Mapping[str, str]] = None,
) -> NexusOperationHandle[OutputT]: ...

@abstractmethod
async def start_operation(
self,
Expand Down Expand Up @@ -5527,6 +5543,23 @@ async def execute_operation(
headers: Optional[Mapping[str, str]] = None,
) -> OutputT: ...

# Overload for operation_handler
@overload
@abstractmethod
async def execute_operation(
self,
operation: Callable[
[ServiceT],
nexusrpc.handler.OperationHandler[InputT, OutputT],
],
input: InputT,
*,
output_type: Optional[Type[OutputT]] = None,
schedule_to_close_timeout: Optional[timedelta] = None,
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
headers: Optional[Mapping[str, str]] = None,
) -> OutputT: ...

@abstractmethod
async def execute_operation(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from agents import Runner, custom_span, gen_trace_id, trace

import temporalio.workflow
from tests.contrib.openai_agents.research_agents.planner_agent import (
WebSearchItem,
WebSearchPlan,
Expand Down Expand Up @@ -45,7 +46,7 @@ async def _perform_searches(self, search_plan: WebSearchPlan) -> list[str]:
asyncio.create_task(self._search(item)) for item in search_plan.searches
]
results = []
for task in asyncio.as_completed(tasks):
for task in temporalio.workflow.as_completed(tasks):
result = await task
if result is not None:
results.append(result)
Expand Down
18 changes: 2 additions & 16 deletions tests/nexus/test_dynamic_creation_of_user_handler_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,20 +44,6 @@ async def cancel(
) -> None:
raise NotImplementedError

async def fetch_info(
self,
ctx: nexusrpc.handler.FetchOperationInfoContext,
token: str,
) -> nexusrpc.OperationInfo:
raise NotImplementedError

async def fetch_result(
self,
ctx: nexusrpc.handler.FetchOperationResultContext,
token: str,
) -> int:
raise NotImplementedError


@nexusrpc.handler.service_handler
class MyServiceHandlerWithWorkflowRunOperation:
Expand All @@ -78,8 +64,8 @@ async def test_run_nexus_service_from_programmatically_created_service_handler(
service_handler = nexusrpc.handler._core.ServiceHandler(
service=nexusrpc.ServiceDefinition(
name="MyService",
operations={
"increment": nexusrpc.Operation[int, int](
operation_definitions={
"increment": nexusrpc.OperationDefinition[int, int](
name="increment",
method_name="increment",
input_type=int,
Expand Down
Loading