Skip to content

Commit 1b5f9a6

Browse files
authored
cancellation - openai (strands-agents#73)
1 parent a308eee commit 1b5f9a6

File tree

4 files changed

+77
-123
lines changed

4 files changed

+77
-123
lines changed

src/strands/experimental/bidi/models/novasonic.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ async def start(
147147
RuntimeError: If user calls start again without first stopping.
148148
"""
149149
if self._connection_id:
150-
raise RuntimeError("call stop before starting again")
150+
raise RuntimeError("model already started | call stop before starting again")
151151

152152
logger.debug("nova connection starting")
153153

@@ -233,7 +233,7 @@ async def receive(self) -> AsyncIterable[BidiOutputEvent]: # type: ignore[overr
233233
RuntimeError: If start has not been called.
234234
"""
235235
if not self._connection_id:
236-
raise RuntimeError("must call start")
236+
raise RuntimeError("model not started | call start before receiving")
237237

238238
logger.debug("nova event stream starting")
239239
yield BidiConnectionStartEvent(connection_id=self._connection_id, model=self.model_id)
@@ -260,7 +260,7 @@ async def send(self, content: BidiInputEvent | ToolResultEvent) -> None:
260260
ValueError: If content type not supported (e.g., image content).
261261
"""
262262
if not self._connection_id:
263-
raise RuntimeError("must call start")
263+
raise RuntimeError("model not started | call start before sending")
264264

265265
if isinstance(content, BidiTextInputEvent):
266266
await self._send_text_content(content.text)
@@ -271,7 +271,7 @@ async def send(self, content: BidiInputEvent | ToolResultEvent) -> None:
271271
if tool_result:
272272
await self._send_tool_result(tool_result)
273273
else:
274-
raise ValueError(f"content_type={type(content)} | content not supported by nova sonic")
274+
raise ValueError(f"content_type={type(content)} | content not supported")
275275

276276
async def _start_audio_connection(self) -> None:
277277
"""Internal: Start audio input connection (call once before sending audio chunks)."""

src/strands/experimental/bidi/models/openai.py

Lines changed: 62 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,11 @@
1616
from ....types._events import ToolResultEvent, ToolUseStreamEvent
1717
from ....types.content import Messages
1818
from ....types.tools import ToolResult, ToolSpec, ToolUse
19+
from .._async import stop_all
1920
from ..types.events import (
2021
BidiAudioInputEvent,
2122
BidiAudioStreamEvent,
22-
BidiConnectionCloseEvent,
2323
BidiConnectionStartEvent,
24-
BidiErrorEvent,
25-
BidiImageInputEvent,
2624
BidiInputEvent,
2725
BidiInterruptionEvent,
2826
BidiOutputEvent,
@@ -70,6 +68,8 @@ class BidiOpenAIRealtimeModel(BidiModel):
7068
function calling, and event conversion to Strands format.
7169
"""
7270

71+
_websocket: ClientConnection
72+
7373
def __init__(
7474
self,
7575
model: str = DEFAULT_MODEL,
@@ -104,9 +104,7 @@ def __init__(
104104
)
105105

106106
# Connection state (initialized in start())
107-
self.websocket: ClientConnection
108-
self.connection_id: str
109-
self._active: bool = False
107+
self._connection_id: str | None = None
110108

111109
self._function_call_buffer: dict[str, Any] = {}
112110

@@ -127,45 +125,35 @@ async def start(
127125
messages: Conversation history to initialize with.
128126
**kwargs: Additional configuration options.
129127
"""
130-
if self._active:
131-
raise RuntimeError("Connection already active. Close the existing connection before creating a new one.")
128+
if self._connection_id:
129+
raise RuntimeError("model already started | call stop before starting again")
132130

133131
logger.info("openai realtime connection starting")
134132

135-
try:
136-
# Initialize connection state
137-
self.connection_id = str(uuid.uuid4())
138-
self._active = True
139-
self._function_call_buffer = {}
140-
141-
# Establish WebSocket connection
142-
url = f"{OPENAI_REALTIME_URL}?model={self.model}"
133+
# Initialize connection state
134+
self._connection_id = str(uuid.uuid4())
143135

144-
headers = [("Authorization", f"Bearer {self.api_key}")]
145-
if self.organization:
146-
headers.append(("OpenAI-Organization", self.organization))
147-
if self.project:
148-
headers.append(("OpenAI-Project", self.project))
136+
self._function_call_buffer = {}
149137

150-
self.websocket = await websockets.connect(url, additional_headers=headers)
151-
logger.info("connection_id=<%s> | websocket connected successfully", self.connection_id)
138+
# Establish WebSocket connection
139+
url = f"{OPENAI_REALTIME_URL}?model={self.model}"
152140

153-
# Configure session
154-
session_config = self._build_session_config(system_prompt, tools)
155-
await self._send_event({"type": "session.update", "session": session_config})
141+
headers = [("Authorization", f"Bearer {self.api_key}")]
142+
if self.organization:
143+
headers.append(("OpenAI-Organization", self.organization))
144+
if self.project:
145+
headers.append(("OpenAI-Project", self.project))
156146

157-
# Add conversation history if provided
158-
if messages:
159-
await self._add_conversation_history(messages)
147+
self._websocket = await websockets.connect(url, additional_headers=headers)
148+
logger.info("connection_id=<%s> | websocket connected successfully", self._connection_id)
160149

161-
except Exception as e:
162-
self._active = False
163-
logger.error("error=<%s> | openai connection failed", e)
164-
raise
150+
# Configure session
151+
session_config = self._build_session_config(system_prompt, tools)
152+
await self._send_event({"type": "session.update", "session": session_config})
165153

166-
def _require_active(self) -> bool:
167-
"""Check if session is active."""
168-
return self._active
154+
# Add conversation history if provided
155+
if messages:
156+
await self._add_conversation_history(messages)
169157

170158
def _create_text_event(self, text: str, role: str, is_final: bool = True) -> BidiTranscriptStreamEvent:
171159
"""Create standardized transcript event.
@@ -275,27 +263,16 @@ async def _add_conversation_history(self, messages: Messages) -> None:
275263

276264
async def receive(self) -> AsyncIterable[BidiOutputEvent]: # type: ignore
277265
"""Receive OpenAI events and convert to Strands TypedEvent format."""
278-
# Emit connection start event
279-
yield BidiConnectionStartEvent(connection_id=self.connection_id, model=self.model)
280-
281-
try:
282-
while self._active:
283-
async for message in self.websocket:
284-
if not self._active:
285-
break # type: ignore
266+
if not self._connection_id:
267+
raise RuntimeError("model not started | call start before receiving")
286268

287-
openai_event = json.loads(message)
269+
yield BidiConnectionStartEvent(connection_id=self._connection_id, model=self.model)
288270

289-
for event in self._convert_openai_event(openai_event) or []:
290-
yield event
271+
async for message in self._websocket:
272+
openai_event = json.loads(message)
291273

292-
except Exception as e:
293-
logger.error("error=<%s> | error receiving openai realtime event", e)
294-
yield BidiErrorEvent(error=e)
295-
finally:
296-
# Emit connection close event
297-
yield BidiConnectionCloseEvent(connection_id=self.connection_id, reason="complete")
298-
self._active = False
274+
for event in self._convert_openai_event(openai_event) or []:
275+
yield event
299276

300277
def _convert_openai_event(self, openai_event: dict[str, Any]) -> list[BidiOutputEvent] | None:
301278
"""Convert OpenAI events to Strands TypedEvent format."""
@@ -557,26 +534,24 @@ async def send(
557534
558535
Args:
559536
content: Typed event (BidiTextInputEvent, BidiAudioInputEvent, BidiImageInputEvent, or ToolResultEvent).
537+
538+
Raises:
539+
ValueError: If content type not supported (e.g., image content).
560540
"""
561-
if not self._require_active():
562-
return
563-
564-
try:
565-
# Note: TypedEvent inherits from dict, so isinstance checks for TypedEvent must come first
566-
if isinstance(content, BidiTextInputEvent):
567-
await self._send_text_content(content.text)
568-
elif isinstance(content, BidiAudioInputEvent):
569-
await self._send_audio_content(content)
570-
elif isinstance(content, BidiImageInputEvent):
571-
# BidiImageInputEvent - not supported by OpenAI Realtime yet
572-
logger.warning("Image input not supported by OpenAI Realtime API")
573-
elif isinstance(content, ToolResultEvent):
574-
tool_result = content.get("tool_result")
575-
if tool_result:
576-
await self._send_tool_result(tool_result)
577-
except Exception as e:
578-
logger.error("error=<%s> | error sending content to openai", e)
579-
raise # Propagate exception for debugging in experimental code
541+
if not self._connection_id:
542+
raise RuntimeError("model not started | call start before sending")
543+
544+
# Note: TypedEvent inherits from dict, so isinstance checks for TypedEvent must come first
545+
if isinstance(content, BidiTextInputEvent):
546+
await self._send_text_content(content.text)
547+
elif isinstance(content, BidiAudioInputEvent):
548+
await self._send_audio_content(content)
549+
elif isinstance(content, ToolResultEvent):
550+
tool_result = content.get("tool_result")
551+
if tool_result:
552+
await self._send_tool_result(tool_result)
553+
else:
554+
raise ValueError(f"content_type={type(content)} | content not supported")
580555

581556
async def _send_audio_content(self, audio_input: BidiAudioInputEvent) -> None:
582557
"""Internal: Send audio content to OpenAI for processing."""
@@ -599,7 +574,7 @@ async def _send_tool_result(self, tool_result: ToolResult) -> None:
599574

600575
logger.debug("tool_use_id=<%s> | sending openai tool result", tool_use_id)
601576

602-
# Extract result content
577+
# TODO: We need to extract all content and content types
603578
result_data: dict[Any, Any] | str = {}
604579
if "content" in tool_result:
605580
# Extract text from content blocks
@@ -616,25 +591,23 @@ async def _send_tool_result(self, tool_result: ToolResult) -> None:
616591

617592
async def stop(self) -> None:
618593
"""Close session and cleanup resources."""
619-
if not self._active:
620-
return
621-
622594
logger.debug("openai realtime connection cleanup starting")
623-
self._active = False
624595

625-
try:
626-
await self.websocket.close()
627-
except Exception as e:
628-
logger.warning("error=<%s> | error closing openai realtime websocket", e)
596+
async def stop_websocket() -> None:
597+
if not hasattr(self, "_websocket"):
598+
return
599+
600+
await self._websocket.close()
601+
602+
async def stop_connection() -> None:
603+
self._connection_id = None
604+
605+
await stop_all(stop_websocket, stop_connection)
629606

630607
logger.debug("openai realtime connection closed")
631608

632609
async def _send_event(self, event: dict[str, Any]) -> None:
633610
"""Send event to OpenAI via WebSocket."""
634-
try:
635-
message = json.dumps(event)
636-
await self.websocket.send(message)
637-
logger.debug("event_type=<%s> | openai event sent", event.get("type"))
638-
except Exception as e:
639-
logger.error("error=<%s> | error sending openai event", e)
640-
raise
611+
message = json.dumps(event)
612+
await self._websocket.send(message)
613+
logger.debug("event_type=<%s> | openai event sent", event.get("type"))

tests/strands/experimental/bidi/models/test_novasonic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ async def test_send_edge_cases(nova_model):
166166
mime_type="image/jpeg",
167167
)
168168

169-
with pytest.raises(ValueError, match=r"content not supported by nova sonic"):
169+
with pytest.raises(ValueError, match=r"content not supported"):
170170
await nova_model.send(image_event)
171171

172172
await nova_model.stop()

tests/strands/experimental/bidi/models/test_openai_realtime.py renamed to tests/strands/experimental/bidi/models/test_openai.py

Lines changed: 10 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,6 @@ def test_model_initialization(api_key, model_name):
9393
model_default = BidiOpenAIRealtimeModel(api_key="test-key")
9494
assert model_default.model == "gpt-realtime"
9595
assert model_default.api_key == "test-key"
96-
assert model_default._active is False
97-
assert model_default.websocket is None
9896

9997
# Test with custom model
10098
model_custom = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key)
@@ -129,14 +127,12 @@ async def test_connection_lifecycle(mock_websockets_connect, model, system_promp
129127

130128
# Test basic connection
131129
await model.start()
132-
assert model._active is True
133-
assert model.connection_id is not None
134-
assert model.websocket == mock_ws
130+
assert model._connection_id is not None
131+
assert model._websocket == mock_ws
135132
mock_connect.assert_called_once()
136133

137134
# Test close
138135
await model.stop()
139-
assert model._active is False
140136
mock_ws.close.assert_called_once()
141137

142138
# Test connection with system prompt
@@ -202,20 +198,20 @@ async def async_connect(*args, **kwargs):
202198
# Test double connection
203199
model2 = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key)
204200
await model2.start()
205-
with pytest.raises(RuntimeError, match="Connection already active"):
201+
with pytest.raises(RuntimeError, match=r"call stop before starting again"):
206202
await model2.start()
207203
await model2.stop()
208204

209205
# Test close when not connected
210206
model3 = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key)
211207
await model3.stop() # Should not raise
212208

213-
# Test close error handling (should not raise, just log)
209+
# Test close error
214210
model4 = BidiOpenAIRealtimeModel(model=model_name, api_key=api_key)
215211
await model4.start()
216212
mock_ws.close.side_effect = Exception("Close failed")
217-
await model4.stop() # Should not raise
218-
assert model4._active is False
213+
with pytest.raises(ExceptionGroup): # noqa: F821
214+
await model4.stop()
219215

220216

221217
# Send Method Tests
@@ -279,7 +275,8 @@ async def test_send_edge_cases(mock_websockets_connect, model):
279275

280276
# Test send when inactive
281277
text_input = BidiTextInputEvent(text="Hello", role="user")
282-
await model.send(text_input)
278+
with pytest.raises(RuntimeError, match=r"call start before sending"):
279+
await model.send(text_input)
283280
mock_ws.send.assert_not_called()
284281

285282
# Test image input (not supported, base64 encoded, no encoding parameter)
@@ -289,15 +286,8 @@ async def test_send_edge_cases(mock_websockets_connect, model):
289286
image=image_b64,
290287
mime_type="image/jpeg",
291288
)
292-
with unittest.mock.patch("strands.experimental.bidi.models.openai.logger") as mock_logger:
289+
with pytest.raises(ValueError, match=r"content not supported"):
293290
await model.send(image_input)
294-
mock_logger.warning.assert_called_with("Image input not supported by OpenAI Realtime API")
295-
296-
# Test unknown content type
297-
unknown_content = {"unknown_field": "value"}
298-
with unittest.mock.patch("strands.experimental.bidi.models.openai.logger") as mock_logger:
299-
await model.send(unknown_content)
300-
assert mock_logger.warning.called
301291

302292
await model.stop()
303293

@@ -318,7 +308,7 @@ async def test_receive_lifecycle_events(mock_websockets_connect, model):
318308

319309
# First event should be connection start (new TypedEvent format)
320310
assert first_event.get("type") == "bidi_connection_start"
321-
assert first_event.get("connection_id") == model.connection_id
311+
assert first_event.get("connection_id") == model._connection_id
322312
assert first_event.get("model") == model.model
323313

324314
# Close to trigger session end
@@ -332,9 +322,6 @@ async def test_receive_lifecycle_events(mock_websockets_connect, model):
332322
except StopAsyncIteration:
333323
pass
334324

335-
# Last event should be connection close (new TypedEvent format)
336-
assert events[-1].get("type") == "bidi_connection_close"
337-
338325

339326
@pytest.mark.asyncio
340327
async def test_event_conversion(mock_websockets_connect, model):
@@ -463,12 +450,6 @@ def test_tool_conversion(model, tool_spec):
463450

464451
def test_helper_methods(model):
465452
"""Test various helper methods."""
466-
# Test _require_active
467-
assert model._require_active() is False
468-
model._active = True
469-
assert model._require_active() is True
470-
model._active = False
471-
472453
# Test _create_text_event (now returns BidiTranscriptStreamEvent)
473454
text_event = model._create_text_event("Hello", "user")
474455
assert isinstance(text_event, BidiTranscriptStreamEvent)

0 commit comments

Comments
 (0)