55
66import asyncio
77import logging
8- from typing import TYPE_CHECKING , Any , AsyncIterable , Awaitable
8+ from typing import TYPE_CHECKING , Any , AsyncIterable
99
1010from ....types ._events import ToolInterruptEvent , ToolResultEvent , ToolResultMessageEvent , ToolUseStreamEvent
1111from ....types .content import Message
1818from ...hooks .events import (
1919 BidiInterruptionEvent as BidiInterruptionHookEvent ,
2020)
21+ from .._async import _TaskPool , stop_all
2122from ..types .events import BidiInterruptionEvent , BidiOutputEvent , BidiTranscriptStreamEvent
2223
2324if TYPE_CHECKING :
@@ -31,21 +32,14 @@ class _BidiAgentLoop:
3132
3233 Attributes:
3334 _agent: BidiAgent instance to loop.
35+ _started: Flag if agent loop has started.
36+ _task_pool: Track active async tasks created in loop.
3437 _event_queue: Queue model and tool call events for receiver.
35- _stop_event: Sentinel to mark end of loop.
36- _tasks: Track active async tasks created in loop.
37- _active: Flag if agent loop is started.
3838 _invocation_state: Optional context to pass to tools during execution.
3939 This allows passing custom data (user_id, session_id, database connections, etc.)
4040 that tools can access via their invocation_state parameter.
4141 """
4242
43- _event_queue : asyncio .Queue
44- _stop_event : object
45- _tasks : set
46- _active : bool
47- _invocation_state : dict [str , Any ]
48-
4943 def __init__ (self , agent : "BidiAgent" ) -> None :
5044 """Initialize members of the agent loop.
5145
@@ -55,8 +49,10 @@ def __init__(self, agent: "BidiAgent") -> None:
5549 agent: Bidirectional agent to loop over.
5650 """
5751 self ._agent = agent
58- self ._active = False
59- self ._invocation_state = {}
52+ self ._started = False
53+ self ._task_pool = _TaskPool ()
54+ self ._event_queue : asyncio .Queue
55+ self ._invocation_state : dict [str , Any ]
6056
6157 async def start (self , invocation_state : dict [str , Any ] | None = None ) -> None :
6258 """Start the agent loop.
@@ -67,19 +63,15 @@ async def start(self, invocation_state: dict[str, Any] | None = None) -> None:
6763 invocation_state: Optional context to pass to tools during execution.
6864 This allows passing custom data (user_id, session_id, database connections, etc.)
6965 that tools can access via their invocation_state parameter.
66+
67+ Raises:
68+ RuntimeError:
69+ If loop already started.
7070 """
71- if self .active :
72- return
71+ if self ._started :
72+ raise RuntimeError ( "loop already started | call stop before starting again" )
7373
7474 logger .debug ("agent loop starting" )
75-
76- self ._invocation_state = invocation_state or {}
77-
78- self ._event_queue = asyncio .Queue (maxsize = 1 )
79- self ._stop_event = object ()
80- self ._tasks = set ()
81-
82- # Emit before invocation event
8375 await self ._agent .hooks .invoke_callbacks_async (BidiBeforeInvocationEvent (agent = self ._agent ))
8476
8577 await self ._agent .model .start (
@@ -88,99 +80,88 @@ async def start(self, invocation_state: dict[str, Any] | None = None) -> None:
8880 messages = self ._agent .messages ,
8981 )
9082
91- self ._create_task (self ._run_model ())
83+ self ._event_queue = asyncio .Queue (maxsize = 1 )
84+
85+ self ._task_pool = _TaskPool ()
86+ self ._task_pool .create (self ._run_model ())
9287
93- self ._active = True
88+ self ._invocation_state = invocation_state or {}
89+ self ._started = True
9490
9591 async def stop (self ) -> None :
9692 """Stop the agent loop."""
97- if not self .active :
98- return
99-
10093 logger .debug ("agent loop stopping" )
10194
95+ self ._started = False
10296 self ._invocation_state = {}
10397
104- try :
105- # Cancel all tasks
106- for task in self ._tasks :
107- task .cancel ()
108-
109- # Wait briefly for tasks to finish their current operations
110- await asyncio .gather (* self ._tasks , return_exceptions = True )
98+ async def stop_tasks () -> None :
99+ await self ._task_pool .cancel ()
111100
112- # Stop the model
101+ async def stop_model () -> None :
113102 await self ._agent .model .stop ()
114103
115- # Clean up the event queue
116- if not self ._event_queue .empty ():
117- self ._event_queue .get_nowait ()
118- self ._event_queue .put_nowait (self ._stop_event )
119-
120- self ._active = False
121-
104+ try :
105+ await stop_all (stop_tasks , stop_model )
122106 finally :
123- # Emit after invocation event (reverse order for cleanup)
124107 await self ._agent .hooks .invoke_callbacks_async (BidiAfterInvocationEvent (agent = self ._agent ))
125108
126109 async def receive (self ) -> AsyncIterable [BidiOutputEvent ]:
127- """Receive model and tool call events."""
110+ """Receive model and tool call events.
111+
112+ Raises:
113+ RuntimeError: If start has not been called.
114+ """
115+ if not self ._started :
116+ raise RuntimeError ("loop not started | call start before receiving" )
117+
128118 while True :
129119 event = await self ._event_queue .get ()
130- if event is self . _stop_event :
131- break
120+ if isinstance ( event , Exception ) :
121+ raise event
132122
133123 yield event
134124
135- @property
136- def active (self ) -> bool :
137- """True if agent loop started, False otherwise."""
138- return self ._active
139-
140- def _create_task (self , coro : Awaitable [None ]) -> None :
141- """Utilitly to create async task.
142-
143- Adds a clean up callback to run after task completes.
144- """
145- task : asyncio .Task [None ] = asyncio .create_task (coro ) # type: ignore
146- task .add_done_callback (lambda task : self ._tasks .remove (task ))
147-
148- self ._tasks .add (task )
149-
150125 async def _run_model (self ) -> None :
151126 """Task for running the model.
152127
153128 Events are streamed through the event queue.
154129 """
155130 logger .debug ("model task starting" )
156131
157- async for event in self ._agent .model .receive (): # type: ignore
158- await self ._event_queue .put (event )
159-
160- if isinstance (event , BidiTranscriptStreamEvent ):
161- if event ["is_final" ]:
162- message : Message = {"role" : event ["role" ], "content" : [{"text" : event ["text" ]}]}
163- self ._agent .messages .append (message )
132+ try :
133+ async for event in self ._agent .model .receive (): # type: ignore
134+ await self ._event_queue .put (event )
135+
136+ if isinstance (event , BidiTranscriptStreamEvent ):
137+ if event ["is_final" ]:
138+ message : Message = {"role" : event ["role" ], "content" : [{"text" : event ["text" ]}]}
139+ self ._agent .messages .append (message )
140+ await self ._agent .hooks .invoke_callbacks_async (
141+ BidiMessageAddedEvent (agent = self ._agent , message = message )
142+ )
143+
144+ elif isinstance (event , ToolUseStreamEvent ):
145+ tool_use = event ["current_tool_use" ]
146+ self ._task_pool .create (self ._run_tool (tool_use ))
147+
148+ tool_message : Message = {"role" : "assistant" , "content" : [{"toolUse" : tool_use }]}
149+ self ._agent .messages .append (tool_message )
164150 await self ._agent .hooks .invoke_callbacks_async (
165- BidiMessageAddedEvent (agent = self ._agent , message = message )
151+ BidiMessageAddedEvent (agent = self ._agent , message = tool_message )
166152 )
167153
168- elif isinstance (event , ToolUseStreamEvent ):
169- tool_use = event ["current_tool_use" ]
170- self ._create_task (self ._run_tool (tool_use ))
171-
172- tool_message : Message = {"role" : "assistant" , "content" : [{"toolUse" : tool_use }]}
173- self ._agent .messages .append (tool_message )
174-
175- elif isinstance (event , BidiInterruptionEvent ):
176- # Emit interruption hook event
177- await self ._agent .hooks .invoke_callbacks_async (
178- BidiInterruptionHookEvent (
179- agent = self ._agent ,
180- reason = event ["reason" ],
181- interrupted_response_id = event .get ("interrupted_response_id" ),
154+ elif isinstance (event , BidiInterruptionEvent ):
155+ await self ._agent .hooks .invoke_callbacks_async (
156+ BidiInterruptionHookEvent (
157+ agent = self ._agent ,
158+ reason = event ["reason" ],
159+ interrupted_response_id = event .get ("interrupted_response_id" ),
160+ )
182161 )
183- )
162+
163+ except Exception as error :
164+ await self ._event_queue .put (error )
184165
185166 async def _run_tool (self , tool_use : ToolUse ) -> None :
186167 """Task for running tool requested by the model using the tool executor."""
@@ -196,30 +177,34 @@ async def _run_tool(self, tool_use: ToolUse) -> None:
196177 "system_prompt" : self ._agent .system_prompt ,
197178 }
198179
199- tool_events = self ._agent .tool_executor ._stream (
200- self ._agent ,
201- tool_use ,
202- tool_results ,
203- invocation_state ,
204- structured_output_context = None ,
205- )
206-
207- async for event in tool_events :
208- if isinstance (event , ToolInterruptEvent ):
209- raise RuntimeError (
210- "Tool interruption is not yet supported in BidiAgent. "
211- "ToolInterruptEvent received but cannot be handled in bidirectional streaming context."
212- )
213- await self ._event_queue .put (event )
214- if isinstance (event , ToolResultEvent ):
215- result = event .tool_result
216-
217- await self ._agent .model .send (ToolResultEvent (result ))
218-
219- message : Message = {
220- "role" : "user" ,
221- "content" : [{"toolResult" : result }],
222- }
223- self ._agent .messages .append (message )
224- await self ._agent .hooks .invoke_callbacks_async (BidiMessageAddedEvent (agent = self ._agent , message = message ))
225- await self ._event_queue .put (ToolResultMessageEvent (message ))
180+ try :
181+ tool_events = self ._agent .tool_executor ._stream (
182+ self ._agent ,
183+ tool_use ,
184+ tool_results ,
185+ invocation_state ,
186+ structured_output_context = None ,
187+ )
188+
189+ async for event in tool_events :
190+ if isinstance (event , ToolInterruptEvent ):
191+ self ._agent ._interrupt_state .deactivate ()
192+ interrupt_names = [interrupt .name for interrupt in event .interrupts ]
193+ raise RuntimeError (f"interrupts={ interrupt_names } | tool interrupts are not supported in bidi" )
194+
195+ await self ._event_queue .put (event )
196+ if isinstance (event , ToolResultEvent ):
197+ result = event .tool_result
198+
199+ await self ._agent .model .send (ToolResultEvent (result ))
200+
201+ message : Message = {
202+ "role" : "user" ,
203+ "content" : [{"toolResult" : result }],
204+ }
205+ self ._agent .messages .append (message )
206+ await self ._agent .hooks .invoke_callbacks_async (BidiMessageAddedEvent (agent = self ._agent , message = message ))
207+ await self ._event_queue .put (ToolResultMessageEvent (message ))
208+
209+ except Exception as error :
210+ await self ._event_queue .put (error )
0 commit comments