Skip to content

Commit a599876

Browse files
committed
Realtime: send session.update event at connection
1 parent 9364bb0 commit a599876

File tree

4 files changed

+175
-116
lines changed

4 files changed

+175
-116
lines changed

examples/realtime/demo.py

Lines changed: 34 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -58,55 +58,45 @@ async def run(self) -> None:
5858
self.session = session
5959
self.ui.set_is_connected(True)
6060
async for event in session:
61-
await self.on_event(event)
61+
await self._on_event(event)
62+
print("done")
6263

6364
# Wait for UI task to complete when session ends
6465
await ui_task
6566

6667
async def on_audio_recorded(self, audio_bytes: bytes) -> None:
67-
"""Called when audio is recorded by the UI."""
68-
try:
69-
# Send the audio to the session
70-
assert self.session is not None
71-
await self.session.send_audio(audio_bytes)
72-
except Exception as e:
73-
self.ui.log_message(f"Error sending audio: {e}")
74-
75-
async def on_event(self, event: RealtimeSessionEvent) -> None:
76-
# Display event in the UI
77-
try:
78-
if event.type == "agent_start":
79-
self.ui.add_transcript(f"Agent started: {event.agent.name}")
80-
elif event.type == "agent_end":
81-
self.ui.add_transcript(f"Agent ended: {event.agent.name}")
82-
elif event.type == "handoff":
83-
self.ui.add_transcript(
84-
f"Handoff from {event.from_agent.name} to {event.to_agent.name}"
85-
)
86-
elif event.type == "tool_start":
87-
self.ui.add_transcript(f"Tool started: {event.tool.name}")
88-
elif event.type == "tool_end":
89-
self.ui.add_transcript(f"Tool ended: {event.tool.name}; output: {event.output}")
90-
elif event.type == "audio_end":
91-
self.ui.add_transcript("Audio ended")
92-
elif event.type == "audio":
93-
np_audio = np.frombuffer(event.audio.data, dtype=np.int16)
94-
self.ui.play_audio(np_audio)
95-
elif event.type == "audio_interrupted":
96-
self.ui.add_transcript("Audio interrupted")
97-
elif event.type == "error":
98-
self.ui.add_transcript(f"Error: {event.error}")
99-
elif event.type == "history_updated":
100-
pass
101-
elif event.type == "history_added":
102-
pass
103-
elif event.type == "raw_model_event":
104-
self.ui.log_message(f"Raw model event: {event.data}")
105-
else:
106-
self.ui.log_message(f"Unknown event type: {event.type}")
107-
except Exception as e:
108-
# This can happen if the UI has already exited
109-
self.ui.log_message(f"Event handling error: {str(e)}")
68+
# Send the audio to the session
69+
assert self.session is not None
70+
await self.session.send_audio(audio_bytes)
71+
72+
async def _on_event(self, event: RealtimeSessionEvent) -> None:
73+
if event.type == "agent_start":
74+
self.ui.add_transcript(f"Agent started: {event.agent.name}")
75+
elif event.type == "agent_end":
76+
self.ui.add_transcript(f"Agent ended: {event.agent.name}")
77+
elif event.type == "handoff":
78+
self.ui.add_transcript(f"Handoff from {event.from_agent.name} to {event.to_agent.name}")
79+
elif event.type == "tool_start":
80+
self.ui.add_transcript(f"Tool started: {event.tool.name}")
81+
elif event.type == "tool_end":
82+
self.ui.add_transcript(f"Tool ended: {event.tool.name}; output: {event.output}")
83+
elif event.type == "audio_end":
84+
self.ui.add_transcript("Audio ended")
85+
elif event.type == "audio":
86+
np_audio = np.frombuffer(event.audio.data, dtype=np.int16)
87+
self.ui.play_audio(np_audio)
88+
elif event.type == "audio_interrupted":
89+
self.ui.add_transcript("Audio interrupted")
90+
elif event.type == "error":
91+
self.ui.add_transcript(f"Error: {event.error}")
92+
elif event.type == "history_updated":
93+
pass
94+
elif event.type == "history_added":
95+
pass
96+
elif event.type == "raw_model_event":
97+
self.ui.log_message(f"Raw model event: {event.data}")
98+
else:
99+
self.ui.log_message(f"Unknown event type: {event.type}")
110100

111101

112102
if __name__ == "__main__":

examples/realtime/ui.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -239,10 +239,7 @@ async def capture_audio(self) -> None:
239239

240240
# Call audio callback if set
241241
if self.audio_callback:
242-
try:
243-
await self.audio_callback(audio_bytes)
244-
except Exception as e:
245-
self.log_message(f"Audio callback error: {e}")
242+
await self.audio_callback(audio_bytes)
246243

247244
# Yield control back to event loop
248245
await asyncio.sleep(0)

src/agents/realtime/openai_realtime.py

Lines changed: 130 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,22 @@
88
from datetime import datetime
99
from typing import Any, Callable, Literal
1010

11+
import pydantic
1112
import websockets
1213
from openai.types.beta.realtime.conversation_item import ConversationItem
1314
from openai.types.beta.realtime.realtime_server_event import (
1415
RealtimeServerEvent as OpenAIRealtimeServerEvent,
1516
)
1617
from openai.types.beta.realtime.response_audio_delta_event import ResponseAudioDeltaEvent
18+
from openai.types.beta.realtime.session_update_event import (
19+
Session as OpenAISessionObject,
20+
SessionTool as OpenAISessionTool,
21+
)
1722
from pydantic import TypeAdapter
1823
from typing_extensions import assert_never
1924
from websockets.asyncio.client import ClientConnection
2025

26+
from agents.tool import FunctionTool, Tool
2127
from agents.util._types import MaybeAwaitable
2228

2329
from ..exceptions import UserError
@@ -56,6 +62,17 @@
5662
RealtimeModelSendUserInput,
5763
)
5864

65+
DEFAULT_MODEL_SETTINGS: RealtimeSessionModelSettings = {
66+
"voice": "ash",
67+
"modalities": ["text", "audio"],
68+
"input_audio_format": "pcm16",
69+
"output_audio_format": "pcm16",
70+
"input_audio_transcription": {
71+
"model": "gpt-4o-mini-transcribe",
72+
},
73+
"turn_detection": {"type": "semantic_vad"},
74+
}
75+
5976

6077
async def get_api_key(key: str | Callable[[], MaybeAwaitable[str]] | None) -> str | None:
6178
if isinstance(key, str):
@@ -110,6 +127,7 @@ async def connect(self, options: RealtimeModelConfig) -> None:
110127
}
111128
self._websocket = await websockets.connect(url, additional_headers=headers)
112129
self._websocket_task = asyncio.create_task(self._listen_for_messages())
130+
await self._update_session_config(model_settings)
113131

114132
async def _send_tracing_config(
115133
self, tracing_config: RealtimeModelTracingConfig | Literal["auto"] | None
@@ -127,11 +145,13 @@ async def _send_tracing_config(
127145

128146
def add_listener(self, listener: RealtimeModelListener) -> None:
129147
"""Add a listener to the model."""
130-
self._listeners.append(listener)
148+
if listener not in self._listeners:
149+
self._listeners.append(listener)
131150

132151
def remove_listener(self, listener: RealtimeModelListener) -> None:
133152
"""Remove a listener from the model."""
134-
self._listeners.remove(listener)
153+
if listener in self._listeners:
154+
self._listeners.remove(listener)
135155

136156
async def _emit_event(self, event: RealtimeModelEvent) -> None:
137157
"""Emit an event to the listeners."""
@@ -195,78 +215,55 @@ async def _send_raw_message(self, event: RealtimeModelSendRawMessage) -> None:
195215
"""Send a raw message to the model."""
196216
assert self._websocket is not None, "Not connected"
197217

198-
try:
199-
converted_event = {
200-
"type": event.message["type"],
201-
}
218+
converted_event = {
219+
"type": event.message["type"],
220+
}
202221

203-
converted_event.update(event.message.get("other_data", {}))
222+
converted_event.update(event.message.get("other_data", {}))
204223

205-
await self._websocket.send(json.dumps(converted_event))
206-
except Exception as e:
207-
await self._emit_event(
208-
RealtimeModelExceptionEvent(
209-
exception=e,
210-
context=f"Failed to send event: {event.message.get('type', 'unknown')}",
211-
)
212-
)
224+
await self._websocket.send(json.dumps(converted_event))
213225

214226
async def _send_user_input(self, event: RealtimeModelSendUserInput) -> None:
215-
"""Send a user input to the model."""
216-
try:
217-
message = (
218-
event.user_input
219-
if isinstance(event.user_input, dict)
220-
else {
221-
"type": "message",
222-
"role": "user",
223-
"content": [{"type": "input_text", "text": event.user_input}],
224-
}
225-
)
226-
other_data = {
227-
"item": message,
227+
message = (
228+
event.user_input
229+
if isinstance(event.user_input, dict)
230+
else {
231+
"type": "message",
232+
"role": "user",
233+
"content": [{"type": "input_text", "text": event.user_input}],
228234
}
235+
)
236+
other_data = {
237+
"item": message,
238+
}
229239

230-
await self._send_raw_message(
231-
RealtimeModelSendRawMessage(
232-
message={"type": "conversation.item.create", "other_data": other_data}
233-
)
234-
)
235-
await self._send_raw_message(
236-
RealtimeModelSendRawMessage(message={"type": "response.create"})
237-
)
238-
except Exception as e:
239-
await self._emit_event(
240-
RealtimeModelExceptionEvent(exception=e, context="Failed to send message")
240+
await self._send_raw_message(
241+
RealtimeModelSendRawMessage(
242+
message={"type": "conversation.item.create", "other_data": other_data}
241243
)
244+
)
245+
await self._send_raw_message(
246+
RealtimeModelSendRawMessage(message={"type": "response.create"})
247+
)
242248

243249
async def _send_audio(self, event: RealtimeModelSendAudio) -> None:
244-
"""Send audio to the model."""
245-
assert self._websocket is not None, "Not connected"
246-
247-
try:
248-
base64_audio = base64.b64encode(event.audio).decode("utf-8")
249-
await self._send_raw_message(
250-
RealtimeModelSendRawMessage(
251-
message={
252-
"type": "input_audio_buffer.append",
253-
"other_data": {
254-
"audio": base64_audio,
255-
},
256-
}
257-
)
250+
base64_audio = base64.b64encode(event.audio).decode("utf-8")
251+
await self._send_raw_message(
252+
RealtimeModelSendRawMessage(
253+
message={
254+
"type": "input_audio_buffer.append",
255+
"other_data": {
256+
"audio": base64_audio,
257+
},
258+
}
258259
)
259-
if event.commit:
260-
await self._send_raw_message(
261-
RealtimeModelSendRawMessage(message={"type": "input_audio_buffer.commit"})
262-
)
263-
except Exception as e:
264-
await self._emit_event(
265-
RealtimeModelExceptionEvent(exception=e, context="Failed to send audio")
260+
)
261+
if event.commit:
262+
await self._send_raw_message(
263+
RealtimeModelSendRawMessage(message={"type": "input_audio_buffer.commit"})
266264
)
267265

268266
async def _send_tool_output(self, event: RealtimeModelSendToolOutput) -> None:
269-
"""Send tool output to the model."""
270267
await self._send_raw_message(
271268
RealtimeModelSendRawMessage(
272269
message={
@@ -299,7 +296,6 @@ async def _send_tool_output(self, event: RealtimeModelSendToolOutput) -> None:
299296
)
300297

301298
async def _send_interrupt(self, event: RealtimeModelSendInterrupt) -> None:
302-
"""Send an interrupt to the model."""
303299
if not self._current_item_id or not self._audio_start_time:
304300
return
305301

@@ -418,8 +414,17 @@ async def _handle_ws_event(self, event: dict[str, Any]):
418414
parsed: OpenAIRealtimeServerEvent = TypeAdapter(
419415
OpenAIRealtimeServerEvent
420416
).validate_python(event)
417+
except pydantic.ValidationError as e:
418+
logger.error(f"Failed to validate server event: {event}", exc_info=True)
419+
await self._emit_event(
420+
RealtimeModelErrorEvent(
421+
error=e,
422+
)
423+
)
424+
return
421425
except Exception as e:
422426
event_type = event.get("type", "unknown") if isinstance(event, dict) else "unknown"
427+
logger.error(f"Failed to validate server event: {event}", exc_info=True)
423428
await self._emit_event(
424429
RealtimeModelExceptionEvent(
425430
exception=e,
@@ -492,3 +497,66 @@ async def _handle_ws_event(self, event: dict[str, Any]):
492497
or parsed.type == "response.output_item.done"
493498
):
494499
await self._handle_output_item(parsed.item)
500+
501+
async def _update_session_config(self, model_settings: RealtimeSessionModelSettings) -> None:
502+
session_config = self._get_session_config(model_settings)
503+
await self._send_raw_message(
504+
RealtimeModelSendRawMessage(
505+
message={
506+
"type": "session.update",
507+
"other_data": {
508+
"session": session_config.model_dump(exclude_unset=True, exclude_none=True)
509+
},
510+
}
511+
)
512+
)
513+
514+
def _get_session_config(
515+
self, model_settings: RealtimeSessionModelSettings
516+
) -> OpenAISessionObject:
517+
"""Get the session config."""
518+
return OpenAISessionObject(
519+
instructions=model_settings.get("instructions", None),
520+
model=(
521+
model_settings.get("model_name", self.model) # type: ignore
522+
or DEFAULT_MODEL_SETTINGS.get("model_name")
523+
),
524+
voice=model_settings.get("voice", DEFAULT_MODEL_SETTINGS.get("voice")),
525+
modalities=model_settings.get("modalities", DEFAULT_MODEL_SETTINGS.get("modalities")),
526+
input_audio_format=model_settings.get(
527+
"input_audio_format",
528+
DEFAULT_MODEL_SETTINGS.get("input_audio_format"), # type: ignore
529+
),
530+
output_audio_format=model_settings.get(
531+
"output_audio_format",
532+
DEFAULT_MODEL_SETTINGS.get("output_audio_format"), # type: ignore
533+
),
534+
input_audio_transcription=model_settings.get(
535+
"input_audio_transcription",
536+
DEFAULT_MODEL_SETTINGS.get("input_audio_transcription"), # type: ignore
537+
),
538+
turn_detection=model_settings.get(
539+
"turn_detection",
540+
DEFAULT_MODEL_SETTINGS.get("turn_detection"), # type: ignore
541+
),
542+
tool_choice=model_settings.get(
543+
"tool_choice",
544+
DEFAULT_MODEL_SETTINGS.get("tool_choice"), # type: ignore
545+
),
546+
tools=self._tools_to_session_tools(model_settings.get("tools", [])),
547+
)
548+
549+
def _tools_to_session_tools(self, tools: list[Tool]) -> list[OpenAISessionTool]:
550+
converted_tools: list[OpenAISessionTool] = []
551+
for tool in tools:
552+
if not isinstance(tool, FunctionTool):
553+
raise UserError(f"Tool {tool.name} is unsupported. Must be a function tool.")
554+
converted_tools.append(
555+
OpenAISessionTool(
556+
name=tool.name,
557+
description=tool.description,
558+
parameters=tool.params_json_schema,
559+
type="function",
560+
)
561+
)
562+
return converted_tools

0 commit comments

Comments
 (0)