1616from ....types ._events import ToolResultEvent , ToolUseStreamEvent
1717from ....types .content import Messages
1818from ....types .tools import ToolResult , ToolSpec , ToolUse
19+ from .._async import stop_all
1920from ..types .events import (
2021 BidiAudioInputEvent ,
2122 BidiAudioStreamEvent ,
22- BidiConnectionCloseEvent ,
2323 BidiConnectionStartEvent ,
24- BidiErrorEvent ,
25- BidiImageInputEvent ,
2624 BidiInputEvent ,
2725 BidiInterruptionEvent ,
2826 BidiOutputEvent ,
@@ -70,6 +68,8 @@ class BidiOpenAIRealtimeModel(BidiModel):
7068 function calling, and event conversion to Strands format.
7169 """
7270
71+ _websocket : ClientConnection
72+
7373 def __init__ (
7474 self ,
7575 model : str = DEFAULT_MODEL ,
@@ -104,9 +104,7 @@ def __init__(
104104 )
105105
106106 # Connection state (initialized in start())
107- self .websocket : ClientConnection
108- self .connection_id : str
109- self ._active : bool = False
107+ self ._connection_id : str | None = None
110108
111109 self ._function_call_buffer : dict [str , Any ] = {}
112110
@@ -127,45 +125,35 @@ async def start(
127125 messages: Conversation history to initialize with.
128126 **kwargs: Additional configuration options.
129127 """
130- if self ._active :
131- raise RuntimeError ("Connection already active. Close the existing connection before creating a new one. " )
128+ if self ._connection_id :
129+ raise RuntimeError ("model already started | call stop before starting again " )
132130
133131 logger .info ("openai realtime connection starting" )
134132
135- try :
136- # Initialize connection state
137- self .connection_id = str (uuid .uuid4 ())
138- self ._active = True
139- self ._function_call_buffer = {}
140-
141- # Establish WebSocket connection
142- url = f"{ OPENAI_REALTIME_URL } ?model={ self .model } "
133+ # Initialize connection state
134+ self ._connection_id = str (uuid .uuid4 ())
143135
144- headers = [("Authorization" , f"Bearer { self .api_key } " )]
145- if self .organization :
146- headers .append (("OpenAI-Organization" , self .organization ))
147- if self .project :
148- headers .append (("OpenAI-Project" , self .project ))
136+ self ._function_call_buffer = {}
149137
150- self . websocket = await websockets . connect ( url , additional_headers = headers )
151- logger . info ( "connection_id=<%s> | websocket connected successfully" , self .connection_id )
138+ # Establish WebSocket connection
139+ url = f" { OPENAI_REALTIME_URL } ?model= { self .model } "
152140
153- # Configure session
154- session_config = self ._build_session_config (system_prompt , tools )
155- await self ._send_event ({"type" : "session.update" , "session" : session_config })
141+ headers = [("Authorization" , f"Bearer { self .api_key } " )]
142+ if self .organization :
143+ headers .append (("OpenAI-Organization" , self .organization ))
144+ if self .project :
145+ headers .append (("OpenAI-Project" , self .project ))
156146
157- # Add conversation history if provided
158- if messages :
159- await self ._add_conversation_history (messages )
147+ self ._websocket = await websockets .connect (url , additional_headers = headers )
148+ logger .info ("connection_id=<%s> | websocket connected successfully" , self ._connection_id )
160149
161- except Exception as e :
162- self ._active = False
163- logger .error ("error=<%s> | openai connection failed" , e )
164- raise
150+ # Configure session
151+ session_config = self ._build_session_config (system_prompt , tools )
152+ await self ._send_event ({"type" : "session.update" , "session" : session_config })
165153
166- def _require_active ( self ) -> bool :
167- """Check if session is active."""
168- return self ._active
154+ # Add conversation history if provided
155+ if messages :
156+ await self ._add_conversation_history ( messages )
169157
170158 def _create_text_event (self , text : str , role : str , is_final : bool = True ) -> BidiTranscriptStreamEvent :
171159 """Create standardized transcript event.
@@ -275,27 +263,16 @@ async def _add_conversation_history(self, messages: Messages) -> None:
275263
276264 async def receive (self ) -> AsyncIterable [BidiOutputEvent ]: # type: ignore
277265 """Receive OpenAI events and convert to Strands TypedEvent format."""
278- # Emit connection start event
279- yield BidiConnectionStartEvent (connection_id = self .connection_id , model = self .model )
280-
281- try :
282- while self ._active :
283- async for message in self .websocket :
284- if not self ._active :
285- break # type: ignore
266+ if not self ._connection_id :
267+ raise RuntimeError ("model not started | call start before receiving" )
286268
287- openai_event = json . loads ( message )
269+ yield BidiConnectionStartEvent ( connection_id = self . _connection_id , model = self . model )
288270
289- for event in self ._convert_openai_event ( openai_event ) or [] :
290- yield event
271+ async for message in self ._websocket :
272+ openai_event = json . loads ( message )
291273
292- except Exception as e :
293- logger .error ("error=<%s> | error receiving openai realtime event" , e )
294- yield BidiErrorEvent (error = e )
295- finally :
296- # Emit connection close event
297- yield BidiConnectionCloseEvent (connection_id = self .connection_id , reason = "complete" )
298- self ._active = False
274+ for event in self ._convert_openai_event (openai_event ) or []:
275+ yield event
299276
300277 def _convert_openai_event (self , openai_event : dict [str , Any ]) -> list [BidiOutputEvent ] | None :
301278 """Convert OpenAI events to Strands TypedEvent format."""
@@ -557,26 +534,24 @@ async def send(
557534
558535 Args:
559536 content: Typed event (BidiTextInputEvent, BidiAudioInputEvent, BidiImageInputEvent, or ToolResultEvent).
537+
538+ Raises:
539+ ValueError: If content type not supported (e.g., image content).
560540 """
561- if not self ._require_active ():
562- return
563-
564- try :
565- # Note: TypedEvent inherits from dict, so isinstance checks for TypedEvent must come first
566- if isinstance (content , BidiTextInputEvent ):
567- await self ._send_text_content (content .text )
568- elif isinstance (content , BidiAudioInputEvent ):
569- await self ._send_audio_content (content )
570- elif isinstance (content , BidiImageInputEvent ):
571- # BidiImageInputEvent - not supported by OpenAI Realtime yet
572- logger .warning ("Image input not supported by OpenAI Realtime API" )
573- elif isinstance (content , ToolResultEvent ):
574- tool_result = content .get ("tool_result" )
575- if tool_result :
576- await self ._send_tool_result (tool_result )
577- except Exception as e :
578- logger .error ("error=<%s> | error sending content to openai" , e )
579- raise # Propagate exception for debugging in experimental code
541+ if not self ._connection_id :
542+ raise RuntimeError ("model not started | call start before sending" )
543+
544+ # Note: TypedEvent inherits from dict, so isinstance checks for TypedEvent must come first
545+ if isinstance (content , BidiTextInputEvent ):
546+ await self ._send_text_content (content .text )
547+ elif isinstance (content , BidiAudioInputEvent ):
548+ await self ._send_audio_content (content )
549+ elif isinstance (content , ToolResultEvent ):
550+ tool_result = content .get ("tool_result" )
551+ if tool_result :
552+ await self ._send_tool_result (tool_result )
553+ else :
554+ raise ValueError (f"content_type={ type (content )} | content not supported" )
580555
581556 async def _send_audio_content (self , audio_input : BidiAudioInputEvent ) -> None :
582557 """Internal: Send audio content to OpenAI for processing."""
@@ -599,7 +574,7 @@ async def _send_tool_result(self, tool_result: ToolResult) -> None:
599574
600575 logger .debug ("tool_use_id=<%s> | sending openai tool result" , tool_use_id )
601576
602- # Extract result content
577+ # TODO: We need to extract all content and content types
603578 result_data : dict [Any , Any ] | str = {}
604579 if "content" in tool_result :
605580 # Extract text from content blocks
@@ -616,25 +591,23 @@ async def _send_tool_result(self, tool_result: ToolResult) -> None:
616591
617592 async def stop (self ) -> None :
618593 """Close session and cleanup resources."""
619- if not self ._active :
620- return
621-
622594 logger .debug ("openai realtime connection cleanup starting" )
623- self ._active = False
624595
625- try :
626- await self .websocket .close ()
627- except Exception as e :
628- logger .warning ("error=<%s> | error closing openai realtime websocket" , e )
596+ async def stop_websocket () -> None :
597+ if not hasattr (self , "_websocket" ):
598+ return
599+
600+ await self ._websocket .close ()
601+
602+ async def stop_connection () -> None :
603+ self ._connection_id = None
604+
605+ await stop_all (stop_websocket , stop_connection )
629606
630607 logger .debug ("openai realtime connection closed" )
631608
632609 async def _send_event (self , event : dict [str , Any ]) -> None :
633610 """Send event to OpenAI via WebSocket."""
634- try :
635- message = json .dumps (event )
636- await self .websocket .send (message )
637- logger .debug ("event_type=<%s> | openai event sent" , event .get ("type" ))
638- except Exception as e :
639- logger .error ("error=<%s> | error sending openai event" , e )
640- raise
611+ message = json .dumps (event )
612+ await self ._websocket .send (message )
613+ logger .debug ("event_type=<%s> | openai event sent" , event .get ("type" ))
0 commit comments