Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ temporalio/bridge/temporal_sdk_bridge*
/sdk-python.iml
/.zed
*.DS_Store
tags
49 changes: 27 additions & 22 deletions temporalio/contrib/opentelemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,29 +172,34 @@ def _start_as_current_span(
kind: opentelemetry.trace.SpanKind,
context: Optional[Context] = None,
) -> Iterator[None]:
with self.tracer.start_as_current_span(
name,
attributes=attributes,
kind=kind,
context=context,
set_status_on_exception=False,
) as span:
if input:
input.headers = self._context_to_headers(input.headers)
try:
yield None
except Exception as exc:
if (
not isinstance(exc, ApplicationError)
or exc.category != ApplicationErrorCategory.BENIGN
):
span.set_status(
Status(
status_code=StatusCode.ERROR,
description=f"{type(exc).__name__}: {exc}",
token = opentelemetry.context.attach(context) if context else None
try:
with self.tracer.start_as_current_span(
name,
attributes=attributes,
kind=kind,
context=context,
set_status_on_exception=False,
) as span:
if input:
input.headers = self._context_to_headers(input.headers)
try:
yield None
except Exception as exc:
if (
not isinstance(exc, ApplicationError)
or exc.category != ApplicationErrorCategory.BENIGN
):
span.set_status(
Status(
status_code=StatusCode.ERROR,
description=f"{type(exc).__name__}: {exc}",
)
)
)
raise
raise
finally:
if token and context is opentelemetry.context.get_current():
opentelemetry.context.detach(token)

def _completed_workflow_span(
self, params: _CompletedWorkflowSpanParams
Expand Down
290 changes: 283 additions & 7 deletions tests/contrib/test_opentelemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@
import gc
import logging
import queue
import sys
import threading
import uuid
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import timedelta
from typing import Iterable, List, Optional
from typing import Callable, Dict, Generator, Iterable, List, Optional, cast

import opentelemetry.context
import pytest
from opentelemetry import baggage, context
from opentelemetry.sdk.trace import ReadableSpan, TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
Expand All @@ -31,11 +32,6 @@
from temporalio.testing import WorkflowEnvironment
from temporalio.worker import UnsandboxedWorkflowRunner, Worker
from tests.helpers import LogCapturer
from tests.helpers.cache_eviction import (
CacheEvictionTearDownWorkflow,
WaitForeverWorkflow,
wait_forever_activity,
)


@dataclass
Expand Down Expand Up @@ -558,6 +554,286 @@ async def test_opentelemetry_benign_exception(client: Client):
assert all(span.status.status_code == StatusCode.UNSET for span in spans)


@contextmanager
def baggage_values(values: Dict[str, str]) -> Generator[None, None, None]:
ctx = context.get_current()
for key, value in values.items():
ctx = baggage.set_baggage(key, value, context=ctx)

token = context.attach(ctx)
try:
yield
finally:
context.detach(token)


@pytest.fixture
def client_with_tracing(client: Client) -> Client:
tracer = get_tracer(__name__, tracer_provider=TracerProvider())
client_config = client.config()
client_config["interceptors"] = [TracingInterceptor(tracer)]
return Client(**client_config)


def get_baggage_value(key: str) -> str:
return cast("str", baggage.get_baggage(key))


@activity.defn
async def read_baggage_activity() -> Dict[str, str]:
return {
"user_id": get_baggage_value("user.id"),
"tenant_id": get_baggage_value("tenant.id"),
}


@workflow.defn
class ReadBaggageTestWorkflow:
@workflow.run
async def run(self) -> Dict[str, str]:
return await workflow.execute_activity(
read_baggage_activity,
start_to_close_timeout=timedelta(seconds=10),
)


async def test_opentelemetry_baggage_propagation_basic(
client_with_tracing: Client, env: WorkflowEnvironment
):
task_queue = f"task_queue_{uuid.uuid4()}"
async with Worker(
client_with_tracing,
task_queue=task_queue,
workflows=[ReadBaggageTestWorkflow],
activities=[read_baggage_activity],
):
with baggage_values({"user.id": "test-user-123", "tenant.id": "some-corp"}):
result = await client_with_tracing.execute_workflow(
ReadBaggageTestWorkflow.run,
id=f"workflow_{uuid.uuid4()}",
task_queue=task_queue,
)

assert (
result["user_id"] == "test-user-123"
), "user.id baggage should propagate to activity"
assert (
result["tenant_id"] == "some-corp"
), "tenant.id baggage should propagate to activity"


@activity.defn
async def read_baggage_local_activity() -> Dict[str, str]:
return cast(
Dict[str, str],
{
"user_id": get_baggage_value("user.id"),
"tenant_id": get_baggage_value("tenant.id"),
},
)


@workflow.defn
class LocalActivityBaggageTestWorkflow:
@workflow.run
async def run(self) -> Dict[str, str]:
return await workflow.execute_local_activity(
read_baggage_local_activity,
start_to_close_timeout=timedelta(seconds=10),
)


async def test_opentelemetry_baggage_propagation_local_activity(
client_with_tracing: Client, env: WorkflowEnvironment
):
task_queue = f"task_queue_{uuid.uuid4()}"
async with Worker(
client_with_tracing,
task_queue=task_queue,
workflows=[LocalActivityBaggageTestWorkflow],
activities=[read_baggage_local_activity],
):
with baggage_values(
{
"user.id": "test-user-456",
"tenant.id": "local-corp",
}
):
result = await client_with_tracing.execute_workflow(
LocalActivityBaggageTestWorkflow.run,
id=f"workflow_{uuid.uuid4()}",
task_queue=task_queue,
)

assert result["user_id"] == "test-user-456"
assert result["tenant_id"] == "local-corp"


retry_attempt_baggage_values: List[str] = []


@activity.defn
async def failing_baggage_activity() -> None:
retry_attempt_baggage_values.append(get_baggage_value("user.id"))
if activity.info().attempt < 2:
raise RuntimeError("Intentional failure")


@workflow.defn
class RetryBaggageTestWorkflow:
@workflow.run
async def run(self) -> None:
await workflow.execute_activity(
failing_baggage_activity,
start_to_close_timeout=timedelta(seconds=10),
retry_policy=RetryPolicy(initial_interval=timedelta(milliseconds=1)),
)


async def test_opentelemetry_baggage_propagation_with_retries(
client_with_tracing: Client, env: WorkflowEnvironment
) -> None:
global retry_attempt_baggage_values
retry_attempt_baggage_values = []

task_queue = f"task_queue_{uuid.uuid4()}"
async with Worker(
client_with_tracing,
task_queue=task_queue,
workflows=[RetryBaggageTestWorkflow],
activities=[failing_baggage_activity],
):
with baggage_values({"user.id": "test-user-retry"}):
await client_with_tracing.execute_workflow(
RetryBaggageTestWorkflow.run,
id=f"workflow_{uuid.uuid4()}",
task_queue=task_queue,
)

# Verify baggage was present on all attempts
assert len(retry_attempt_baggage_values) == 2
assert all(v == "test-user-retry" for v in retry_attempt_baggage_values)


@activity.defn
async def context_clear_noop_activity() -> None:
pass


@activity.defn
async def context_clear_exception_activity() -> None:
raise Exception("Simulated exception")


@workflow.defn
class ContextClearWorkflow:
@workflow.run
async def run(self) -> None:
await workflow.execute_activity(
context_clear_noop_activity,
start_to_close_timeout=timedelta(seconds=10),
retry_policy=RetryPolicy(
maximum_attempts=1, initial_interval=timedelta(milliseconds=1)
),
)


@pytest.mark.parametrize(
"activity,expect_failure",
[
(context_clear_noop_activity, not True),
(context_clear_exception_activity, True),
],
)
async def test_opentelemetry_context_restored_after_activity(
client_with_tracing: Client,
env: WorkflowEnvironment,
activity: Callable[[], None],
expect_failure: bool,
) -> None:
attach_count = 0
detach_count = 0
original_attach = context.attach
original_detach = context.detach

def tracked_attach(ctx):
nonlocal attach_count
attach_count += 1
return original_attach(ctx)

def tracked_detach(token):
nonlocal detach_count
detach_count += 1
return original_detach(token)

context.attach = tracked_attach
context.detach = tracked_detach

try:
task_queue = f"task_queue_{uuid.uuid4()}"
async with Worker(
client_with_tracing,
task_queue=task_queue,
workflows=[ContextClearWorkflow],
activities=[activity],
):
with baggage_values({"user.id": "test-123"}):
try:
await client_with_tracing.execute_workflow(
ContextClearWorkflow.run,
id=f"workflow_{uuid.uuid4()}",
task_queue=task_queue,
)
assert (
not expect_failure
), "This test should have raised an exception"
except Exception:
assert expect_failure, "This test is not expeced to raise"

assert (
attach_count == detach_count
), f"Context leak detected: {attach_count} attaches vs {detach_count} detaches. "
assert attach_count > 0, "Expected at least one context attach/detach"

finally:
context.attach = original_attach
context.detach = original_detach


@activity.defn
async def simple_no_context_activity() -> str:
return "success"


@workflow.defn
class SimpleNoContextWorkflow:
@workflow.run
async def run(self) -> str:
return await workflow.execute_activity(
simple_no_context_activity,
start_to_close_timeout=timedelta(seconds=10),
)


async def test_opentelemetry_interceptor_works_if_no_context(
client_with_tracing: Client, env: WorkflowEnvironment
):
task_queue = f"task_queue_{uuid.uuid4()}"
async with Worker(
client_with_tracing,
task_queue=task_queue,
workflows=[SimpleNoContextWorkflow],
activities=[simple_no_context_activity],
):
result = await client_with_tracing.execute_workflow(
SimpleNoContextWorkflow.run,
id=f"workflow_{uuid.uuid4()}",
task_queue=task_queue,
)

assert result == "success"


# TODO(cretz): Additional tests to write
# * query without interceptor (no headers)
# * workflow without interceptor (no headers) but query with interceptor (headers)
Expand Down
Loading