Skip to content

Commit 5bea4fb

Browse files
committed
hooks - before node call - cancel node
1 parent aaf9715 commit 5bea4fb

File tree

7 files changed

+243
-32
lines changed

7 files changed

+243
-32
lines changed

src/strands/experimental/hooks/multiagent/events.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,18 @@ class BeforeNodeCallEvent(BaseHookEvent):
3535
source: The multi-agent orchestrator instance
3636
node_id: ID of the node about to execute
3737
invocation_state: Configuration that user passes in
38+
cancel_node: A user defined message that when set, will cancel the node execution with status FAILED.
39+
The message will be emitted under a MultiAgentNodeCancel event. If set to `True`, Strands will cancel the
40+
node using a default cancel message.
3841
"""
3942

4043
source: "MultiAgentBase"
4144
node_id: str
4245
invocation_state: dict[str, Any] | None = None
46+
cancel_node: bool | str = False
47+
48+
def _can_write(self, name: str) -> bool:
49+
return name in ["cancel_node"]
4350

4451

4552
@dataclass

src/strands/multiagent/graph.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from ..telemetry import get_tracer
3939
from ..types._events import (
4040
MultiAgentHandoffEvent,
41+
MultiAgentNodeCancelEvent,
4142
MultiAgentNodeStartEvent,
4243
MultiAgentNodeStopEvent,
4344
MultiAgentNodeStreamEvent,
@@ -777,8 +778,6 @@ def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list[
777778

778779
async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) -> AsyncIterator[Any]:
779780
"""Execute a single node and yield TypedEvent objects."""
780-
await self.hooks.invoke_callbacks_async(BeforeNodeCallEvent(self, node.node_id, invocation_state))
781-
782781
# Reset the node's state if reset_on_revisit is enabled, and it's being revisited
783782
if self.reset_on_revisit and node in self.state.completed_nodes:
784783
logger.debug("node_id=<%s> | resetting node state for revisit", node.node_id)
@@ -794,8 +793,20 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
794793
)
795794
yield start_event
796795

796+
before_event, _ = await self.hooks.invoke_callbacks_async(
797+
BeforeNodeCallEvent(self, node.node_id, invocation_state)
798+
)
799+
797800
start_time = time.time()
798801
try:
802+
if before_event.cancel_node:
803+
cancel_message = (
804+
before_event.cancel_node if isinstance(before_event.cancel_node, str) else "node cancelled by user"
805+
)
806+
logger.debug("reason=<%s> | cancelling execution", cancel_message)
807+
yield MultiAgentNodeCancelEvent(node.node_id, cancel_message)
808+
raise RuntimeError(cancel_message)
809+
799810
# Build node input from satisfied dependencies
800811
node_input = self._build_node_input(node)
801812

src/strands/multiagent/swarm.py

Lines changed: 44 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from ..tools.decorator import tool
3939
from ..types._events import (
4040
MultiAgentHandoffEvent,
41+
MultiAgentNodeCancelEvent,
4142
MultiAgentNodeStartEvent,
4243
MultiAgentNodeStopEvent,
4344
MultiAgentNodeStreamEvent,
@@ -679,11 +680,23 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato
679680
len(self.state.node_history) + 1,
680681
)
681682

683+
before_event, _ = await self.hooks.invoke_callbacks_async(
684+
BeforeNodeCallEvent(self, current_node.node_id, invocation_state)
685+
)
686+
682687
# TODO: Implement cancellation token to stop _execute_node from continuing
683688
try:
684-
await self.hooks.invoke_callbacks_async(
685-
BeforeNodeCallEvent(self, current_node.node_id, invocation_state)
686-
)
689+
if before_event.cancel_node:
690+
cancel_message = (
691+
before_event.cancel_node
692+
if isinstance(before_event.cancel_node, str)
693+
else "node cancelled by user"
694+
)
695+
logger.debug("reason=<%s> | cancelling execution", cancel_message)
696+
yield MultiAgentNodeCancelEvent(current_node.node_id, cancel_message)
697+
self.state.completion_status = Status.FAILED
698+
break
699+
687700
node_stream = self._stream_with_timeout(
688701
self._execute_node(current_node, self.state.task, invocation_state),
689702
self.node_timeout,
@@ -693,40 +706,42 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato
693706
yield event
694707

695708
self.state.node_history.append(current_node)
709+
710+
except Exception:
711+
logger.exception("node=<%s> | node execution failed", current_node.node_id)
712+
self.state.completion_status = Status.FAILED
713+
break
714+
715+
finally:
696716
await self.hooks.invoke_callbacks_async(
697717
AfterNodeCallEvent(self, current_node.node_id, invocation_state)
698718
)
699719

700-
logger.debug("node=<%s> | node execution completed", current_node.node_id)
701-
702-
# Check if handoff requested during execution
703-
if self.state.handoff_node:
704-
previous_node = current_node
705-
current_node = self.state.handoff_node
720+
logger.debug("node=<%s> | node execution completed", current_node.node_id)
706721

707-
self.state.handoff_node = None
708-
self.state.current_node = current_node
722+
# Check if handoff requested during execution
723+
if self.state.handoff_node:
724+
previous_node = current_node
725+
current_node = self.state.handoff_node
709726

710-
handoff_event = MultiAgentHandoffEvent(
711-
from_node_ids=[previous_node.node_id],
712-
to_node_ids=[current_node.node_id],
713-
message=self.state.handoff_message or "Agent handoff occurred",
714-
)
715-
yield handoff_event
716-
logger.debug(
717-
"from_node=<%s>, to_node=<%s> | handoff detected",
718-
previous_node.node_id,
719-
current_node.node_id,
720-
)
727+
self.state.handoff_node = None
728+
self.state.current_node = current_node
721729

722-
else:
723-
logger.debug("node=<%s> | no handoff occurred, marking swarm as complete", current_node.node_id)
724-
self.state.completion_status = Status.COMPLETED
725-
break
730+
handoff_event = MultiAgentHandoffEvent(
731+
from_node_ids=[previous_node.node_id],
732+
to_node_ids=[current_node.node_id],
733+
message=self.state.handoff_message or "Agent handoff occurred",
734+
)
735+
yield handoff_event
736+
logger.debug(
737+
"from_node=<%s>, to_node=<%s> | handoff detected",
738+
previous_node.node_id,
739+
current_node.node_id,
740+
)
726741

727-
except Exception:
728-
logger.exception("node=<%s> | node execution failed", current_node.node_id)
729-
self.state.completion_status = Status.FAILED
742+
else:
743+
logger.debug("node=<%s> | no handoff occurred, marking swarm as complete", current_node.node_id)
744+
self.state.completion_status = Status.COMPLETED
730745
break
731746

732747
except Exception:

src/strands/types/_events.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,3 +524,22 @@ def __init__(self, node_id: str, agent_event: dict[str, Any]) -> None:
524524
"event": agent_event, # Nest agent event to avoid field conflicts
525525
}
526526
)
527+
528+
529+
class MultiAgentNodeCancelEvent(TypedEvent):
530+
"""Event emitted when a user cancels node execution from their BeforeNodeCallEvent hook."""
531+
532+
def __init__(self, node_id: str, message: str) -> None:
533+
"""Initialize with cancel message.
534+
535+
Args:
536+
node_id: Unique identifier for the node.
537+
message: The node cancellation message.
538+
"""
539+
super().__init__(
540+
{
541+
"type": "multiagent_node_cancel",
542+
"node_id": node_id,
543+
"message": message,
544+
}
545+
)

tests/strands/multiagent/test_graph.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66

77
from strands.agent import Agent, AgentResult
88
from strands.agent.state import AgentState
9+
from strands.experimental.hooks.multiagent import BeforeNodeCallEvent
910
from strands.hooks import AgentInitializedEvent
1011
from strands.hooks.registry import HookProvider, HookRegistry
1112
from strands.multiagent.base import MultiAgentBase, MultiAgentResult, NodeResult
1213
from strands.multiagent.graph import Graph, GraphBuilder, GraphEdge, GraphNode, GraphResult, GraphState, Status
1314
from strands.session.file_session_manager import FileSessionManager
1415
from strands.session.session_manager import SessionManager
16+
from strands.types._events import MultiAgentNodeCancelEvent
1517

1618

1719
def create_mock_agent(name, response_text="Default response", metrics=None, agent_id=None):
@@ -2033,3 +2035,36 @@ async def test_graph_persisted(mock_strands_tracer, mock_use_span):
20332035
assert final_state["status"] == "completed"
20342036
assert len(final_state["completed_nodes"]) == 1
20352037
assert "test_node" in final_state["node_results"]
2038+
2039+
2040+
@pytest.mark.parametrize(
2041+
("cancel_node", "cancel_message"),
2042+
[(True, "node cancelled by user"), ("custom cancel message", "custom cancel message")],
2043+
)
2044+
@pytest.mark.asyncio
2045+
async def test_graph_cancel_node(cancel_node, cancel_message):
2046+
def cancel_callback(event):
2047+
event.cancel_node = cancel_node
2048+
return event
2049+
2050+
agent = create_mock_agent("test_agent", "Should not execute")
2051+
builder = GraphBuilder()
2052+
builder.add_node(agent, "test_agent")
2053+
builder.set_entry_point("test_agent")
2054+
graph = builder.build()
2055+
graph.hooks.add_callback(BeforeNodeCallEvent, cancel_callback)
2056+
2057+
stream = graph.stream_async("test task")
2058+
2059+
tru_cancel_event = None
2060+
with pytest.raises(RuntimeError, match=cancel_message):
2061+
async for event in stream:
2062+
if event.get("type") == "multiagent_node_cancel":
2063+
tru_cancel_event = event
2064+
2065+
exp_cancel_event = MultiAgentNodeCancelEvent(node_id="test_agent", message=cancel_message)
2066+
assert tru_cancel_event == exp_cancel_event
2067+
2068+
tru_status = graph.state.status
2069+
exp_status = Status.FAILED
2070+
assert tru_status == exp_status

tests/strands/multiagent/test_swarm.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import asyncio
22
import time
3-
from unittest.mock import MagicMock, Mock, patch
3+
from unittest.mock import ANY, MagicMock, Mock, patch
44

55
import pytest
66

77
from strands.agent import Agent, AgentResult
88
from strands.agent.state import AgentState
9+
from strands.experimental.hooks.multiagent import BeforeNodeCallEvent
910
from strands.hooks.registry import HookRegistry
1011
from strands.multiagent.base import Status
1112
from strands.multiagent.swarm import SharedContext, Swarm, SwarmNode, SwarmResult, SwarmState
@@ -1176,3 +1177,38 @@ async def handoff_stream(*args, **kwargs):
11761177
tru_node_order = [node.node_id for node in result.node_history]
11771178
exp_node_order = ["first", "second"]
11781179
assert tru_node_order == exp_node_order
1180+
1181+
1182+
@pytest.mark.parametrize(
1183+
("cancel_node", "cancel_message"),
1184+
[(True, "node cancelled by user"), ("custom cancel message", "custom cancel message")],
1185+
)
1186+
@pytest.mark.asyncio
1187+
async def test_swarm_cancel_node(cancel_node, cancel_message, alist):
1188+
def cancel_callback(event):
1189+
event.cancel_node = cancel_node
1190+
return event
1191+
1192+
agent = create_mock_agent("test_agent", "Should not execute")
1193+
swarm = Swarm([agent])
1194+
swarm.hooks.add_callback(BeforeNodeCallEvent, cancel_callback)
1195+
1196+
stream = swarm.stream_async("test task")
1197+
1198+
tru_events = await alist(stream)
1199+
exp_events = [
1200+
{
1201+
"message": cancel_message,
1202+
"node_id": "test_agent",
1203+
"type": "multiagent_node_cancel",
1204+
},
1205+
{
1206+
"result": ANY,
1207+
"type": "multiagent_result",
1208+
},
1209+
]
1210+
assert tru_events == exp_events
1211+
1212+
tru_status = swarm.state.completion_status
1213+
exp_status = Status.FAILED
1214+
assert tru_status == exp_status
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import pytest
2+
3+
from strands import Agent
4+
from strands.experimental.hooks.multiagent import BeforeNodeCallEvent
5+
from strands.hooks import HookProvider
6+
from strands.multiagent import GraphBuilder, Swarm
7+
from strands.multiagent.base import Status
8+
from strands.types._events import MultiAgentNodeCancelEvent
9+
10+
11+
@pytest.fixture
12+
def cancel_hook():
13+
class Hook(HookProvider):
14+
def register_hooks(self, registry):
15+
registry.add_callback(BeforeNodeCallEvent, self.cancel)
16+
17+
def cancel(self, event):
18+
if event.node_id == "weather":
19+
event.cancel_node = "test cancel"
20+
21+
return Hook()
22+
23+
24+
@pytest.fixture
25+
def info_agent():
26+
return Agent(name="info")
27+
28+
29+
@pytest.fixture
30+
def weather_agent():
31+
return Agent(name="weather")
32+
33+
34+
@pytest.fixture
35+
def swarm(cancel_hook, info_agent, weather_agent):
36+
return Swarm([info_agent, weather_agent], hooks=[cancel_hook])
37+
38+
39+
@pytest.fixture
40+
def graph(cancel_hook, info_agent, weather_agent):
41+
builder = GraphBuilder()
42+
builder.add_node(info_agent, "info")
43+
builder.add_node(weather_agent, "weather")
44+
builder.add_edge("info", "weather")
45+
builder.set_entry_point("info")
46+
builder.set_hook_providers([cancel_hook])
47+
48+
return builder.build()
49+
50+
51+
@pytest.mark.asyncio
52+
async def test_swarm_cancel_node(swarm):
53+
tru_cancel_event = None
54+
async for event in swarm.stream_async("What is the weather"):
55+
if event.get("type") == "multiagent_node_cancel":
56+
tru_cancel_event = event
57+
58+
multiagent_result = event["result"]
59+
60+
exp_cancel_event = MultiAgentNodeCancelEvent(node_id="weather", message="test cancel")
61+
assert tru_cancel_event == exp_cancel_event
62+
63+
tru_status = multiagent_result.status
64+
exp_status = Status.FAILED
65+
assert tru_status == exp_status
66+
67+
assert len(multiagent_result.node_history) == 1
68+
tru_node_id = multiagent_result.node_history[0].node_id
69+
exp_node_id = "info"
70+
assert tru_node_id == exp_node_id
71+
72+
73+
@pytest.mark.asyncio
74+
async def test_graph_cancel_node(graph):
75+
tru_cancel_event = None
76+
with pytest.raises(RuntimeError, match="test cancel"):
77+
async for event in graph.stream_async("What is the weather"):
78+
if event.get("type") == "multiagent_node_cancel":
79+
tru_cancel_event = event
80+
81+
exp_cancel_event = MultiAgentNodeCancelEvent(node_id="weather", message="test cancel")
82+
assert tru_cancel_event == exp_cancel_event
83+
84+
state = graph.state
85+
86+
tru_status = state.status
87+
exp_status = Status.FAILED
88+
assert tru_status == exp_status

0 commit comments

Comments
 (0)