Skip to content

Commit 27fb0fc

Browse files
committed
fix: #1942 Enable async tool calling in Realtime sessions
1 parent 4bc33e3 commit 27fb0fc

File tree

5 files changed

+97
-8
lines changed

5 files changed

+97
-8
lines changed

examples/realtime/app/agent.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import asyncio
2+
13
from agents import function_tool
24
from agents.extensions.handoff_prompt import RECOMMENDED_PROMPT_PREFIX
35
from agents.realtime import RealtimeAgent, realtime_handoff
@@ -13,20 +15,26 @@
1315
name_override="faq_lookup_tool", description_override="Lookup frequently asked questions."
1416
)
1517
async def faq_lookup_tool(question: str) -> str:
16-
if "bag" in question or "baggage" in question:
18+
print("faq_lookup_tool called with question:", question)
19+
20+
# Simulate a slow API call
21+
await asyncio.sleep(3)
22+
23+
q = question.lower()
24+
if "wifi" in q or "wi-fi" in q:
25+
return "We have free wifi on the plane, join Airline-Wifi"
26+
elif "bag" in q or "baggage" in q:
1727
return (
1828
"You are allowed to bring one bag on the plane. "
1929
"It must be under 50 pounds and 22 inches x 14 inches x 9 inches."
2030
)
21-
elif "seats" in question or "plane" in question:
31+
elif "seats" in q or "plane" in q:
2232
return (
2333
"There are 120 seats on the plane. "
2434
"There are 22 business class seats and 98 economy seats. "
2535
"Exit rows are rows 4 and 16. "
2636
"Rows 5-8 are Economy Plus, with extra legroom. "
2737
)
28-
elif "wifi" in question:
29-
return "We have free wifi on the plane, join Airline-Wifi"
3038
return "I'm sorry, I don't know the answer to that question."
3139

3240

examples/realtime/app/server.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from typing_extensions import assert_never
1313

1414
from agents.realtime import RealtimeRunner, RealtimeSession, RealtimeSessionEvent
15-
from agents.realtime.config import RealtimeUserInputMessage
15+
from agents.realtime.config import RealtimeRunConfig, RealtimeUserInputMessage
1616
from agents.realtime.items import RealtimeItem
1717
from agents.realtime.model import RealtimeModelConfig
1818
from agents.realtime.model_inputs import RealtimeModelSendRawMessage
@@ -47,6 +47,9 @@ async def connect(self, websocket: WebSocket, session_id: str):
4747

4848
agent = get_starting_agent()
4949
runner = RealtimeRunner(agent)
50+
# If you want to customize the runner behavior, you can pass options:
51+
# runner_config = RealtimeRunConfig(async_tool_calls=False)
52+
# runner = RealtimeRunner(agent, config=runner_config)
5053
model_config: RealtimeModelConfig = {
5154
"initial_model_settings": {
5255
"turn_detection": {

src/agents/realtime/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,9 @@ class RealtimeRunConfig(TypedDict):
184184
tracing_disabled: NotRequired[bool]
185185
"""Whether tracing is disabled for this run."""
186186

187+
async_tool_calls: NotRequired[bool]
188+
"""Whether function tool calls should run asynchronously. Defaults to True."""
189+
187190
# TODO (rm) Add history audio storage config
188191

189192

src/agents/realtime/session.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ def __init__(
123123
)
124124

125125
self._guardrail_tasks: set[asyncio.Task[Any]] = set()
126+
self._tool_call_tasks: set[asyncio.Task[Any]] = set()
127+
self._async_tool_calls: bool = bool(self._run_config.get("async_tool_calls", True))
126128

127129
@property
128130
def model(self) -> RealtimeModel:
@@ -216,7 +218,10 @@ async def on_event(self, event: RealtimeModelEvent) -> None:
216218
if event.type == "error":
217219
await self._put_event(RealtimeError(info=self._event_info, error=event.error))
218220
elif event.type == "function_call":
219-
await self._handle_tool_call(event)
221+
if self._async_tool_calls:
222+
self._enqueue_tool_call_task(event)
223+
else:
224+
await self._handle_tool_call(event)
220225
elif event.type == "audio":
221226
await self._put_event(
222227
RealtimeAudio(
@@ -752,10 +757,47 @@ def _cleanup_guardrail_tasks(self) -> None:
752757
task.cancel()
753758
self._guardrail_tasks.clear()
754759

760+
def _enqueue_tool_call_task(self, event: RealtimeModelToolCallEvent) -> None:
761+
"""Run tool calls in the background to avoid blocking realtime transport."""
762+
task = asyncio.create_task(self._handle_tool_call(event))
763+
self._tool_call_tasks.add(task)
764+
task.add_done_callback(self._on_tool_call_task_done)
765+
766+
def _on_tool_call_task_done(self, task: asyncio.Task[Any]) -> None:
767+
self._tool_call_tasks.discard(task)
768+
769+
if task.cancelled():
770+
return
771+
772+
exception = task.exception()
773+
if exception is None:
774+
return
775+
776+
logger.exception("Realtime tool call task failed", exc_info=exception)
777+
778+
if self._stored_exception is None:
779+
self._stored_exception = exception
780+
781+
asyncio.create_task(
782+
self._put_event(
783+
RealtimeError(
784+
info=self._event_info,
785+
error={"message": f"Tool call task failed: {exception}"},
786+
)
787+
)
788+
)
789+
790+
def _cleanup_tool_call_tasks(self) -> None:
791+
for task in self._tool_call_tasks:
792+
if not task.done():
793+
task.cancel()
794+
self._tool_call_tasks.clear()
795+
755796
async def _cleanup(self) -> None:
756797
"""Clean up all resources and mark session as closed."""
757798
# Cancel and cleanup guardrail tasks
758799
self._cleanup_guardrail_tasks()
800+
self._cleanup_tool_call_tasks()
759801

760802
# Remove ourselves as a listener
761803
self._model.remove_listener(self)

tests/realtime/test_session.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -561,8 +561,13 @@ async def test_ignored_events_only_generate_raw_events(self, mock_model, mock_ag
561561

562562
@pytest.mark.asyncio
563563
async def test_function_call_event_triggers_tool_handling(self, mock_model, mock_agent):
564-
"""Test that function_call events trigger tool call handling"""
565-
session = RealtimeSession(mock_model, mock_agent, None)
564+
"""Test that function_call events trigger tool call handling synchronously when disabled"""
565+
session = RealtimeSession(
566+
mock_model,
567+
mock_agent,
568+
None,
569+
run_config={"async_tool_calls": False},
570+
)
566571

567572
# Create function call event
568573
function_call_event = RealtimeModelToolCallEvent(
@@ -586,6 +591,34 @@ async def test_function_call_event_triggers_tool_handling(self, mock_model, mock
586591
assert isinstance(raw_event, RealtimeRawModelEvent)
587592
assert raw_event.data == function_call_event
588593

594+
@pytest.mark.asyncio
595+
async def test_function_call_event_runs_async_by_default(self, mock_model, mock_agent):
596+
"""Function call handling should be scheduled asynchronously by default"""
597+
session = RealtimeSession(mock_model, mock_agent, None)
598+
599+
function_call_event = RealtimeModelToolCallEvent(
600+
name="test_function",
601+
call_id="call_async",
602+
arguments='{"param": "value"}',
603+
)
604+
605+
with pytest.MonkeyPatch().context() as m:
606+
handle_tool_call_mock = AsyncMock()
607+
m.setattr(session, "_handle_tool_call", handle_tool_call_mock)
608+
609+
await session.on_event(function_call_event)
610+
611+
# Let the background task run
612+
await asyncio.sleep(0)
613+
614+
handle_tool_call_mock.assert_awaited_once_with(function_call_event)
615+
616+
# Raw event still enqueued
617+
assert session._event_queue.qsize() == 1
618+
raw_event = await session._event_queue.get()
619+
assert isinstance(raw_event, RealtimeRawModelEvent)
620+
assert raw_event.data == function_call_event
621+
589622

590623
class TestHistoryManagement:
591624
"""Test suite for history management and audio transcription in

0 commit comments

Comments
 (0)