Skip to content

Commit c4181ec

Browse files
authored
cancellation - agent loop (strands-agents#63)
1 parent 6298c86 commit c4181ec

File tree

2 files changed

+104
-111
lines changed

2 files changed

+104
-111
lines changed

src/strands/experimental/bidi/agent/agent.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from .... import _identifier
2121
from ....agent.state import AgentState
2222
from ....hooks import HookProvider, HookRegistry
23+
from ....interrupt import _InterruptState
2324
from ....tools.caller import _ToolCaller
2425
from ....tools.executors import ConcurrentToolExecutor
2526
from ....tools.executors._executor import ToolExecutor
@@ -146,6 +147,9 @@ def __init__(
146147
# Emit initialization event
147148
self.hooks.invoke_callbacks(BidiAgentInitializedEvent(agent=self))
148149

150+
# TODO: Determine if full support is required
151+
self._interrupt_state = _InterruptState()
152+
149153
self._started = False
150154

151155
@property
@@ -270,6 +274,10 @@ async def start(self, invocation_state: dict[str, Any] | None = None) -> None:
270274
This allows passing custom data (user_id, session_id, database connections, etc.)
271275
that tools can access via their invocation_state parameter.
272276
277+
Raises:
278+
RuntimeError:
279+
If agent already started.
280+
273281
Example:
274282
```python
275283
await agent.start(invocation_state={

src/strands/experimental/bidi/agent/loop.py

Lines changed: 96 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import asyncio
77
import logging
8-
from typing import TYPE_CHECKING, Any, AsyncIterable, Awaitable
8+
from typing import TYPE_CHECKING, Any, AsyncIterable
99

1010
from ....types._events import ToolInterruptEvent, ToolResultEvent, ToolResultMessageEvent, ToolUseStreamEvent
1111
from ....types.content import Message
@@ -18,6 +18,7 @@
1818
from ...hooks.events import (
1919
BidiInterruptionEvent as BidiInterruptionHookEvent,
2020
)
21+
from .._async import _TaskPool, stop_all
2122
from ..types.events import BidiInterruptionEvent, BidiOutputEvent, BidiTranscriptStreamEvent
2223

2324
if TYPE_CHECKING:
@@ -31,21 +32,14 @@ class _BidiAgentLoop:
3132
3233
Attributes:
3334
_agent: BidiAgent instance to loop.
35+
_started: Flag if agent loop has started.
36+
_task_pool: Track active async tasks created in loop.
3437
_event_queue: Queue model and tool call events for receiver.
35-
_stop_event: Sentinel to mark end of loop.
36-
_tasks: Track active async tasks created in loop.
37-
_active: Flag if agent loop is started.
3838
_invocation_state: Optional context to pass to tools during execution.
3939
This allows passing custom data (user_id, session_id, database connections, etc.)
4040
that tools can access via their invocation_state parameter.
4141
"""
4242

43-
_event_queue: asyncio.Queue
44-
_stop_event: object
45-
_tasks: set
46-
_active: bool
47-
_invocation_state: dict[str, Any]
48-
4943
def __init__(self, agent: "BidiAgent") -> None:
5044
"""Initialize members of the agent loop.
5145
@@ -55,8 +49,10 @@ def __init__(self, agent: "BidiAgent") -> None:
5549
agent: Bidirectional agent to loop over.
5650
"""
5751
self._agent = agent
58-
self._active = False
59-
self._invocation_state = {}
52+
self._started = False
53+
self._task_pool = _TaskPool()
54+
self._event_queue: asyncio.Queue
55+
self._invocation_state: dict[str, Any]
6056

6157
async def start(self, invocation_state: dict[str, Any] | None = None) -> None:
6258
"""Start the agent loop.
@@ -67,19 +63,15 @@ async def start(self, invocation_state: dict[str, Any] | None = None) -> None:
6763
invocation_state: Optional context to pass to tools during execution.
6864
This allows passing custom data (user_id, session_id, database connections, etc.)
6965
that tools can access via their invocation_state parameter.
66+
67+
Raises:
68+
RuntimeError:
69+
If loop already started.
7070
"""
71-
if self.active:
72-
return
71+
if self._started:
72+
raise RuntimeError("loop already started | call stop before starting again")
7373

7474
logger.debug("agent loop starting")
75-
76-
self._invocation_state = invocation_state or {}
77-
78-
self._event_queue = asyncio.Queue(maxsize=1)
79-
self._stop_event = object()
80-
self._tasks = set()
81-
82-
# Emit before invocation event
8375
await self._agent.hooks.invoke_callbacks_async(BidiBeforeInvocationEvent(agent=self._agent))
8476

8577
await self._agent.model.start(
@@ -88,99 +80,88 @@ async def start(self, invocation_state: dict[str, Any] | None = None) -> None:
8880
messages=self._agent.messages,
8981
)
9082

91-
self._create_task(self._run_model())
83+
self._event_queue = asyncio.Queue(maxsize=1)
84+
85+
self._task_pool = _TaskPool()
86+
self._task_pool.create(self._run_model())
9287

93-
self._active = True
88+
self._invocation_state = invocation_state or {}
89+
self._started = True
9490

9591
async def stop(self) -> None:
9692
"""Stop the agent loop."""
97-
if not self.active:
98-
return
99-
10093
logger.debug("agent loop stopping")
10194

95+
self._started = False
10296
self._invocation_state = {}
10397

104-
try:
105-
# Cancel all tasks
106-
for task in self._tasks:
107-
task.cancel()
108-
109-
# Wait briefly for tasks to finish their current operations
110-
await asyncio.gather(*self._tasks, return_exceptions=True)
98+
async def stop_tasks() -> None:
99+
await self._task_pool.cancel()
111100

112-
# Stop the model
101+
async def stop_model() -> None:
113102
await self._agent.model.stop()
114103

115-
# Clean up the event queue
116-
if not self._event_queue.empty():
117-
self._event_queue.get_nowait()
118-
self._event_queue.put_nowait(self._stop_event)
119-
120-
self._active = False
121-
104+
try:
105+
await stop_all(stop_tasks, stop_model)
122106
finally:
123-
# Emit after invocation event (reverse order for cleanup)
124107
await self._agent.hooks.invoke_callbacks_async(BidiAfterInvocationEvent(agent=self._agent))
125108

126109
async def receive(self) -> AsyncIterable[BidiOutputEvent]:
127-
"""Receive model and tool call events."""
110+
"""Receive model and tool call events.
111+
112+
Raises:
113+
RuntimeError: If start has not been called.
114+
"""
115+
if not self._started:
116+
raise RuntimeError("loop not started | call start before receiving")
117+
128118
while True:
129119
event = await self._event_queue.get()
130-
if event is self._stop_event:
131-
break
120+
if isinstance(event, Exception):
121+
raise event
132122

133123
yield event
134124

135-
@property
136-
def active(self) -> bool:
137-
"""True if agent loop started, False otherwise."""
138-
return self._active
139-
140-
def _create_task(self, coro: Awaitable[None]) -> None:
141-
"""Utilitly to create async task.
142-
143-
Adds a clean up callback to run after task completes.
144-
"""
145-
task: asyncio.Task[None] = asyncio.create_task(coro) # type: ignore
146-
task.add_done_callback(lambda task: self._tasks.remove(task))
147-
148-
self._tasks.add(task)
149-
150125
async def _run_model(self) -> None:
151126
"""Task for running the model.
152127
153128
Events are streamed through the event queue.
154129
"""
155130
logger.debug("model task starting")
156131

157-
async for event in self._agent.model.receive(): # type: ignore
158-
await self._event_queue.put(event)
159-
160-
if isinstance(event, BidiTranscriptStreamEvent):
161-
if event["is_final"]:
162-
message: Message = {"role": event["role"], "content": [{"text": event["text"]}]}
163-
self._agent.messages.append(message)
132+
try:
133+
async for event in self._agent.model.receive(): # type: ignore
134+
await self._event_queue.put(event)
135+
136+
if isinstance(event, BidiTranscriptStreamEvent):
137+
if event["is_final"]:
138+
message: Message = {"role": event["role"], "content": [{"text": event["text"]}]}
139+
self._agent.messages.append(message)
140+
await self._agent.hooks.invoke_callbacks_async(
141+
BidiMessageAddedEvent(agent=self._agent, message=message)
142+
)
143+
144+
elif isinstance(event, ToolUseStreamEvent):
145+
tool_use = event["current_tool_use"]
146+
self._task_pool.create(self._run_tool(tool_use))
147+
148+
tool_message: Message = {"role": "assistant", "content": [{"toolUse": tool_use}]}
149+
self._agent.messages.append(tool_message)
164150
await self._agent.hooks.invoke_callbacks_async(
165-
BidiMessageAddedEvent(agent=self._agent, message=message)
151+
BidiMessageAddedEvent(agent=self._agent, message=tool_message)
166152
)
167153

168-
elif isinstance(event, ToolUseStreamEvent):
169-
tool_use = event["current_tool_use"]
170-
self._create_task(self._run_tool(tool_use))
171-
172-
tool_message: Message = {"role": "assistant", "content": [{"toolUse": tool_use}]}
173-
self._agent.messages.append(tool_message)
174-
175-
elif isinstance(event, BidiInterruptionEvent):
176-
# Emit interruption hook event
177-
await self._agent.hooks.invoke_callbacks_async(
178-
BidiInterruptionHookEvent(
179-
agent=self._agent,
180-
reason=event["reason"],
181-
interrupted_response_id=event.get("interrupted_response_id"),
154+
elif isinstance(event, BidiInterruptionEvent):
155+
await self._agent.hooks.invoke_callbacks_async(
156+
BidiInterruptionHookEvent(
157+
agent=self._agent,
158+
reason=event["reason"],
159+
interrupted_response_id=event.get("interrupted_response_id"),
160+
)
182161
)
183-
)
162+
163+
except Exception as error:
164+
await self._event_queue.put(error)
184165

185166
async def _run_tool(self, tool_use: ToolUse) -> None:
186167
"""Task for running tool requested by the model using the tool executor."""
@@ -196,30 +177,34 @@ async def _run_tool(self, tool_use: ToolUse) -> None:
196177
"system_prompt": self._agent.system_prompt,
197178
}
198179

199-
tool_events = self._agent.tool_executor._stream(
200-
self._agent,
201-
tool_use,
202-
tool_results,
203-
invocation_state,
204-
structured_output_context=None,
205-
)
206-
207-
async for event in tool_events:
208-
if isinstance(event, ToolInterruptEvent):
209-
raise RuntimeError(
210-
"Tool interruption is not yet supported in BidiAgent. "
211-
"ToolInterruptEvent received but cannot be handled in bidirectional streaming context."
212-
)
213-
await self._event_queue.put(event)
214-
if isinstance(event, ToolResultEvent):
215-
result = event.tool_result
216-
217-
await self._agent.model.send(ToolResultEvent(result))
218-
219-
message: Message = {
220-
"role": "user",
221-
"content": [{"toolResult": result}],
222-
}
223-
self._agent.messages.append(message)
224-
await self._agent.hooks.invoke_callbacks_async(BidiMessageAddedEvent(agent=self._agent, message=message))
225-
await self._event_queue.put(ToolResultMessageEvent(message))
180+
try:
181+
tool_events = self._agent.tool_executor._stream(
182+
self._agent,
183+
tool_use,
184+
tool_results,
185+
invocation_state,
186+
structured_output_context=None,
187+
)
188+
189+
async for event in tool_events:
190+
if isinstance(event, ToolInterruptEvent):
191+
self._agent._interrupt_state.deactivate()
192+
interrupt_names = [interrupt.name for interrupt in event.interrupts]
193+
raise RuntimeError(f"interrupts={interrupt_names} | tool interrupts are not supported in bidi")
194+
195+
await self._event_queue.put(event)
196+
if isinstance(event, ToolResultEvent):
197+
result = event.tool_result
198+
199+
await self._agent.model.send(ToolResultEvent(result))
200+
201+
message: Message = {
202+
"role": "user",
203+
"content": [{"toolResult": result}],
204+
}
205+
self._agent.messages.append(message)
206+
await self._agent.hooks.invoke_callbacks_async(BidiMessageAddedEvent(agent=self._agent, message=message))
207+
await self._event_queue.put(ToolResultMessageEvent(message))
208+
209+
except Exception as error:
210+
await self._event_queue.put(error)

0 commit comments

Comments
 (0)