-
Notifications
You must be signed in to change notification settings - Fork 2.8k
prevent tool cancellation when AgentTask is called inside it #4586
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9c7d397
f8ba889
8f08b37
6511a91
fc28e63
e0bcd6a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -123,7 +123,8 @@ def __init__(self, agent: Agent, sess: AgentSession) -> None: | |
| # for false interruption handling | ||
| self._paused_speech: SpeechHandle | None = None | ||
| self._false_interruption_timer: asyncio.TimerHandle | None = None | ||
| self._interrupt_paused_speech_task: asyncio.Task[None] | None = None | ||
| self._cancel_speech_pause_task: asyncio.Task[None] | None = None | ||
|
|
||
| self._stt_eos_received: bool = False | ||
|
|
||
| # fired when a speech_task finishes or when a new speech_handle is scheduled | ||
|
|
@@ -754,8 +755,11 @@ async def _close_session(self) -> None: | |
| *(mcp_server.aclose() for mcp_server in self.mcp_servers), return_exceptions=True | ||
| ) | ||
|
|
||
| await self._interrupt_paused_speech(old_task=self._interrupt_paused_speech_task) | ||
| self._interrupt_paused_speech_task = None | ||
| await self._cancel_speech_pause( | ||
| old_task=self._cancel_speech_pause_task, | ||
| interrupt=False, # don't interrupt the paused speech, it's managed by _pause_scheduling_task | ||
| ) | ||
| self._cancel_speech_pause_task = None | ||
|
|
||
| async def aclose(self) -> None: | ||
| # `aclose` must only be called by AgentSession | ||
|
|
@@ -1371,8 +1375,8 @@ def on_final_transcript(self, ev: stt.SpeechEvent, *, speaking: bool | None = No | |
| # schedule a resume timer if interrupted after end_of_speech | ||
| self._start_false_interruption_timer(timeout) | ||
|
|
||
| self._interrupt_paused_speech_task = asyncio.create_task( | ||
| self._interrupt_paused_speech(old_task=self._interrupt_paused_speech_task) | ||
| self._cancel_speech_pause_task = asyncio.create_task( | ||
| self._cancel_speech_pause(old_task=self._cancel_speech_pause_task) | ||
| ) | ||
|
|
||
| def on_preemptive_generation(self, info: _PreemptiveGenerationInfo) -> None: | ||
|
|
@@ -1490,7 +1494,7 @@ async def _user_turn_completed_task( | |
| extra={"user_input": info.new_transcript}, | ||
| ) | ||
| return | ||
| await self._interrupt_paused_speech(self._interrupt_paused_speech_task) | ||
| await self._cancel_speech_pause(self._cancel_speech_pause_task) | ||
|
|
||
| await current_speech.interrupt() | ||
|
|
||
|
|
@@ -2079,20 +2083,16 @@ def _tool_execution_completed_cb(out: ToolExecutionOutput) -> None: | |
| ) | ||
|
|
||
| current_span.set_attribute(trace_types.ATTR_SPEECH_INTERRUPTED, speech_handle.interrupted) | ||
| has_speech_message = False | ||
|
|
||
| # add the tools messages that triggers this reply to the chat context | ||
| if _previous_tools_messages: | ||
| self._agent._chat_ctx.insert(_previous_tools_messages) | ||
| self._session._tool_items_added(_previous_tools_messages) | ||
|
|
||
| forwarded_text = text_out.text if text_out else "" | ||
| if speech_handle.interrupted: | ||
| await utils.aio.cancel_and_wait(*tasks) | ||
| await text_tee.aclose() | ||
|
|
||
| forwarded_text = text_out.text if text_out else "" | ||
| if forwarded_text: | ||
| has_speech_message = True | ||
| # if the audio playout was enabled, clear the buffer | ||
| if audio_output is not None: | ||
| audio_output.clear_buffer() | ||
|
|
@@ -2109,55 +2109,39 @@ def _tool_execution_completed_cb(out: ToolExecutionOutput) -> None: | |
| else: | ||
| forwarded_text = "" | ||
|
|
||
| if forwarded_text: | ||
| msg = chat_ctx.add_message( | ||
| role="assistant", | ||
| content=forwarded_text, | ||
| id=llm_gen_data.id, | ||
| interrupted=True, | ||
| created_at=reply_started_at, | ||
| metrics=assistant_metrics, | ||
| ) | ||
| self._agent._chat_ctx.insert(msg) | ||
| self._session._conversation_item_added(msg) | ||
| speech_handle._item_added([msg]) | ||
| current_span.set_attribute(trace_types.ATTR_RESPONSE_TEXT, forwarded_text) | ||
|
|
||
| if self._session.agent_state == "speaking": | ||
| self._session._update_agent_state("listening") | ||
|
|
||
| speech_handle._mark_generation_done() | ||
| await utils.aio.cancel_and_wait(exe_task) | ||
| return | ||
|
Comment on lines
-2113
to
-2131
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Was this some duplicated logic?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, we have some duplicated code for interrupted and not interrupted. I merged them in this pr. |
||
|
|
||
| if read_transcript_from_tts and text_out and not text_out.text: | ||
| elif read_transcript_from_tts and text_out and not text_out.text: | ||
| logger.warning( | ||
| "`use_tts_aligned_transcript` is enabled but no agent transcript was returned from tts" | ||
| ) | ||
|
|
||
| if text_out and text_out.text: | ||
| has_speech_message = True | ||
| if forwarded_text: | ||
| msg = chat_ctx.add_message( | ||
| role="assistant", | ||
| content=text_out.text, | ||
| content=forwarded_text, | ||
| id=llm_gen_data.id, | ||
| interrupted=False, | ||
| interrupted=speech_handle.interrupted, | ||
| created_at=reply_started_at, | ||
| metrics=assistant_metrics, | ||
| ) | ||
| self._agent._chat_ctx.insert(msg) | ||
| self._session._conversation_item_added(msg) | ||
| speech_handle._item_added([msg]) | ||
| current_span.set_attribute(trace_types.ATTR_RESPONSE_TEXT, text_out.text) | ||
| current_span.set_attribute(trace_types.ATTR_RESPONSE_TEXT, forwarded_text) | ||
|
|
||
| if len(tool_output.output) > 0: | ||
| if not speech_handle.interrupted and len(tool_output.output) > 0: | ||
| self._session._update_agent_state("thinking") | ||
| elif self._session.agent_state == "speaking": | ||
| self._session._update_agent_state("listening") | ||
|
|
||
| await text_tee.aclose() | ||
|
|
||
| speech_handle._mark_generation_done() # mark the playout done before waiting for the tool execution # noqa: E501 | ||
|
|
||
| if speech_handle.interrupted: | ||
| await utils.aio.cancel_and_wait(exe_task) | ||
| return | ||
|
Comment on lines
+2140
to
+2142
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this should be removed?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there is a guard for cancellation https://github.com/livekit/agents/blob/livekit-agents@1.3.12/livekit-agents/livekit/agents/voice/generation.py#L648-L658, we will cancel the tool execution task but not the user's function |
||
|
|
||
| # wait for the tool execution to complete | ||
| self._background_speeches.add(speech_handle) | ||
| try: | ||
| await exe_task | ||
|
|
@@ -2229,7 +2213,7 @@ def _tool_execution_completed_cb(out: ToolExecutionOutput) -> None: | |
| ), | ||
| # in case the current reply only generated tools (no speech), re-use the current user_metrics for the next | ||
| # tool response generation | ||
| _previous_user_metrics=user_metrics if not has_speech_message else None, | ||
| _previous_user_metrics=user_metrics if not forwarded_text else None, | ||
| _previous_tools_messages=tool_messages, | ||
| ), | ||
| speech_handle=speech_handle, | ||
|
|
@@ -2580,83 +2564,66 @@ def _create_assistant_message( | |
| msg.metrics = assistant_metrics | ||
| return msg | ||
|
|
||
| msg_gen, text_out, audio_out = ( | ||
| message_outputs[0] if len(message_outputs) > 0 else (None, None, None) | ||
| ) # there should be only one message | ||
|
|
||
| forwarded_text = text_out.text if text_out else "" | ||
| if speech_handle.interrupted: | ||
| await utils.aio.cancel_and_wait(*tasks) | ||
|
|
||
| if len(message_outputs) > 0: | ||
| # there should be only one message | ||
| msg_gen, text_out, audio_out = message_outputs[0] | ||
| forwarded_text = text_out.text if text_out else "" | ||
| if audio_output is not None: | ||
| audio_output.clear_buffer() | ||
| if msg_gen and audio_output is not None: | ||
| audio_output.clear_buffer() | ||
|
|
||
| playback_ev = await audio_output.wait_for_playout() | ||
| playback_position = playback_ev.playback_position | ||
| if ( | ||
| audio_out is not None | ||
| and audio_out.first_frame_fut.done() | ||
| and not audio_out.first_frame_fut.cancelled() | ||
| ): | ||
| # playback_ev is valid only if the first frame was already played | ||
| if playback_ev.synchronized_transcript is not None: | ||
| forwarded_text = playback_ev.synchronized_transcript | ||
| else: | ||
| forwarded_text = "" | ||
| playback_position = 0 | ||
|
|
||
| # truncate server-side message (if supported) | ||
| if self.llm.capabilities.message_truncation: | ||
| msg_modalities = await msg_gen.modalities | ||
| self._rt_session.truncate( | ||
| message_id=msg_gen.message_id, | ||
| modalities=msg_modalities, | ||
| audio_end_ms=int(playback_position * 1000), | ||
| audio_transcript=forwarded_text, | ||
| ) | ||
| playback_ev = await audio_output.wait_for_playout() | ||
| playback_position = playback_ev.playback_position | ||
| if ( | ||
| audio_out is not None | ||
| and audio_out.first_frame_fut.done() | ||
| and not audio_out.first_frame_fut.cancelled() | ||
| ): | ||
| # playback_ev is valid only if the first frame was already played | ||
| if playback_ev.synchronized_transcript is not None: | ||
| forwarded_text = playback_ev.synchronized_transcript | ||
| else: | ||
| forwarded_text = "" | ||
| playback_position = 0 | ||
|
|
||
| msg: llm.ChatMessage | None = None | ||
| if forwarded_text: | ||
| msg = _create_assistant_message( | ||
| # truncate server-side message (if supported) | ||
| if self.llm.capabilities.message_truncation: | ||
| msg_modalities = await msg_gen.modalities | ||
| self._rt_session.truncate( | ||
| message_id=msg_gen.message_id, | ||
| forwarded_text=forwarded_text, | ||
| interrupted=True, | ||
| modalities=msg_modalities, | ||
| audio_end_ms=int(playback_position * 1000), | ||
| audio_transcript=forwarded_text, | ||
| ) | ||
| self._agent._chat_ctx.items.append(msg) | ||
| speech_handle._item_added([msg]) | ||
| self._session._conversation_item_added(msg) | ||
| current_span.set_attribute(trace_types.ATTR_RESPONSE_TEXT, forwarded_text) | ||
|
|
||
| speech_handle._mark_generation_done() | ||
| await utils.aio.cancel_and_wait(exe_task) | ||
|
|
||
| for tee in tees: | ||
| await tee.aclose() | ||
| return | ||
|
|
||
| if len(message_outputs) > 0: | ||
| # there should be only one message | ||
| msg_gen, text_out, _ = message_outputs[0] | ||
| forwarded_text = text_out.text if text_out else "" | ||
| if forwarded_text: | ||
| msg = _create_assistant_message( | ||
| message_id=msg_gen.message_id, | ||
| forwarded_text=forwarded_text, | ||
| interrupted=False, | ||
| ) | ||
| self._agent._chat_ctx.items.append(msg) | ||
| speech_handle._item_added([msg]) | ||
| self._session._conversation_item_added(msg) | ||
| current_span.set_attribute(trace_types.ATTR_RESPONSE_TEXT, forwarded_text) | ||
| elif read_transcript_from_tts and text_out and not text_out.text: | ||
| logger.warning( | ||
| "`use_tts_aligned_transcript` is enabled but no agent transcript was returned from tts" | ||
| ) | ||
|
|
||
| elif read_transcript_from_tts and text_out is not None: | ||
| logger.warning( | ||
| "`use_tts_aligned_transcript` is enabled but no agent transcript was returned from tts" | ||
| ) | ||
| if msg_gen and forwarded_text: | ||
| msg = _create_assistant_message( | ||
| message_id=msg_gen.message_id, | ||
| forwarded_text=forwarded_text, | ||
| interrupted=speech_handle.interrupted, | ||
| ) | ||
| self._agent._chat_ctx.items.append(msg) | ||
| speech_handle._item_added([msg]) | ||
| self._session._conversation_item_added(msg) | ||
| current_span.set_attribute(trace_types.ATTR_RESPONSE_TEXT, forwarded_text) | ||
|
|
||
| for tee in tees: | ||
| await tee.aclose() | ||
| speech_handle._mark_generation_done() | ||
|
|
||
| speech_handle._mark_generation_done() # mark the playout done before waiting for the tool execution # noqa: E501 | ||
| if speech_handle.interrupted: | ||
| await utils.aio.cancel_and_wait(exe_task) | ||
| return | ||
|
|
||
| # wait for the tool execution to complete | ||
| tool_output.first_tool_started_fut.add_done_callback( | ||
| lambda _: self._session._update_agent_state("thinking") | ||
| ) | ||
|
|
@@ -2806,7 +2773,9 @@ def _on_false_interruption() -> None: | |
| timeout, _on_false_interruption | ||
| ) | ||
|
|
||
| async def _interrupt_paused_speech(self, old_task: asyncio.Task[None] | None = None) -> None: | ||
| async def _cancel_speech_pause( | ||
| self, old_task: asyncio.Task[None] | None = None, *, interrupt: bool = True | ||
| ) -> None: | ||
| if old_task is not None: | ||
| await old_task | ||
|
|
||
|
|
@@ -2817,8 +2786,14 @@ async def _interrupt_paused_speech(self, old_task: asyncio.Task[None] | None = N | |
| if not self._paused_speech: | ||
| return | ||
|
|
||
| if not self._paused_speech.interrupted and self._paused_speech.allow_interruptions: | ||
| await self._paused_speech.interrupt() # ensure the speech is done | ||
| if ( | ||
| interrupt | ||
| and not self._paused_speech.interrupted | ||
| and self._paused_speech.allow_interruptions | ||
| ): | ||
| self._paused_speech.interrupt() | ||
| # ensure the generation is done | ||
| await self._paused_speech._wait_for_generation() | ||
| self._paused_speech = None | ||
|
|
||
| if self._session.options.resume_false_interruption and self._session.output.audio: | ||
|
Comment on lines
2797
to
2799
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟡 Paused speech state cleared prematurely when allow_interruptions is False When Click to expandScenario
ImpactThe paused speech reference is cleared prematurely while an
(Refers to lines 2797-2800) Recommendation: Consider not clearing Was this helpful? React with 👍 or 👎 to provide feedback. |
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.