Skip to content

Commit 9eb5fdc

Browse files
committed
Realtime: only cancel response if necessary
1 parent 9e1f699 commit 9eb5fdc

File tree

1 file changed

+19
-1
lines changed

1 file changed

+19
-1
lines changed

src/agents/realtime/openai_realtime.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def __init__(self) -> None:
140140
self._ongoing_response: bool = False
141141
self._tracing_config: RealtimeModelTracingConfig | Literal["auto"] | None = None
142142
self._playback_tracker: RealtimePlaybackTracker | None = None
143+
self._created_session: OpenAISessionObject | None = None
143144

144145
async def connect(self, options: RealtimeModelConfig) -> None:
145146
"""Establish a connection to the model and keep it alive."""
@@ -352,7 +353,14 @@ async def _send_interrupt(self, event: RealtimeModelSendInterrupt) -> None:
352353
int(elapsed_ms),
353354
)
354355
await self._send_raw_message(converted)
355-
await self._cancel_response()
356+
357+
automatic_response_cancellation_enabled = (
358+
self._created_session
359+
and self._created_session.turn_detection
360+
and self._created_session.turn_detection.interrupt_response
361+
)
362+
if not automatic_response_cancellation_enabled:
363+
await self._cancel_response()
356364

357365
self._audio_state_tracker.on_interrupted()
358366
if self._playback_tracker:
@@ -486,6 +494,9 @@ async def _handle_ws_event(self, event: dict[str, Any]):
486494
await self._emit_event(RealtimeModelTurnEndedEvent())
487495
elif parsed.type == "session.created":
488496
await self._send_tracing_config(self._tracing_config)
497+
self._update_created_session(parsed.session) # type: ignore
498+
elif parsed.type == "session.updated":
499+
self._update_created_session(parsed.session) # type: ignore
489500
elif parsed.type == "error":
490501
await self._emit_event(RealtimeModelErrorEvent(error=parsed.error))
491502
elif parsed.type == "conversation.item.deleted":
@@ -535,6 +546,13 @@ async def _handle_ws_event(self, event: dict[str, Any]):
535546
):
536547
await self._handle_output_item(parsed.item)
537548

549+
def _update_created_session(self, session: OpenAISessionObject) -> None:
550+
self._created_session = session
551+
if session.output_audio_format:
552+
self._audio_state_tracker.set_audio_format(session.output_audio_format)
553+
if self._playback_tracker:
554+
self._playback_tracker.set_audio_format(session.output_audio_format)
555+
538556
async def _update_session_config(self, model_settings: RealtimeSessionModelSettings) -> None:
539557
session_config = self._get_session_config(model_settings)
540558
await self._send_raw_message(

0 commit comments

Comments
 (0)