8
8
from datetime import datetime
9
9
from typing import Any , Callable , Literal
10
10
11
+ import pydantic
11
12
import websockets
12
13
from openai .types .beta .realtime .conversation_item import ConversationItem
13
14
from openai .types .beta .realtime .realtime_server_event import (
14
15
RealtimeServerEvent as OpenAIRealtimeServerEvent ,
15
16
)
16
17
from openai .types .beta .realtime .response_audio_delta_event import ResponseAudioDeltaEvent
18
+ from openai .types .beta .realtime .session_update_event import (
19
+ Session as OpenAISessionObject ,
20
+ SessionTool as OpenAISessionTool ,
21
+ )
17
22
from pydantic import TypeAdapter
18
23
from typing_extensions import assert_never
19
24
from websockets .asyncio .client import ClientConnection
20
25
26
+ from agents .tool import FunctionTool , Tool
21
27
from agents .util ._types import MaybeAwaitable
22
28
23
29
from ..exceptions import UserError
56
62
RealtimeModelSendUserInput ,
57
63
)
58
64
65
+ DEFAULT_MODEL_SETTINGS : RealtimeSessionModelSettings = {
66
+ "voice" : "ash" ,
67
+ "modalities" : ["text" , "audio" ],
68
+ "input_audio_format" : "pcm16" ,
69
+ "output_audio_format" : "pcm16" ,
70
+ "input_audio_transcription" : {
71
+ "model" : "gpt-4o-mini-transcribe" ,
72
+ },
73
+ "turn_detection" : {"type" : "semantic_vad" },
74
+ }
75
+
59
76
60
77
async def get_api_key (key : str | Callable [[], MaybeAwaitable [str ]] | None ) -> str | None :
61
78
if isinstance (key , str ):
@@ -110,6 +127,7 @@ async def connect(self, options: RealtimeModelConfig) -> None:
110
127
}
111
128
self ._websocket = await websockets .connect (url , additional_headers = headers )
112
129
self ._websocket_task = asyncio .create_task (self ._listen_for_messages ())
130
+ await self ._update_session_config (model_settings )
113
131
114
132
async def _send_tracing_config (
115
133
self , tracing_config : RealtimeModelTracingConfig | Literal ["auto" ] | None
@@ -127,11 +145,13 @@ async def _send_tracing_config(
127
145
128
146
def add_listener (self , listener : RealtimeModelListener ) -> None :
129
147
"""Add a listener to the model."""
130
- self ._listeners .append (listener )
148
+ if listener not in self ._listeners :
149
+ self ._listeners .append (listener )
131
150
132
151
def remove_listener (self , listener : RealtimeModelListener ) -> None :
133
152
"""Remove a listener from the model."""
134
- self ._listeners .remove (listener )
153
+ if listener in self ._listeners :
154
+ self ._listeners .remove (listener )
135
155
136
156
async def _emit_event (self , event : RealtimeModelEvent ) -> None :
137
157
"""Emit an event to the listeners."""
@@ -195,78 +215,55 @@ async def _send_raw_message(self, event: RealtimeModelSendRawMessage) -> None:
195
215
"""Send a raw message to the model."""
196
216
assert self ._websocket is not None , "Not connected"
197
217
198
- try :
199
- converted_event = {
200
- "type" : event .message ["type" ],
201
- }
218
+ converted_event = {
219
+ "type" : event .message ["type" ],
220
+ }
202
221
203
- converted_event .update (event .message .get ("other_data" , {}))
222
+ converted_event .update (event .message .get ("other_data" , {}))
204
223
205
- await self ._websocket .send (json .dumps (converted_event ))
206
- except Exception as e :
207
- await self ._emit_event (
208
- RealtimeModelExceptionEvent (
209
- exception = e ,
210
- context = f"Failed to send event: { event .message .get ('type' , 'unknown' )} " ,
211
- )
212
- )
224
+ await self ._websocket .send (json .dumps (converted_event ))
213
225
214
226
async def _send_user_input (self , event : RealtimeModelSendUserInput ) -> None :
215
- """Send a user input to the model."""
216
- try :
217
- message = (
218
- event .user_input
219
- if isinstance (event .user_input , dict )
220
- else {
221
- "type" : "message" ,
222
- "role" : "user" ,
223
- "content" : [{"type" : "input_text" , "text" : event .user_input }],
224
- }
225
- )
226
- other_data = {
227
- "item" : message ,
227
+ message = (
228
+ event .user_input
229
+ if isinstance (event .user_input , dict )
230
+ else {
231
+ "type" : "message" ,
232
+ "role" : "user" ,
233
+ "content" : [{"type" : "input_text" , "text" : event .user_input }],
228
234
}
235
+ )
236
+ other_data = {
237
+ "item" : message ,
238
+ }
229
239
230
- await self ._send_raw_message (
231
- RealtimeModelSendRawMessage (
232
- message = {"type" : "conversation.item.create" , "other_data" : other_data }
233
- )
234
- )
235
- await self ._send_raw_message (
236
- RealtimeModelSendRawMessage (message = {"type" : "response.create" })
237
- )
238
- except Exception as e :
239
- await self ._emit_event (
240
- RealtimeModelExceptionEvent (exception = e , context = "Failed to send message" )
240
+ await self ._send_raw_message (
241
+ RealtimeModelSendRawMessage (
242
+ message = {"type" : "conversation.item.create" , "other_data" : other_data }
241
243
)
244
+ )
245
+ await self ._send_raw_message (
246
+ RealtimeModelSendRawMessage (message = {"type" : "response.create" })
247
+ )
242
248
243
249
async def _send_audio (self , event : RealtimeModelSendAudio ) -> None :
244
- """Send audio to the model."""
245
- assert self ._websocket is not None , "Not connected"
246
-
247
- try :
248
- base64_audio = base64 .b64encode (event .audio ).decode ("utf-8" )
249
- await self ._send_raw_message (
250
- RealtimeModelSendRawMessage (
251
- message = {
252
- "type" : "input_audio_buffer.append" ,
253
- "other_data" : {
254
- "audio" : base64_audio ,
255
- },
256
- }
257
- )
250
+ base64_audio = base64 .b64encode (event .audio ).decode ("utf-8" )
251
+ await self ._send_raw_message (
252
+ RealtimeModelSendRawMessage (
253
+ message = {
254
+ "type" : "input_audio_buffer.append" ,
255
+ "other_data" : {
256
+ "audio" : base64_audio ,
257
+ },
258
+ }
258
259
)
259
- if event .commit :
260
- await self ._send_raw_message (
261
- RealtimeModelSendRawMessage (message = {"type" : "input_audio_buffer.commit" })
262
- )
263
- except Exception as e :
264
- await self ._emit_event (
265
- RealtimeModelExceptionEvent (exception = e , context = "Failed to send audio" )
260
+ )
261
+ if event .commit :
262
+ await self ._send_raw_message (
263
+ RealtimeModelSendRawMessage (message = {"type" : "input_audio_buffer.commit" })
266
264
)
267
265
268
266
async def _send_tool_output (self , event : RealtimeModelSendToolOutput ) -> None :
269
- """Send tool output to the model."""
270
267
await self ._send_raw_message (
271
268
RealtimeModelSendRawMessage (
272
269
message = {
@@ -299,7 +296,6 @@ async def _send_tool_output(self, event: RealtimeModelSendToolOutput) -> None:
299
296
)
300
297
301
298
async def _send_interrupt (self , event : RealtimeModelSendInterrupt ) -> None :
302
- """Send an interrupt to the model."""
303
299
if not self ._current_item_id or not self ._audio_start_time :
304
300
return
305
301
@@ -418,8 +414,17 @@ async def _handle_ws_event(self, event: dict[str, Any]):
418
414
parsed : OpenAIRealtimeServerEvent = TypeAdapter (
419
415
OpenAIRealtimeServerEvent
420
416
).validate_python (event )
417
+ except pydantic .ValidationError as e :
418
+ logger .error (f"Failed to validate server event: { event } " , exc_info = True )
419
+ await self ._emit_event (
420
+ RealtimeModelErrorEvent (
421
+ error = e ,
422
+ )
423
+ )
424
+ return
421
425
except Exception as e :
422
426
event_type = event .get ("type" , "unknown" ) if isinstance (event , dict ) else "unknown"
427
+ logger .error (f"Failed to validate server event: { event } " , exc_info = True )
423
428
await self ._emit_event (
424
429
RealtimeModelExceptionEvent (
425
430
exception = e ,
@@ -492,3 +497,66 @@ async def _handle_ws_event(self, event: dict[str, Any]):
492
497
or parsed .type == "response.output_item.done"
493
498
):
494
499
await self ._handle_output_item (parsed .item )
500
+
501
+ async def _update_session_config (self , model_settings : RealtimeSessionModelSettings ) -> None :
502
+ session_config = self ._get_session_config (model_settings )
503
+ await self ._send_raw_message (
504
+ RealtimeModelSendRawMessage (
505
+ message = {
506
+ "type" : "session.update" ,
507
+ "other_data" : {
508
+ "session" : session_config .model_dump (exclude_unset = True , exclude_none = True )
509
+ },
510
+ }
511
+ )
512
+ )
513
+
514
+ def _get_session_config (
515
+ self , model_settings : RealtimeSessionModelSettings
516
+ ) -> OpenAISessionObject :
517
+ """Get the session config."""
518
+ return OpenAISessionObject (
519
+ instructions = model_settings .get ("instructions" , None ),
520
+ model = (
521
+ model_settings .get ("model_name" , self .model ) # type: ignore
522
+ or DEFAULT_MODEL_SETTINGS .get ("model_name" )
523
+ ),
524
+ voice = model_settings .get ("voice" , DEFAULT_MODEL_SETTINGS .get ("voice" )),
525
+ modalities = model_settings .get ("modalities" , DEFAULT_MODEL_SETTINGS .get ("modalities" )),
526
+ input_audio_format = model_settings .get (
527
+ "input_audio_format" ,
528
+ DEFAULT_MODEL_SETTINGS .get ("input_audio_format" ), # type: ignore
529
+ ),
530
+ output_audio_format = model_settings .get (
531
+ "output_audio_format" ,
532
+ DEFAULT_MODEL_SETTINGS .get ("output_audio_format" ), # type: ignore
533
+ ),
534
+ input_audio_transcription = model_settings .get (
535
+ "input_audio_transcription" ,
536
+ DEFAULT_MODEL_SETTINGS .get ("input_audio_transcription" ), # type: ignore
537
+ ),
538
+ turn_detection = model_settings .get (
539
+ "turn_detection" ,
540
+ DEFAULT_MODEL_SETTINGS .get ("turn_detection" ), # type: ignore
541
+ ),
542
+ tool_choice = model_settings .get (
543
+ "tool_choice" ,
544
+ DEFAULT_MODEL_SETTINGS .get ("tool_choice" ), # type: ignore
545
+ ),
546
+ tools = self ._tools_to_session_tools (model_settings .get ("tools" , [])),
547
+ )
548
+
549
+ def _tools_to_session_tools (self , tools : list [Tool ]) -> list [OpenAISessionTool ]:
550
+ converted_tools : list [OpenAISessionTool ] = []
551
+ for tool in tools :
552
+ if not isinstance (tool , FunctionTool ):
553
+ raise UserError (f"Tool { tool .name } is unsupported. Must be a function tool." )
554
+ converted_tools .append (
555
+ OpenAISessionTool (
556
+ name = tool .name ,
557
+ description = tool .description ,
558
+ parameters = tool .params_json_schema ,
559
+ type = "function" ,
560
+ )
561
+ )
562
+ return converted_tools
0 commit comments