Skip to content

Commit b800ee9

Browse files
committed
hooks - before node call - cancel node
1 parent 95ac650 commit b800ee9

File tree

7 files changed

+236
-31
lines changed

7 files changed

+236
-31
lines changed

src/strands/experimental/hooks/multiagent/events.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,18 @@ class BeforeNodeCallEvent(BaseHookEvent):
3535
source: The multi-agent orchestrator instance
3636
node_id: ID of the node about to execute
3737
invocation_state: Configuration that user passes in
38+
cancel_node: A user defined message that when set, will cancel the node execution with status FAILED.
39+
The message will be emitted under a MultiAgentNodeCancel event. If set to `True`, Strands will cancel the
40+
node using a default cancel message.
3841
"""
3942

4043
source: "MultiAgentBase"
4144
node_id: str
4245
invocation_state: dict[str, Any] | None = None
46+
cancel_node: bool | str = False
47+
48+
def _can_write(self, name: str) -> bool:
49+
return name in ["cancel_node"]
4350

4451

4552
@dataclass

src/strands/multiagent/graph.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from ..telemetry import get_tracer
3939
from ..types._events import (
4040
MultiAgentHandoffEvent,
41+
MultiAgentNodeCancelEvent,
4142
MultiAgentNodeStartEvent,
4243
MultiAgentNodeStopEvent,
4344
MultiAgentNodeStreamEvent,
@@ -776,8 +777,6 @@ def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list[
776777

777778
async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) -> AsyncIterator[Any]:
778779
"""Execute a single node and yield TypedEvent objects."""
779-
await self.hooks.invoke_callbacks_async(BeforeNodeCallEvent(self, node.node_id, invocation_state))
780-
781780
# Reset the node's state if reset_on_revisit is enabled, and it's being revisited
782781
if self.reset_on_revisit and node in self.state.completed_nodes:
783782
logger.debug("node_id=<%s> | resetting node state for revisit", node.node_id)
@@ -795,6 +794,18 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
795794

796795
start_time = time.time()
797796
try:
797+
before_event, _ = await self.hooks.invoke_callbacks_async(
798+
BeforeNodeCallEvent(self, node.node_id, invocation_state)
799+
)
800+
801+
if before_event.cancel_node:
802+
cancel_message = (
803+
before_event.cancel_node if isinstance(before_event.cancel_node, str) else "node cancelled by user"
804+
)
805+
logger.debug("reason=<%s> | cancelling execution", cancel_message)
806+
yield MultiAgentNodeCancelEvent(node.node_id, cancel_message)
807+
raise RuntimeError(cancel_message)
808+
798809
# Build node input from satisfied dependencies
799810
node_input = self._build_node_input(node)
800811

src/strands/multiagent/swarm.py

Lines changed: 42 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from ..tools.decorator import tool
3939
from ..types._events import (
4040
MultiAgentHandoffEvent,
41+
MultiAgentNodeCancelEvent,
4142
MultiAgentNodeStartEvent,
4243
MultiAgentNodeStopEvent,
4344
MultiAgentNodeStreamEvent,
@@ -680,9 +681,21 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato
680681

681682
# TODO: Implement cancellation token to stop _execute_node from continuing
682683
try:
683-
await self.hooks.invoke_callbacks_async(
684+
before_event, _ = await self.hooks.invoke_callbacks_async(
684685
BeforeNodeCallEvent(self, current_node.node_id, invocation_state)
685686
)
687+
688+
if before_event.cancel_node:
689+
cancel_message = (
690+
before_event.cancel_node
691+
if isinstance(before_event.cancel_node, str)
692+
else "node cancelled by user"
693+
)
694+
logger.debug("reason=<%s> | cancelling execution", cancel_message)
695+
yield MultiAgentNodeCancelEvent(current_node.node_id, cancel_message)
696+
self.state.completion_status = Status.FAILED
697+
break
698+
686699
node_stream = self._stream_with_timeout(
687700
self._execute_node(current_node, self.state.task, invocation_state),
688701
self.node_timeout,
@@ -692,40 +705,42 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato
692705
yield event
693706

694707
self.state.node_history.append(current_node)
708+
709+
except Exception:
710+
logger.exception("node=<%s> | node execution failed", current_node.node_id)
711+
self.state.completion_status = Status.FAILED
712+
break
713+
714+
finally:
695715
await self.hooks.invoke_callbacks_async(
696716
AfterNodeCallEvent(self, current_node.node_id, invocation_state)
697717
)
698718

699-
logger.debug("node=<%s> | node execution completed", current_node.node_id)
719+
logger.debug("node=<%s> | node execution completed", current_node.node_id)
700720

701-
# Check if handoff requested during execution
702-
if self.state.handoff_node:
703-
previous_node = current_node
704-
current_node = self.state.handoff_node
721+
# Check if handoff requested during execution
722+
if self.state.handoff_node:
723+
previous_node = current_node
724+
current_node = self.state.handoff_node
705725

706-
self.state.handoff_node = None
707-
self.state.current_node = current_node
726+
self.state.handoff_node = None
727+
self.state.current_node = current_node
708728

709-
handoff_event = MultiAgentHandoffEvent(
710-
from_node_ids=[previous_node.node_id],
711-
to_node_ids=[current_node.node_id],
712-
message=self.state.handoff_message or "Agent handoff occurred",
713-
)
714-
yield handoff_event
715-
logger.debug(
716-
"from_node=<%s>, to_node=<%s> | handoff detected",
717-
previous_node.node_id,
718-
current_node.node_id,
719-
)
720-
721-
else:
722-
logger.debug("node=<%s> | no handoff occurred, marking swarm as complete", current_node.node_id)
723-
self.state.completion_status = Status.COMPLETED
724-
break
729+
handoff_event = MultiAgentHandoffEvent(
730+
from_node_ids=[previous_node.node_id],
731+
to_node_ids=[current_node.node_id],
732+
message=self.state.handoff_message or "Agent handoff occurred",
733+
)
734+
yield handoff_event
735+
logger.debug(
736+
"from_node=<%s>, to_node=<%s> | handoff detected",
737+
previous_node.node_id,
738+
current_node.node_id,
739+
)
725740

726-
except Exception:
727-
logger.exception("node=<%s> | node execution failed", current_node.node_id)
728-
self.state.completion_status = Status.FAILED
741+
else:
742+
logger.debug("node=<%s> | no handoff occurred, marking swarm as complete", current_node.node_id)
743+
self.state.completion_status = Status.COMPLETED
729744
break
730745

731746
except Exception:

src/strands/types/_events.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,3 +524,22 @@ def __init__(self, node_id: str, agent_event: dict[str, Any]) -> None:
524524
"event": agent_event, # Nest agent event to avoid field conflicts
525525
}
526526
)
527+
528+
529+
class MultiAgentNodeCancelEvent(TypedEvent):
530+
"""Event emitted when a user cancels node execution from their BeforeNodeCallEvent hook."""
531+
532+
def __init__(self, node_id: str, message: str) -> None:
533+
"""Initialize with cancel message.
534+
535+
Args:
536+
node_id: Unique identifier for the node.
537+
message: The node cancellation message.
538+
"""
539+
super().__init__(
540+
{
541+
"type": "multiagent_node_cancel",
542+
"node_id": node_id,
543+
"message": message,
544+
}
545+
)

tests/strands/multiagent/test_graph.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import asyncio
22
import time
3-
from unittest.mock import AsyncMock, MagicMock, Mock, call, patch
3+
from unittest.mock import ANY, AsyncMock, MagicMock, Mock, call, patch
44

55
import pytest
66

77
from strands.agent import Agent, AgentResult
88
from strands.agent.state import AgentState
9+
from strands.experimental.hooks.multiagent import BeforeNodeCallEvent
910
from strands.hooks import AgentInitializedEvent
1011
from strands.hooks.registry import HookProvider, HookRegistry
1112
from strands.multiagent.base import MultiAgentBase, MultiAgentResult, NodeResult
@@ -2033,3 +2034,29 @@ async def test_graph_persisted(mock_strands_tracer, mock_use_span):
20332034
assert final_state["status"] == "completed"
20342035
assert len(final_state["completed_nodes"]) == 1
20352036
assert "test_node" in final_state["node_results"]
2037+
2038+
2039+
@pytest.mark.parametrize(
2040+
("cancel_node", "cancel_message"),
2041+
[(True, "node cancelled by user"), ("custom cancel message", "custom cancel message")],
2042+
)
2043+
@pytest.mark.asyncio
2044+
async def test_graph_cancel_node(cancel_node, cancel_message):
2045+
def cancel_callback(event):
2046+
event.cancel_node = cancel_node
2047+
return event
2048+
2049+
agent = create_mock_agent("test_agent", "Should not execute")
2050+
builder = GraphBuilder()
2051+
builder.add_node(agent, "test_agent")
2052+
builder.set_entry_point("test_agent")
2053+
graph = builder.build()
2054+
graph.hooks.add_callback(BeforeNodeCallEvent, cancel_callback)
2055+
2056+
with pytest.raises(RuntimeError, match=cancel_message):
2057+
async for event in graph.stream_async("test task"):
2058+
if event.get("type") == "multiagent_node_cancel":
2059+
assert event["message"] == cancel_message
2060+
assert event["node_id"] == "test_agent"
2061+
2062+
assert graph.state.status == Status.FAILED

tests/strands/multiagent/test_swarm.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import asyncio
22
import time
3-
from unittest.mock import MagicMock, Mock, patch
3+
from unittest.mock import ANY, MagicMock, Mock, patch
44

55
import pytest
66

77
from strands.agent import Agent, AgentResult
88
from strands.agent.state import AgentState
9+
from strands.experimental.hooks.multiagent import BeforeNodeCallEvent
910
from strands.hooks.registry import HookRegistry
1011
from strands.multiagent.base import Status
1112
from strands.multiagent.swarm import SharedContext, Swarm, SwarmNode, SwarmResult, SwarmState
@@ -1176,3 +1177,39 @@ async def handoff_stream(*args, **kwargs):
11761177
tru_node_order = [node.node_id for node in result.node_history]
11771178
exp_node_order = ["first", "second"]
11781179
assert tru_node_order == exp_node_order
1180+
1181+
@pytest.mark.parametrize(
1182+
("cancel_node", "cancel_message"),
1183+
[(True, "node cancelled by user"), ("custom cancel message", "custom cancel message")],
1184+
)
1185+
@pytest.mark.asyncio
1186+
async def test_swarm_cancel_node(cancel_node, cancel_message, alist):
1187+
def cancel_callback(event):
1188+
event.cancel_node = cancel_node
1189+
return event
1190+
1191+
agent = create_mock_agent("test_agent", "Should not execute")
1192+
swarm = Swarm([agent])
1193+
swarm.hooks.add_callback(BeforeNodeCallEvent, cancel_callback)
1194+
1195+
stream = swarm.stream_async("test task")
1196+
1197+
print(swarm.state.results)
1198+
1199+
tru_events = await alist(stream)
1200+
exp_events = [
1201+
{
1202+
"message": cancel_message,
1203+
"node_id": "test_agent",
1204+
"type": "multiagent_node_cancel",
1205+
},
1206+
{
1207+
"result": ANY,
1208+
"type": "multiagent_result",
1209+
},
1210+
]
1211+
assert tru_events == exp_events
1212+
1213+
tru_status = swarm.state.completion_status
1214+
exp_status = Status.FAILED
1215+
assert tru_status == exp_status
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import pytest
2+
3+
from strands import Agent
4+
from strands.experimental.hooks.multiagent import BeforeNodeCallEvent
5+
from strands.hooks import HookProvider
6+
from strands.multiagent import GraphBuilder, Swarm
7+
from strands.multiagent.base import Status
8+
from strands.types._events import MultiAgentNodeCancelEvent
9+
10+
11+
@pytest.fixture
12+
def cancel_hook():
13+
class Hook(HookProvider):
14+
def register_hooks(self, registry):
15+
registry.add_callback(BeforeNodeCallEvent, self.cancel)
16+
17+
def cancel(self, event):
18+
if event.node_id == "weather":
19+
event.cancel_node = "test cancel"
20+
21+
return Hook()
22+
23+
24+
@pytest.fixture
25+
def info_agent():
26+
return Agent(name="info")
27+
28+
29+
@pytest.fixture
30+
def weather_agent():
31+
return Agent(name="weather")
32+
33+
34+
@pytest.fixture
35+
def swarm(cancel_hook, info_agent, weather_agent):
36+
return Swarm([info_agent, weather_agent], hooks=[cancel_hook])
37+
38+
39+
@pytest.fixture
40+
def graph(cancel_hook, info_agent, weather_agent):
41+
builder = GraphBuilder()
42+
builder.add_node(info_agent, "info")
43+
builder.add_node(weather_agent, "weather")
44+
builder.add_edge("info", "weather")
45+
builder.set_entry_point("info")
46+
builder.set_hook_providers([cancel_hook])
47+
48+
return builder.build()
49+
50+
51+
@pytest.mark.asyncio
52+
async def test_swarm_cancel_node(swarm):
53+
tru_cancel_event = None
54+
async for event in swarm.stream_async("What is the weather"):
55+
if event.get("type") == "multiagent_node_cancel":
56+
tru_cancel_event = event
57+
58+
multiagent_result = event["result"]
59+
60+
exp_cancel_event = MultiAgentNodeCancelEvent(node_id="weather", message="test cancel")
61+
assert tru_cancel_event == exp_cancel_event
62+
63+
tru_status = multiagent_result.status
64+
exp_status = Status.FAILED
65+
assert tru_status == exp_status
66+
67+
assert len(multiagent_result.node_history) == 1
68+
tru_node_id = multiagent_result.node_history[0].node_id
69+
exp_node_id = "info"
70+
assert tru_node_id == exp_node_id
71+
72+
73+
@pytest.mark.asyncio
74+
async def test_graph_cancel_node(graph):
75+
tru_cancel_event = None
76+
77+
with pytest.raises(RuntimeError, match=f"test cancel"):
78+
async for event in graph.stream_async("What is the weather"):
79+
if event.get("type") == "multiagent_node_cancel":
80+
tru_cancel_event = event
81+
82+
exp_cancel_event = MultiAgentNodeCancelEvent(node_id="weather", message="test cancel")
83+
assert tru_cancel_event == exp_cancel_event
84+
85+
state = graph.state
86+
87+
tru_status = state.status
88+
exp_status = Status.FAILED
89+
assert tru_status == exp_status

0 commit comments

Comments
 (0)