Skip to content
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions scripts/gen_bridge_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def generate_rust_service_call(service_descriptor: ServiceDescriptor) -> str:
py: Python<'p>,
call: RpcCall,
) -> PyResult<Bound<'p, PyAny>> {
self.runtime.assert_same_process("use client")?;
use temporal_client::${descriptor_name};
let mut retry_client = self.retry_client.clone();
self.runtime.future_into_py(py, async move {
Expand Down
1 change: 1 addition & 0 deletions temporalio/bridge/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ pub fn connect_client<'a>(
config: ClientConfig,
) -> PyResult<Bound<'a, PyAny>> {
let opts: ClientOptions = config.try_into()?;
runtime_ref.runtime.assert_same_process("create client")?;
let runtime = runtime_ref.runtime.clone();
runtime_ref.runtime.future_into_py(py, async move {
Ok(ClientRef {
Expand Down
5 changes: 5 additions & 0 deletions temporalio/bridge/src/client_rpc_generated.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ impl ClientRef {
py: Python<'p>,
call: RpcCall,
) -> PyResult<Bound<'p, PyAny>> {
self.runtime.assert_same_process("use client")?;
use temporal_client::WorkflowService;
let mut retry_client = self.retry_client.clone();
self.runtime.future_into_py(py, async move {
Expand Down Expand Up @@ -566,6 +567,7 @@ impl ClientRef {
py: Python<'p>,
call: RpcCall,
) -> PyResult<Bound<'p, PyAny>> {
self.runtime.assert_same_process("use client")?;
use temporal_client::OperatorService;
let mut retry_client = self.retry_client.clone();
self.runtime.future_into_py(py, async move {
Expand Down Expand Up @@ -628,6 +630,7 @@ impl ClientRef {
}

fn call_cloud_service<'p>(&self, py: Python<'p>, call: RpcCall) -> PyResult<Bound<'p, PyAny>> {
self.runtime.assert_same_process("use client")?;
use temporal_client::CloudService;
let mut retry_client = self.retry_client.clone();
self.runtime.future_into_py(py, async move {
Expand Down Expand Up @@ -842,6 +845,7 @@ impl ClientRef {
}

fn call_test_service<'p>(&self, py: Python<'p>, call: RpcCall) -> PyResult<Bound<'p, PyAny>> {
self.runtime.assert_same_process("use client")?;
use temporal_client::TestService;
let mut retry_client = self.retry_client.clone();
self.runtime.future_into_py(py, async move {
Expand Down Expand Up @@ -881,6 +885,7 @@ impl ClientRef {
}

fn call_health_service<'p>(&self, py: Python<'p>, call: RpcCall) -> PyResult<Bound<'p, PyAny>> {
self.runtime.assert_same_process("use client")?;
use temporal_client::HealthService;
let mut retry_client = self.retry_client.clone();
self.runtime.future_into_py(py, async move {
Expand Down
16 changes: 15 additions & 1 deletion temporalio/bridge/src/runtime.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use futures::channel::mpsc::Receiver;
use pyo3::exceptions::{PyRuntimeError, PyValueError};
use pyo3::exceptions::{PyAssertionError, PyRuntimeError, PyValueError};
use pyo3::prelude::*;
use pythonize::pythonize;
use std::collections::HashMap;
Expand Down Expand Up @@ -33,6 +33,7 @@ pub struct RuntimeRef {

#[derive(Clone)]
pub(crate) struct Runtime {
pub(crate) pid: u32,
pub(crate) core: Arc<CoreRuntime>,
metrics_call_buffer: Option<Arc<MetricsCallBuffer<BufferedMetricRef>>>,
log_forwarder_handle: Option<Arc<JoinHandle<()>>>,
Expand Down Expand Up @@ -173,6 +174,7 @@ pub fn init_runtime(telemetry_config: TelemetryConfig) -> PyResult<RuntimeRef> {

Ok(RuntimeRef {
runtime: Runtime {
pid: std::process::id(),
core: Arc::new(core),
metrics_call_buffer,
log_forwarder_handle,
Expand All @@ -197,6 +199,18 @@ impl Runtime {
let _guard = self.core.tokio_handle().enter();
pyo3_async_runtimes::generic::future_into_py::<TokioRuntime, _, T>(py, fut)
}

pub(crate) fn assert_same_process(&self, action: &'static str) -> PyResult<()> {
let current_pid = std::process::id();
if self.pid != current_pid {
Err(PyAssertionError::new_err(format!(
"Cannot {} across forks (original runtime PID is {}, current is {})",
action, self.pid, current_pid,
)))
} else {
Ok(())
}
}
}

impl Drop for Runtime {
Expand Down
8 changes: 8 additions & 0 deletions temporalio/bridge/src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,7 @@ pub fn new_worker(
config: WorkerConfig,
) -> PyResult<WorkerRef> {
enter_sync!(runtime_ref.runtime);
runtime_ref.runtime.assert_same_process("create worker")?;
let event_loop_task_locals = Arc::new(OnceLock::new());
let config = convert_worker_config(config, event_loop_task_locals.clone())?;
let worker = temporal_sdk_core::init_worker(
Expand All @@ -495,6 +496,9 @@ pub fn new_replay_worker<'a>(
config: WorkerConfig,
) -> PyResult<Bound<'a, PyTuple>> {
enter_sync!(runtime_ref.runtime);
runtime_ref
.runtime
.assert_same_process("create replay worker")?;
let event_loop_task_locals = Arc::new(OnceLock::new());
let config = convert_worker_config(config, event_loop_task_locals.clone())?;
let (history_pusher, stream) = HistoryPusher::new(runtime_ref.runtime.clone());
Expand All @@ -519,6 +523,7 @@ pub fn new_replay_worker<'a>(
#[pymethods]
impl WorkerRef {
fn validate<'p>(&self, py: Python<'p>) -> PyResult<Bound<PyAny, 'p>> {
self.runtime.assert_same_process("use worker")?;
let worker = self.worker.as_ref().unwrap().clone();
// Set custom slot supplier task locals so they can run futures.
// Event loop is assumed to be running at this point.
Expand All @@ -538,6 +543,7 @@ impl WorkerRef {
}

fn poll_workflow_activation<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
self.runtime.assert_same_process("use worker")?;
let worker = self.worker.as_ref().unwrap().clone();
self.runtime.future_into_py(py, async move {
let bytes = match worker.poll_workflow_activation().await {
Expand All @@ -550,6 +556,7 @@ impl WorkerRef {
}

fn poll_activity_task<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
self.runtime.assert_same_process("use worker")?;
let worker = self.worker.as_ref().unwrap().clone();
self.runtime.future_into_py(py, async move {
let bytes = match worker.poll_activity_task().await {
Expand All @@ -562,6 +569,7 @@ impl WorkerRef {
}

fn poll_nexus_task<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
self.runtime.assert_same_process("use worker")?;
let worker = self.worker.as_ref().unwrap().clone();
self.runtime.future_into_py(py, async move {
let bytes = match worker.poll_nexus_task().await {
Expand Down
73 changes: 61 additions & 12 deletions temporalio/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,33 +24,83 @@
import temporalio.bridge.runtime
import temporalio.common

_default_runtime: Optional[Runtime] = None

class _RuntimeRef:
def __init__(
self,
) -> None:
self._default_runtime: Runtime | None = None
self._prevent_default = False
self._default_created = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this field needed if _default_runtime's None-ness can be used as basically the same thing? And therefore if you have default runtime and prevent default, do they deserve an entirely new class/encapsulation as opposed to just adding a single _prevent_default global alongside the existing global?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My original thinking for _default_created was to differentiate between the default being lazily created and one being set via set_default. I can see that not being worth tracking and just raising in prevent_default if an existing _default_runtime exists. I think I'll make that change because it similarly makes sense to avoid calling prevent_default after set_default. Even though that would be harmless, it is a weird thing to do.

The primary purpose of the encapsulation was just to write tests in a way that plays nicely with existing fixtures. Definitely happy to consider alternatives here.

Copy link
Member

@cretz cretz Nov 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Meh, if we have to extract out to a new class for test only, ok. It makes for a bit uglier stack trace for users, but not a big deal.


def default(self) -> Runtime:
if not self._default_runtime:
if self._prevent_default:
raise RuntimeError(
"Cannot create default Runtime after Runtime.prevent_default has been called"
)
self._default_runtime = Runtime(telemetry=TelemetryConfig())
self._default_created = True
return self._default_runtime

def prevent_default(self):
if self._default_created:
raise RuntimeError(
"Runtime.prevent_default called after default runtime has been created"
)
self._prevent_default = True

def set_default(
self, runtime: Runtime, *, error_if_already_set: bool = True
) -> None:
if self._default_runtime and error_if_already_set:
raise RuntimeError("Runtime default already set")

self._default_runtime = runtime


_runtime_ref: _RuntimeRef = _RuntimeRef()


class Runtime:
"""Runtime for Temporal Python SDK.

Users are encouraged to use :py:meth:`default`. It can be set with
Most users are encouraged to use :py:meth:`default`. It can be set with
:py:meth:`set_default`. Every time a new runtime is created, a new internal
thread pool is created.

Runtimes do not work across forks.
Runtimes do not work across forks. Advanced users should consider using
:py:meth:`prevent_default` and `:py:meth`set_default` to ensure each
fork creates it's own runtime.

"""

@classmethod
def default(cls) -> Runtime:
"""Get the default runtime, creating if not already created.
"""Get the default runtime, creating if not already created. If :py:meth:`prevent_default`
is called before this method it will raise a RuntimeError instead of creating a default
runtime.

If the default runtime needs to be different, it should be done with
:py:meth:`set_default` before this is called or ever used.

Returns:
The default runtime.
"""
global _default_runtime
if not _default_runtime:
_default_runtime = cls(telemetry=TelemetryConfig())
return _default_runtime
global _runtime_ref
return _runtime_ref.default()

@classmethod
def prevent_default(cls):
"""Prevent :py:meth:`default` from lazily creating a :py:class:`Runtime`.

Raises a RuntimeError if a default :py:class:`Runtime` has already been created.

Explicitly setting a default runtime with :py:meth:`set_default` bypasses this setting and
future calls to :py:meth:`default` will return the provided runtime.
"""
global _runtime_ref
_runtime_ref.prevent_default()

@staticmethod
def set_default(runtime: Runtime, *, error_if_already_set: bool = True) -> None:
Expand All @@ -65,10 +115,9 @@ def set_default(runtime: Runtime, *, error_if_already_set: bool = True) -> None:
error_if_already_set: If True and default is already set, this will
raise a RuntimeError.
"""
global _default_runtime
if _default_runtime and error_if_already_set:
raise RuntimeError("Runtime default already set")
_default_runtime = runtime
global _runtime_ref
_runtime_ref.set_default(runtime, error_if_already_set=error_if_already_set)
return
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return


def __init__(self, *, telemetry: TelemetryConfig) -> None:
"""Create a default runtime with the given telemetry config.
Expand Down
20 changes: 19 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import asyncio
import multiprocessing.context
import os
import sys
from typing import AsyncGenerator
from typing import AsyncGenerator, Iterator

import pytest
import pytest_asyncio
Expand Down Expand Up @@ -133,6 +134,23 @@ async def env(env_type: str) -> AsyncGenerator[WorkflowEnvironment, None]:
await env.shutdown()


@pytest.fixture(scope="session")
def mp_fork_ctx() -> Iterator[multiprocessing.context.BaseContext | None]:
mp_ctx = None
try:
mp_ctx = multiprocessing.get_context("fork")
except ValueError:
pass

try:
yield mp_ctx
finally:
if mp_ctx:
for p in mp_ctx.active_children():
p.terminate()
p.join()


@pytest_asyncio.fixture
async def client(env: WorkflowEnvironment) -> Client:
return env.client
Expand Down
80 changes: 80 additions & 0 deletions tests/helpers/fork.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from __future__ import annotations

import asyncio
import multiprocessing
import multiprocessing.context
import sys
from dataclasses import dataclass
from typing import Any

import pytest


@dataclass
class _ForkTestResult:
status: str
err_name: str | None
err_msg: str | None

def __eq__(self, value: object) -> bool:
if not isinstance(value, _ForkTestResult):
return False

valid_err_msg = False

if self.err_msg and value.err_msg:
valid_err_msg = (
self.err_msg in value.err_msg or value.err_msg in self.err_msg
)

return (
value.status == self.status
and value.err_name == value.err_name
and valid_err_msg
)

@staticmethod
def assertion_error(message: str) -> _ForkTestResult:
return _ForkTestResult(
status="error", err_name="AssertionError", err_msg=message
)


class _TestFork:
_expected: _ForkTestResult

async def coro(self) -> Any:
raise NotImplementedError()

def entry(self):
event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(event_loop)
try:
event_loop.run_until_complete(self.coro())
payload = _ForkTestResult(status="ok", err_name=None, err_msg=None)
except BaseException as err:
payload = _ForkTestResult(
status="error", err_name=err.__class__.__name__, err_msg=str(err)
)

self._child_conn.send(payload)
self._child_conn.close()

def run(self, mp_fork_context: multiprocessing.context.BaseContext | None):
process_factory = getattr(mp_fork_context, "Process", None)

if not mp_fork_context or not process_factory:
pytest.skip("fork context not available")

self._parent_conn, self._child_conn = mp_fork_context.Pipe(duplex=False)
# start fork
child_process = process_factory(target=self.entry, args=(), daemon=False)
child_process.start()
# close parent's handle on child_conn
self._child_conn.close()

# get run info from pipe
payload = self._parent_conn.recv()
self._parent_conn.close()

assert payload == self._expected
Loading