Skip to content

Commit 1e53211

Browse files
committed
Nexus
1 parent 39f8e84 commit 1e53211

24 files changed

+3544
-562
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ informal introduction to the features and their implementation.
9494
- [Heartbeating and Cancellation](#heartbeating-and-cancellation)
9595
- [Worker Shutdown](#worker-shutdown)
9696
- [Testing](#testing-1)
97+
- [Nexus](#nexus)
9798
- [Workflow Replay](#workflow-replay)
9899
- [Observability](#observability)
99100
- [Metrics](#metrics)
@@ -1313,6 +1314,7 @@ affect calls activity code might make to functions on the `temporalio.activity`
13131314
* `cancel()` can be invoked to simulate a cancellation of the activity
13141315
* `worker_shutdown()` can be invoked to simulate a worker shutdown during execution of the activity
13151316

1317+
13161318
### Workflow Replay
13171319

13181320
Given a workflow's history, it can be replayed locally to check for things like non-determinism errors. For example,

pyproject.toml

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ keywords = [
1111
"workflow",
1212
]
1313
dependencies = [
14-
"protobuf>=3.20",
14+
"nexus-rpc",
15+
"protobuf==5.29.3",
1516
"python-dateutil>=2.8.2,<3 ; python_version < '3.11'",
1617
"types-protobuf>=3.20",
1718
"typing-extensions>=4.2.0,<5",
@@ -40,7 +41,7 @@ dev = [
4041
"psutil>=5.9.3,<6",
4142
"pydocstyle>=6.3.0,<7",
4243
"pydoctor>=24.11.1,<25",
43-
"pyright==1.1.377",
44+
"pyright==1.1.400",
4445
"pytest~=7.4",
4546
"pytest-asyncio>=0.21,<0.22",
4647
"pytest-timeout~=2.2",
@@ -49,6 +50,8 @@ dev = [
4950
"twine>=4.0.1,<5",
5051
"ruff>=0.5.0,<0.6",
5152
"maturin>=1.8.2",
53+
"pytest-cov>=6.1.1",
54+
"httpx>=0.28.1",
5255
"pytest-pretty>=1.3.0",
5356
]
5457

@@ -158,6 +161,7 @@ exclude = [
158161
"tests/worker/workflow_sandbox/testmodules/proto",
159162
"temporalio/bridge/worker.py",
160163
"temporalio/contrib/opentelemetry.py",
164+
"temporalio/contrib/pydantic.py",
161165
"temporalio/converter.py",
162166
"temporalio/testing/_workflow.py",
163167
"temporalio/worker/_activity.py",
@@ -169,6 +173,10 @@ exclude = [
169173
"tests/api/test_grpc_stub.py",
170174
"tests/conftest.py",
171175
"tests/contrib/test_opentelemetry.py",
176+
"tests/contrib/pydantic/models.py",
177+
"tests/contrib/pydantic/models_2.py",
178+
"tests/contrib/pydantic/test_pydantic.py",
179+
"tests/contrib/pydantic/workflows.py",
172180
"tests/test_converter.py",
173181
"tests/test_service.py",
174182
"tests/test_workflow.py",
@@ -187,6 +195,9 @@ exclude = [
187195
[tool.ruff]
188196
target-version = "py39"
189197

198+
[tool.ruff.lint]
199+
extend-ignore = ["E741"] # Allow single-letter variable names like I, O
200+
190201
[build-system]
191202
requires = ["maturin>=1.0,<2.0"]
192203
build-backend = "maturin"
@@ -203,3 +214,6 @@ exclude = [
203214
[tool.uv]
204215
# Prevent uv commands from building the package by default
205216
package = false
217+
218+
[tool.uv.sources]
219+
nexus-rpc = { path = "../nexus-sdk-python", editable = true }

temporalio/bridge/src/worker.rs

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use temporal_sdk_core_api::worker::{
2020
};
2121
use temporal_sdk_core_api::Worker;
2222
use temporal_sdk_core_protos::coresdk::workflow_completion::WorkflowActivationCompletion;
23-
use temporal_sdk_core_protos::coresdk::{ActivityHeartbeat, ActivityTaskCompletion};
23+
use temporal_sdk_core_protos::coresdk::{ActivityHeartbeat, ActivityTaskCompletion, nexus::NexusTaskCompletion};
2424
use temporal_sdk_core_protos::temporal::api::history::v1::History;
2525
use tokio::sync::mpsc::{channel, Sender};
2626
use tokio_stream::wrappers::ReceiverStream;
@@ -570,6 +570,19 @@ impl WorkerRef {
570570
})
571571
}
572572

573+
fn poll_nexus_task<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> {
574+
let worker = self.worker.as_ref().unwrap().clone();
575+
self.runtime.future_into_py(py, async move {
576+
let bytes = match worker.poll_nexus_task().await {
577+
Ok(task) => task.encode_to_vec(),
578+
Err(PollError::ShutDown) => return Err(PollShutdownError::new_err(())),
579+
Err(err) => return Err(PyRuntimeError::new_err(format!("Poll failure: {}", err))),
580+
};
581+
let bytes: &[u8] = &bytes;
582+
Ok(Python::with_gil(|py| bytes.into_py(py)))
583+
})
584+
}
585+
573586
fn complete_workflow_activation<'p>(
574587
&self,
575588
py: Python<'p>,
@@ -600,6 +613,19 @@ impl WorkerRef {
600613
})
601614
}
602615

616+
fn complete_nexus_task<'p>(&self, py: Python<'p>, proto: &PyBytes) -> PyResult<&'p PyAny> {
617+
let worker = self.worker.as_ref().unwrap().clone();
618+
let completion = NexusTaskCompletion::decode(proto.as_bytes())
619+
.map_err(|err| PyValueError::new_err(format!("Invalid proto: {}", err)))?;
620+
self.runtime.future_into_py(py, async move {
621+
worker
622+
.complete_nexus_task(completion)
623+
.await
624+
.context("Completion failure")
625+
.map_err(Into::into)
626+
})
627+
}
628+
603629
fn record_activity_heartbeat(&self, proto: &PyBytes) -> PyResult<()> {
604630
enter_sync!(self.runtime);
605631
let heartbeat = ActivityHeartbeat::decode(proto.as_bytes())

temporalio/bridge/worker.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import temporalio.bridge.client
2727
import temporalio.bridge.proto
2828
import temporalio.bridge.proto.activity_task
29+
import temporalio.bridge.proto.nexus
2930
import temporalio.bridge.proto.workflow_activation
3031
import temporalio.bridge.proto.workflow_completion
3132
import temporalio.bridge.runtime
@@ -35,7 +36,7 @@
3536
from temporalio.bridge.temporal_sdk_bridge import (
3637
CustomSlotSupplier as BridgeCustomSlotSupplier,
3738
)
38-
from temporalio.bridge.temporal_sdk_bridge import PollShutdownError
39+
from temporalio.bridge.temporal_sdk_bridge import PollShutdownError # type: ignore
3940

4041

4142
@dataclass
@@ -216,6 +217,14 @@ async def poll_activity_task(
216217
await self._ref.poll_activity_task()
217218
)
218219

220+
async def poll_nexus_task(
221+
self,
222+
) -> temporalio.bridge.proto.nexus.NexusTask:
223+
"""Poll for a nexus task."""
224+
return temporalio.bridge.proto.nexus.NexusTask.FromString(
225+
await self._ref.poll_nexus_task()
226+
)
227+
219228
async def complete_workflow_activation(
220229
self,
221230
comp: temporalio.bridge.proto.workflow_completion.WorkflowActivationCompletion,
@@ -229,6 +238,12 @@ async def complete_activity_task(
229238
"""Complete an activity task."""
230239
await self._ref.complete_activity_task(comp.SerializeToString())
231240

241+
async def complete_nexus_task(
242+
self, comp: temporalio.bridge.proto.nexus.NexusTaskCompletion
243+
) -> None:
244+
"""Complete a nexus task."""
245+
await self._ref.complete_nexus_task(comp.SerializeToString())
246+
232247
def record_activity_heartbeat(
233248
self, comp: temporalio.bridge.proto.ActivityHeartbeat
234249
) -> None:

temporalio/client.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -464,9 +464,16 @@ async def start_workflow(
464464
rpc_metadata: Mapping[str, str] = {},
465465
rpc_timeout: Optional[timedelta] = None,
466466
request_eager_start: bool = False,
467-
stack_level: int = 2,
468467
priority: temporalio.common.Priority = temporalio.common.Priority.default,
469468
versioning_override: Optional[temporalio.common.VersioningOverride] = None,
469+
# The following options are deliberately not exposed in overloads
470+
stack_level: int = 2,
471+
nexus_completion_callbacks: Sequence[
472+
temporalio.common.NexusCompletionCallback
473+
] = [],
474+
workflow_event_links: Sequence[
475+
temporalio.api.common.v1.Link.WorkflowEvent
476+
] = [],
470477
) -> WorkflowHandle[Any, Any]:
471478
"""Start a workflow and return its handle.
472479
@@ -557,6 +564,8 @@ async def start_workflow(
557564
rpc_timeout=rpc_timeout,
558565
request_eager_start=request_eager_start,
559566
priority=priority,
567+
nexus_completion_callbacks=nexus_completion_callbacks,
568+
workflow_event_links=workflow_event_links,
560569
)
561570
)
562571

@@ -5193,6 +5202,8 @@ class StartWorkflowInput:
51935202
rpc_timeout: Optional[timedelta]
51945203
request_eager_start: bool
51955204
priority: temporalio.common.Priority
5205+
nexus_completion_callbacks: Sequence[temporalio.common.NexusCompletionCallback]
5206+
workflow_event_links: Sequence[temporalio.api.common.v1.Link.WorkflowEvent]
51965207
versioning_override: Optional[temporalio.common.VersioningOverride] = None
51975208

51985209

@@ -5809,6 +5820,16 @@ async def _build_start_workflow_execution_request(
58095820
req = temporalio.api.workflowservice.v1.StartWorkflowExecutionRequest()
58105821
req.request_eager_execution = input.request_eager_start
58115822
await self._populate_start_workflow_execution_request(req, input)
5823+
for callback in input.nexus_completion_callbacks:
5824+
c = temporalio.api.common.v1.Callback()
5825+
c.nexus.url = callback.url
5826+
c.nexus.header.update(callback.header)
5827+
req.completion_callbacks.append(c)
5828+
5829+
req.links.extend(
5830+
temporalio.api.common.v1.Link(workflow_event=link)
5831+
for link in input.workflow_event_links
5832+
)
58125833
return req
58135834

58145835
async def _build_signal_with_start_workflow_execution_request(

temporalio/common.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from abc import ABC, abstractmethod
99
from dataclasses import dataclass
1010
from datetime import datetime, timedelta
11-
from enum import Enum, IntEnum
11+
from enum import IntEnum
1212
from typing import (
1313
Any,
1414
Callable,
@@ -197,6 +197,37 @@ def __setstate__(self, state: object) -> None:
197197
)
198198

199199

200+
@dataclass(frozen=True)
201+
class NexusCompletionCallback:
202+
"""Nexus callback to attach to events such as workflow completion."""
203+
204+
url: str
205+
"""Callback URL."""
206+
207+
header: Mapping[str, str]
208+
"""Header to attach to callback request."""
209+
210+
211+
@dataclass(frozen=True)
212+
class WorkflowEventLink:
213+
"""A link to a history event that can be attached to a different history event."""
214+
215+
namespace: str
216+
"""Namespace of the workflow to link to."""
217+
218+
workflow_id: str
219+
"""ID of the workflow to link to."""
220+
221+
run_id: str
222+
"""Run ID of the workflow to link to."""
223+
224+
event_type: temporalio.api.enums.v1.EventType
225+
"""Type of the event to link to."""
226+
227+
event_id: int
228+
"""ID of the event to link to."""
229+
230+
200231
# We choose to make this a list instead of an sequence so we can catch if people
201232
# are not sending lists each time but maybe accidentally sending a string (which
202233
# is a sequence)

temporalio/converter.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -911,6 +911,12 @@ def _error_to_failure(
911911
failure.child_workflow_execution_failure_info.retry_state = (
912912
temporalio.api.enums.v1.RetryState.ValueType(error.retry_state or 0)
913913
)
914+
# TODO(nexus-prerelease): test coverage for this
915+
elif isinstance(error, temporalio.exceptions.NexusOperationError):
916+
failure.nexus_operation_execution_failure_info.SetInParent()
917+
failure.nexus_operation_execution_failure_info.operation_token = (
918+
error.operation_token
919+
)
914920

915921
def from_failure(
916922
self,
@@ -1006,6 +1012,26 @@ def from_failure(
10061012
if child_info.retry_state
10071013
else None,
10081014
)
1015+
elif failure.HasField("nexus_handler_failure_info"):
1016+
nexus_handler_failure_info = failure.nexus_handler_failure_info
1017+
err = temporalio.exceptions.NexusHandlerError(
1018+
failure.message or "Nexus handler error",
1019+
type=nexus_handler_failure_info.type,
1020+
retryable={
1021+
temporalio.api.enums.v1.NexusHandlerErrorRetryBehavior.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_RETRYABLE: True,
1022+
temporalio.api.enums.v1.NexusHandlerErrorRetryBehavior.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_NON_RETRYABLE: False,
1023+
}.get(nexus_handler_failure_info.retry_behavior),
1024+
)
1025+
elif failure.HasField("nexus_operation_execution_failure_info"):
1026+
nexus_op_failure_info = failure.nexus_operation_execution_failure_info
1027+
err = temporalio.exceptions.NexusOperationError(
1028+
failure.message or "Nexus operation error",
1029+
scheduled_event_id=nexus_op_failure_info.scheduled_event_id,
1030+
endpoint=nexus_op_failure_info.endpoint,
1031+
service=nexus_op_failure_info.service,
1032+
operation=nexus_op_failure_info.operation,
1033+
operation_token=nexus_op_failure_info.operation_token,
1034+
)
10091035
else:
10101036
err = temporalio.exceptions.FailureError(failure.message or "Failure error")
10111037
err._failure = failure

temporalio/exceptions.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,69 @@ def retry_state(self) -> Optional[RetryState]:
362362
return self._retry_state
363363

364364

365+
class NexusHandlerError(FailureError):
366+
"""Error raised on Nexus handler failure."""
367+
368+
def __init__(
369+
self,
370+
message: str,
371+
*,
372+
type: str,
373+
retryable: Optional[bool] = None,
374+
):
375+
"""Initialize a Nexus handler error."""
376+
super().__init__(message)
377+
self._type = type
378+
self._retryable = retryable
379+
380+
381+
class NexusOperationError(FailureError):
382+
"""Error raised on Nexus operation failure."""
383+
384+
def __init__(
385+
self,
386+
message: str,
387+
*,
388+
scheduled_event_id: int,
389+
endpoint: str,
390+
service: str,
391+
operation: str,
392+
operation_token: str,
393+
):
394+
"""Initialize a Nexus operation error."""
395+
super().__init__(message)
396+
self._scheduled_event_id = scheduled_event_id
397+
self._endpoint = endpoint
398+
self._service = service
399+
self._operation = operation
400+
self._operation_token = operation_token
401+
402+
@property
403+
def scheduled_event_id(self) -> int:
404+
"""The NexusOperationScheduled event ID for the failed operation."""
405+
return self._scheduled_event_id
406+
407+
@property
408+
def endpoint(self) -> str:
409+
"""The endpoint name for the failed operation."""
410+
return self._endpoint
411+
412+
@property
413+
def service(self) -> str:
414+
"""The service name for the failed operation."""
415+
return self._service
416+
417+
@property
418+
def operation(self) -> str:
419+
"""The name of the failed operation."""
420+
return self._operation
421+
422+
@property
423+
def operation_token(self) -> str:
424+
"""The operation token returned by the failed operation."""
425+
return self._operation_token
426+
427+
365428
def is_cancelled_exception(exception: BaseException) -> bool:
366429
"""Check whether the given exception is considered a cancellation exception
367430
according to Temporal.

0 commit comments

Comments
 (0)