Skip to content

Realtime: run guardrails without blockign event loop #1104

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/agents/realtime/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ class RealtimeGuardrailTripped:

type: Literal["guardrail_tripped"] = "guardrail_tripped"


RealtimeSessionEvent: TypeAlias = Union[
RealtimeAgentStartEvent,
RealtimeAgentEndEvent,
Expand Down
39 changes: 38 additions & 1 deletion src/agents/realtime/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def __init__(
"debounce_text_length", 100
)

self._guardrail_tasks: set[asyncio.Task[Any]] = set()

async def __aenter__(self) -> RealtimeSession:
"""Start the session by connecting to the model. After this, you will be able to stream
events from the model and send messages and audio to the model.
Expand Down Expand Up @@ -136,6 +138,7 @@ async def __aiter__(self) -> AsyncIterator[RealtimeSessionEvent]:
async def close(self) -> None:
"""Close the session."""
self._closed = True
self._cleanup_guardrail_tasks()
self._model.remove_listener(self)
await self._model.close()

Expand Down Expand Up @@ -185,7 +188,7 @@ async def on_event(self, event: RealtimeModelEvent) -> None:

if current_length >= next_run_threshold:
self._item_guardrail_run_counts[item_id] += 1
await self._run_output_guardrails(self._item_transcripts[item_id])
self._enqueue_guardrail_task(self._item_transcripts[item_id])
elif event.type == "item_updated":
is_new = not any(item.item_id == event.item.item_id for item in self._history)
self._history = self._get_new_history(self._history, event.item)
Expand Down Expand Up @@ -366,3 +369,37 @@ async def _run_output_guardrails(self, text: str) -> bool:
return True

return False

def _enqueue_guardrail_task(self, text: str) -> None:
# Runs the guardrails in a separate task to avoid blocking the main loop

task = asyncio.create_task(self._run_output_guardrails(text))
self._guardrail_tasks.add(task)

# Add callback to remove completed tasks and handle exceptions
task.add_done_callback(self._on_guardrail_task_done)

def _on_guardrail_task_done(self, task: asyncio.Task[Any]) -> None:
"""Handle completion of a guardrail task."""
# Remove from tracking set
self._guardrail_tasks.discard(task)

# Check for exceptions and propagate as events
if not task.cancelled():
exception = task.exception()
if exception:
# Create an exception event instead of raising
asyncio.create_task(
self._put_event(
RealtimeError(
info=self._event_info,
error={"message": f"Guardrail task failed: {str(exception)}"},
)
)
)

def _cleanup_guardrail_tasks(self) -> None:
for task in self._guardrail_tasks:
if not task.done():
task.cancel()
self._guardrail_tasks.clear()
18 changes: 18 additions & 0 deletions tests/realtime/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,6 +971,12 @@ async def test_mixed_tool_types_filtering(self, mock_model, mock_agent):
class TestGuardrailFunctionality:
"""Test suite for output guardrail functionality in RealtimeSession"""

async def _wait_for_guardrail_tasks(self, session):
"""Wait for all pending guardrail tasks to complete."""
import asyncio
if session._guardrail_tasks:
await asyncio.gather(*session._guardrail_tasks, return_exceptions=True)

@pytest.fixture
def triggered_guardrail(self):
"""Creates a guardrail that always triggers"""
Expand Down Expand Up @@ -1010,6 +1016,9 @@ async def test_transcript_delta_triggers_guardrail_at_threshold(

await session.on_event(transcript_event)

# Wait for async guardrail tasks to complete
await self._wait_for_guardrail_tasks(session)

# Should have triggered guardrail and interrupted
assert session._interrupted_by_guardrail is True
assert mock_model.interrupts_called == 1
Expand Down Expand Up @@ -1047,6 +1056,9 @@ async def test_transcript_delta_multiple_thresholds_same_item(
item_id="item_1", delta="67890", response_id="resp_1"
))

# Wait for async guardrail tasks to complete
await self._wait_for_guardrail_tasks(session)

# Should only trigger once due to interrupted_by_guardrail flag
assert mock_model.interrupts_called == 1
assert len(mock_model.sent_messages) == 1
Expand Down Expand Up @@ -1105,6 +1117,9 @@ async def test_turn_ended_clears_guardrail_state(
item_id="item_1", delta="trigger", response_id="resp_1"
))

# Wait for async guardrail tasks to complete
await self._wait_for_guardrail_tasks(session)

assert session._interrupted_by_guardrail is True
assert len(session._item_transcripts) == 1

Expand Down Expand Up @@ -1143,6 +1158,9 @@ def guardrail_func(context, agent, output):
item_id="item_1", delta="trigger", response_id="resp_1"
))

# Wait for async guardrail tasks to complete
await self._wait_for_guardrail_tasks(session)

# Should have interrupted and sent message with both guardrail names
assert mock_model.interrupts_called == 1
assert len(mock_model.sent_messages) == 1
Expand Down