Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 19 additions & 26 deletions temporalio/contrib/openai_agents/_trace_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
get_trace_provider,
)
from agents.tracing.scope import Scope
from agents.tracing.spans import NoOpSpan, Span
from agents.tracing.spans import NoOpSpan

import temporalio.activity
import temporalio.api.common.v1
Expand Down Expand Up @@ -402,48 +402,41 @@ async def signal_external_workflow(
def start_activity(
self, input: temporalio.worker.StartActivityInput
) -> temporalio.workflow.ActivityHandle:
# Use synchronous span pattern to avoid context detachment errors.
# Async callbacks (add_done_callback) fire in different context instances,
# breaking OTel's token validation. Instead, complete span immediately
# at orchestration point - execution time is captured by activity worker.
trace = get_trace_provider().get_current_trace()
span: Span | None = None
if trace:
span = custom_span(
with custom_span(
name="temporal:startActivity", data={"activity": input.activity}
)
span.start(mark_as_current=True)

):
pass # Span completes immediately in same context
set_header_from_context(input, temporalio.workflow.payload_converter())
handle = self.next.start_activity(input)
if span:
handle.add_done_callback(lambda _: span.finish()) # type: ignore
return handle
return self.next.start_activity(input)

async def start_child_workflow(
self, input: temporalio.worker.StartChildWorkflowInput
) -> temporalio.workflow.ChildWorkflowHandle:
# Use synchronous span pattern - see start_activity for explanation
trace = get_trace_provider().get_current_trace()
span: Span | None = None
if trace:
span = custom_span(
with custom_span(
name="temporal:startChildWorkflow", data={"workflow": input.workflow}
)
span.start(mark_as_current=True)
):
pass # Span completes immediately in same context
set_header_from_context(input, temporalio.workflow.payload_converter())
handle = await self.next.start_child_workflow(input)
if span:
handle.add_done_callback(lambda _: span.finish()) # type: ignore
return handle
return await self.next.start_child_workflow(input)

def start_local_activity(
self, input: temporalio.worker.StartLocalActivityInput
) -> temporalio.workflow.ActivityHandle:
# Use synchronous span pattern - see start_activity for explanation
trace = get_trace_provider().get_current_trace()
span: Span | None = None
if trace:
span = custom_span(
with custom_span(
name="temporal:startLocalActivity", data={"activity": input.activity}
)
span.start(mark_as_current=True)
):
pass # Span completes immediately in same context
set_header_from_context(input, temporalio.workflow.payload_converter())
handle = self.next.start_local_activity(input)
if span:
handle.add_done_callback(lambda _: span.finish()) # type: ignore
return handle
return self.next.start_local_activity(input)
139 changes: 73 additions & 66 deletions tests/worker/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1992,77 +1992,84 @@ def last_signal(self) -> str:


async def test_workflow_logging(client: Client):
original_full_workflow_info_on_extra = workflow.logger.full_workflow_info_on_extra
workflow.logger.full_workflow_info_on_extra = True
with LogCapturer().logs_captured(
workflow.logger.base_logger, activity.logger.base_logger
) as capturer:
# Log two signals and kill worker before completing. Need to disable
# workflow cache since we restart the worker and don't want to pay the
# sticky queue penalty.
async with new_worker(
client, LoggingWorkflow, max_cached_workflows=0
) as worker:
handle = await client.start_workflow(
LoggingWorkflow.run,
id=f"workflow-{uuid.uuid4()}",
task_queue=worker.task_queue,
)
# Send some signals and updates
await handle.signal(LoggingWorkflow.my_signal, "signal 1")
await handle.signal(LoggingWorkflow.my_signal, "signal 2")
await handle.execute_update(
LoggingWorkflow.my_update, "update 1", id="update-1"
try:
with LogCapturer().logs_captured(
workflow.logger.base_logger, activity.logger.base_logger
) as capturer:
# --- First execution: logs should appear ---
# Disable workflow cache so worker restart triggers replay
async with new_worker(
client, LoggingWorkflow, max_cached_workflows=0
) as worker:
handle = await client.start_workflow(
LoggingWorkflow.run,
id=f"workflow-{uuid.uuid4()}",
task_queue=worker.task_queue,
)
await handle.signal(LoggingWorkflow.my_signal, "signal 1")
await handle.signal(LoggingWorkflow.my_signal, "signal 2")
await handle.execute_update(
LoggingWorkflow.my_update, "update 1", id="update-1"
)
await handle.execute_update(
LoggingWorkflow.my_update, "update 2", id="update-2"
)
assert "signal 2" == await handle.query(LoggingWorkflow.last_signal)

# Verify logs from first execution
assert capturer.find_log("Signal: signal 1 ({'attempt':")
assert capturer.find_log("Signal: signal 2")
assert capturer.find_log("Update: update 1")
assert capturer.find_log("Update: update 2")
assert capturer.find_log("Query called")
assert not capturer.find_log("Signal: signal 3")

# Verify workflow info and funcName
record = capturer.find_log("Signal: signal 1")
assert (
record
and record.__dict__["temporal_workflow"]["workflow_type"]
== "LoggingWorkflow"
and record.funcName == "my_signal"
)
await handle.execute_update(
LoggingWorkflow.my_update, "update 2", id="update-2"

# Verify full_workflow_info_on_extra
assert isinstance(record.__dict__["workflow_info"], workflow.Info)

# Verify update logging
record = capturer.find_log("Update: update 1")
assert (
record
and record.__dict__["temporal_workflow"]["update_id"] == "update-1"
and record.__dict__["temporal_workflow"]["update_name"] == "my_update"
and "'update_id': 'update-1'" in record.message
and "'update_name': 'my_update'" in record.message
)
assert "signal 2" == await handle.query(LoggingWorkflow.last_signal)

# Confirm logs were produced
assert capturer.find_log("Signal: signal 1 ({'attempt':")
assert capturer.find_log("Signal: signal 2")
assert capturer.find_log("Update: update 1")
assert capturer.find_log("Update: update 2")
assert capturer.find_log("Query called")
assert not capturer.find_log("Signal: signal 3")
# Also make sure it has some workflow info and correct funcName
record = capturer.find_log("Signal: signal 1")
assert (
record
and record.__dict__["temporal_workflow"]["workflow_type"]
== "LoggingWorkflow"
and record.funcName == "my_signal"
)
# Since we enabled full info, make sure it's there
assert isinstance(record.__dict__["workflow_info"], workflow.Info)
# Check the log emitted by the update execution.
record = capturer.find_log("Update: update 1")
assert (
record
and record.__dict__["temporal_workflow"]["update_id"] == "update-1"
and record.__dict__["temporal_workflow"]["update_name"] == "my_update"
and "'update_id': 'update-1'" in record.message
and "'update_name': 'my_update'" in record.message
)

# Clear queue and start a new one with more signals
capturer.log_queue.queue.clear()
async with new_worker(
client,
LoggingWorkflow,
task_queue=worker.task_queue,
max_cached_workflows=0,
) as worker:
# Send signals and updates
await handle.signal(LoggingWorkflow.my_signal, "signal 3")
await handle.signal(LoggingWorkflow.my_signal, "finish")
await handle.result()
# --- Clear logs and continue execution (replay path) ---
# When the new worker starts, it replays the workflow history (signals 1 & 2).
# Replay suppression should prevent those logs from appearing again.
capturer.log_queue.queue.clear()

# Confirm replayed logs are not present but new ones are
assert not capturer.find_log("Signal: signal 1")
assert not capturer.find_log("Signal: signal 2")
assert capturer.find_log("Signal: signal 3")
assert capturer.find_log("Signal: finish")
async with new_worker(
client,
LoggingWorkflow,
task_queue=worker.task_queue,
max_cached_workflows=0,
) as worker:
await handle.signal(LoggingWorkflow.my_signal, "signal 3")
await handle.signal(LoggingWorkflow.my_signal, "finish")
await handle.result()

# --- Replay execution: no duplicate logs ---
assert not capturer.find_log("Signal: signal 1")
assert not capturer.find_log("Signal: signal 2")
assert capturer.find_log("Signal: signal 3")
assert capturer.find_log("Signal: finish")
finally:
workflow.logger.full_workflow_info_on_extra = original_full_workflow_info_on_extra


@activity.defn
Expand Down