Skip to content

Commit b7b1bf7

Browse files
committed
Realtime: run guardrails without blockign event loop
1 parent 741da67 commit b7b1bf7

File tree

3 files changed

+57
-1
lines changed

3 files changed

+57
-1
lines changed

src/agents/realtime/events.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ class RealtimeGuardrailTripped:
197197

198198
type: Literal["guardrail_tripped"] = "guardrail_tripped"
199199

200+
200201
RealtimeSessionEvent: TypeAlias = Union[
201202
RealtimeAgentStartEvent,
202203
RealtimeAgentEndEvent,

src/agents/realtime/session.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ def __init__(
9393
"debounce_text_length", 100
9494
)
9595

96+
self._guardrail_tasks: set[asyncio.Task[Any]] = set()
97+
9698
async def __aenter__(self) -> RealtimeSession:
9799
"""Start the session by connecting to the model. After this, you will be able to stream
98100
events from the model and send messages and audio to the model.
@@ -136,6 +138,7 @@ async def __aiter__(self) -> AsyncIterator[RealtimeSessionEvent]:
136138
async def close(self) -> None:
137139
"""Close the session."""
138140
self._closed = True
141+
self._cleanup_guardrail_tasks()
139142
self._model.remove_listener(self)
140143
await self._model.close()
141144

@@ -185,7 +188,7 @@ async def on_event(self, event: RealtimeModelEvent) -> None:
185188

186189
if current_length >= next_run_threshold:
187190
self._item_guardrail_run_counts[item_id] += 1
188-
await self._run_output_guardrails(self._item_transcripts[item_id])
191+
self._enqueue_guardrail_task(self._item_transcripts[item_id])
189192
elif event.type == "item_updated":
190193
is_new = not any(item.item_id == event.item.item_id for item in self._history)
191194
self._history = self._get_new_history(self._history, event.item)
@@ -366,3 +369,37 @@ async def _run_output_guardrails(self, text: str) -> bool:
366369
return True
367370

368371
return False
372+
373+
def _enqueue_guardrail_task(self, text: str) -> None:
374+
# Runs the guardrails in a separate task to avoid blocking the main loop
375+
376+
task = asyncio.create_task(self._run_output_guardrails(text))
377+
self._guardrail_tasks.add(task)
378+
379+
# Add callback to remove completed tasks and handle exceptions
380+
task.add_done_callback(self._on_guardrail_task_done)
381+
382+
def _on_guardrail_task_done(self, task: asyncio.Task[Any]) -> None:
383+
"""Handle completion of a guardrail task."""
384+
# Remove from tracking set
385+
self._guardrail_tasks.discard(task)
386+
387+
# Check for exceptions and propagate as events
388+
if not task.cancelled():
389+
exception = task.exception()
390+
if exception:
391+
# Create an exception event instead of raising
392+
asyncio.create_task(
393+
self._put_event(
394+
RealtimeError(
395+
info=self._event_info,
396+
error={"message": f"Guardrail task failed: {str(exception)}"},
397+
)
398+
)
399+
)
400+
401+
def _cleanup_guardrail_tasks(self) -> None:
402+
for task in self._guardrail_tasks:
403+
if not task.done():
404+
task.cancel()
405+
self._guardrail_tasks.clear()

tests/realtime/test_session.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -971,6 +971,12 @@ async def test_mixed_tool_types_filtering(self, mock_model, mock_agent):
971971
class TestGuardrailFunctionality:
972972
"""Test suite for output guardrail functionality in RealtimeSession"""
973973

974+
async def _wait_for_guardrail_tasks(self, session):
975+
"""Wait for all pending guardrail tasks to complete."""
976+
import asyncio
977+
if session._guardrail_tasks:
978+
await asyncio.gather(*session._guardrail_tasks, return_exceptions=True)
979+
974980
@pytest.fixture
975981
def triggered_guardrail(self):
976982
"""Creates a guardrail that always triggers"""
@@ -1009,6 +1015,9 @@ async def test_transcript_delta_triggers_guardrail_at_threshold(
10091015
)
10101016

10111017
await session.on_event(transcript_event)
1018+
1019+
# Wait for async guardrail tasks to complete
1020+
await self._wait_for_guardrail_tasks(session)
10121021

10131022
# Should have triggered guardrail and interrupted
10141023
assert session._interrupted_by_guardrail is True
@@ -1046,6 +1055,9 @@ async def test_transcript_delta_multiple_thresholds_same_item(
10461055
await session.on_event(RealtimeModelTranscriptDeltaEvent(
10471056
item_id="item_1", delta="67890", response_id="resp_1"
10481057
))
1058+
1059+
# Wait for async guardrail tasks to complete
1060+
await self._wait_for_guardrail_tasks(session)
10491061

10501062
# Should only trigger once due to interrupted_by_guardrail flag
10511063
assert mock_model.interrupts_called == 1
@@ -1104,6 +1116,9 @@ async def test_turn_ended_clears_guardrail_state(
11041116
await session.on_event(RealtimeModelTranscriptDeltaEvent(
11051117
item_id="item_1", delta="trigger", response_id="resp_1"
11061118
))
1119+
1120+
# Wait for async guardrail tasks to complete
1121+
await self._wait_for_guardrail_tasks(session)
11071122

11081123
assert session._interrupted_by_guardrail is True
11091124
assert len(session._item_transcripts) == 1
@@ -1142,6 +1157,9 @@ def guardrail_func(context, agent, output):
11421157
await session.on_event(RealtimeModelTranscriptDeltaEvent(
11431158
item_id="item_1", delta="trigger", response_id="resp_1"
11441159
))
1160+
1161+
# Wait for async guardrail tasks to complete
1162+
await self._wait_for_guardrail_tasks(session)
11451163

11461164
# Should have interrupted and sent message with both guardrail names
11471165
assert mock_model.interrupts_called == 1

0 commit comments

Comments
 (0)