Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/strands/experimental/hooks/multiagent/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,18 @@ class BeforeNodeCallEvent(BaseHookEvent):
source: The multi-agent orchestrator instance
node_id: ID of the node about to execute
invocation_state: Configuration that user passes in
cancel_node: A user defined message that when set, will cancel the node execution with status FAILED.
The message will be emitted under a MultiAgentNodeCancel event. If set to `True`, Strands will cancel the
node using a default cancel message.
"""

source: "MultiAgentBase"
node_id: str
invocation_state: dict[str, Any] | None = None
cancel_node: bool | str = False

def _can_write(self, name: str) -> bool:
return name in ["cancel_node"]


@dataclass
Expand Down
15 changes: 13 additions & 2 deletions src/strands/multiagent/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from ..telemetry import get_tracer
from ..types._events import (
MultiAgentHandoffEvent,
MultiAgentNodeCancelEvent,
MultiAgentNodeStartEvent,
MultiAgentNodeStopEvent,
MultiAgentNodeStreamEvent,
Expand Down Expand Up @@ -776,8 +777,6 @@ def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list[

async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) -> AsyncIterator[Any]:
"""Execute a single node and yield TypedEvent objects."""
await self.hooks.invoke_callbacks_async(BeforeNodeCallEvent(self, node.node_id, invocation_state))

# Reset the node's state if reset_on_revisit is enabled, and it's being revisited
if self.reset_on_revisit and node in self.state.completed_nodes:
logger.debug("node_id=<%s> | resetting node state for revisit", node.node_id)
Expand All @@ -793,8 +792,20 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
)
yield start_event

before_event, _ = await self.hooks.invoke_callbacks_async(
BeforeNodeCallEvent(self, node.node_id, invocation_state)
)

start_time = time.time()
try:
if before_event.cancel_node:
cancel_message = (
before_event.cancel_node if isinstance(before_event.cancel_node, str) else "node cancelled by user"
)
logger.debug("reason=<%s> | cancelling execution", cancel_message)
yield MultiAgentNodeCancelEvent(node.node_id, cancel_message)
raise RuntimeError(cancel_message)

# Build node input from satisfied dependencies
node_input = self._build_node_input(node)

Expand Down
73 changes: 44 additions & 29 deletions src/strands/multiagent/swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from ..tools.decorator import tool
from ..types._events import (
MultiAgentHandoffEvent,
MultiAgentNodeCancelEvent,
MultiAgentNodeStartEvent,
MultiAgentNodeStopEvent,
MultiAgentNodeStreamEvent,
Expand Down Expand Up @@ -678,11 +679,23 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato
len(self.state.node_history) + 1,
)

before_event, _ = await self.hooks.invoke_callbacks_async(
BeforeNodeCallEvent(self, current_node.node_id, invocation_state)
)

# TODO: Implement cancellation token to stop _execute_node from continuing
try:
await self.hooks.invoke_callbacks_async(
BeforeNodeCallEvent(self, current_node.node_id, invocation_state)
)
if before_event.cancel_node:
cancel_message = (
before_event.cancel_node
if isinstance(before_event.cancel_node, str)
else "node cancelled by user"
)
logger.debug("reason=<%s> | cancelling execution", cancel_message)
yield MultiAgentNodeCancelEvent(current_node.node_id, cancel_message)
self.state.completion_status = Status.FAILED
break

node_stream = self._stream_with_timeout(
self._execute_node(current_node, self.state.task, invocation_state),
self.node_timeout,
Expand All @@ -692,40 +705,42 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato
yield event

self.state.node_history.append(current_node)

except Exception:
logger.exception("node=<%s> | node execution failed", current_node.node_id)
self.state.completion_status = Status.FAILED
break

finally:
await self.hooks.invoke_callbacks_async(
AfterNodeCallEvent(self, current_node.node_id, invocation_state)
)

logger.debug("node=<%s> | node execution completed", current_node.node_id)

# Check if handoff requested during execution
if self.state.handoff_node:
previous_node = current_node
current_node = self.state.handoff_node
logger.debug("node=<%s> | node execution completed", current_node.node_id)

self.state.handoff_node = None
self.state.current_node = current_node
# Check if handoff requested during execution
if self.state.handoff_node:
previous_node = current_node
current_node = self.state.handoff_node

handoff_event = MultiAgentHandoffEvent(
from_node_ids=[previous_node.node_id],
to_node_ids=[current_node.node_id],
message=self.state.handoff_message or "Agent handoff occurred",
)
yield handoff_event
logger.debug(
"from_node=<%s>, to_node=<%s> | handoff detected",
previous_node.node_id,
current_node.node_id,
)
self.state.handoff_node = None
self.state.current_node = current_node

else:
logger.debug("node=<%s> | no handoff occurred, marking swarm as complete", current_node.node_id)
self.state.completion_status = Status.COMPLETED
break
handoff_event = MultiAgentHandoffEvent(
from_node_ids=[previous_node.node_id],
to_node_ids=[current_node.node_id],
message=self.state.handoff_message or "Agent handoff occurred",
)
yield handoff_event
logger.debug(
"from_node=<%s>, to_node=<%s> | handoff detected",
previous_node.node_id,
current_node.node_id,
)

except Exception:
logger.exception("node=<%s> | node execution failed", current_node.node_id)
self.state.completion_status = Status.FAILED
else:
logger.debug("node=<%s> | no handoff occurred, marking swarm as complete", current_node.node_id)
self.state.completion_status = Status.COMPLETED
break

except Exception:
Expand Down
19 changes: 19 additions & 0 deletions src/strands/types/_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,3 +524,22 @@ def __init__(self, node_id: str, agent_event: dict[str, Any]) -> None:
"event": agent_event, # Nest agent event to avoid field conflicts
}
)


class MultiAgentNodeCancelEvent(TypedEvent):
"""Event emitted when a user cancels node execution from their BeforeNodeCallEvent hook."""

def __init__(self, node_id: str, message: str) -> None:
"""Initialize with cancel message.

Args:
node_id: Unique identifier for the node.
message: The node cancellation message.
"""
super().__init__(
{
"type": "multiagent_node_cancel",
"node_id": node_id,
"message": message,
}
)
35 changes: 35 additions & 0 deletions tests/strands/multiagent/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@

from strands.agent import Agent, AgentResult
from strands.agent.state import AgentState
from strands.experimental.hooks.multiagent import BeforeNodeCallEvent
from strands.hooks import AgentInitializedEvent
from strands.hooks.registry import HookProvider, HookRegistry
from strands.multiagent.base import MultiAgentBase, MultiAgentResult, NodeResult
from strands.multiagent.graph import Graph, GraphBuilder, GraphEdge, GraphNode, GraphResult, GraphState, Status
from strands.session.file_session_manager import FileSessionManager
from strands.session.session_manager import SessionManager
from strands.types._events import MultiAgentNodeCancelEvent


def create_mock_agent(name, response_text="Default response", metrics=None, agent_id=None):
Expand Down Expand Up @@ -2033,3 +2035,36 @@ async def test_graph_persisted(mock_strands_tracer, mock_use_span):
assert final_state["status"] == "completed"
assert len(final_state["completed_nodes"]) == 1
assert "test_node" in final_state["node_results"]


@pytest.mark.parametrize(
("cancel_node", "cancel_message"),
[(True, "node cancelled by user"), ("custom cancel message", "custom cancel message")],
)
@pytest.mark.asyncio
async def test_graph_cancel_node(cancel_node, cancel_message):
def cancel_callback(event):
event.cancel_node = cancel_node
return event

agent = create_mock_agent("test_agent", "Should not execute")
builder = GraphBuilder()
builder.add_node(agent, "test_agent")
builder.set_entry_point("test_agent")
graph = builder.build()
graph.hooks.add_callback(BeforeNodeCallEvent, cancel_callback)

stream = graph.stream_async("test task")

tru_cancel_event = None
with pytest.raises(RuntimeError, match=cancel_message):
async for event in stream:
if event.get("type") == "multiagent_node_cancel":
tru_cancel_event = event

exp_cancel_event = MultiAgentNodeCancelEvent(node_id="test_agent", message=cancel_message)
assert tru_cancel_event == exp_cancel_event

tru_status = graph.state.status
exp_status = Status.FAILED
assert tru_status == exp_status
38 changes: 37 additions & 1 deletion tests/strands/multiagent/test_swarm.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import asyncio
import time
from unittest.mock import MagicMock, Mock, patch
from unittest.mock import ANY, MagicMock, Mock, patch

import pytest

from strands.agent import Agent, AgentResult
from strands.agent.state import AgentState
from strands.experimental.hooks.multiagent import BeforeNodeCallEvent
from strands.hooks.registry import HookRegistry
from strands.multiagent.base import Status
from strands.multiagent.swarm import SharedContext, Swarm, SwarmNode, SwarmResult, SwarmState
Expand Down Expand Up @@ -1176,3 +1177,38 @@ async def handoff_stream(*args, **kwargs):
tru_node_order = [node.node_id for node in result.node_history]
exp_node_order = ["first", "second"]
assert tru_node_order == exp_node_order


@pytest.mark.parametrize(
("cancel_node", "cancel_message"),
[(True, "node cancelled by user"), ("custom cancel message", "custom cancel message")],
)
@pytest.mark.asyncio
async def test_swarm_cancel_node(cancel_node, cancel_message, alist):
def cancel_callback(event):
event.cancel_node = cancel_node
return event

agent = create_mock_agent("test_agent", "Should not execute")
swarm = Swarm([agent])
swarm.hooks.add_callback(BeforeNodeCallEvent, cancel_callback)

stream = swarm.stream_async("test task")

tru_events = await alist(stream)
exp_events = [
{
"message": cancel_message,
"node_id": "test_agent",
"type": "multiagent_node_cancel",
},
{
"result": ANY,
"type": "multiagent_result",
},
]
assert tru_events == exp_events

tru_status = swarm.state.completion_status
exp_status = Status.FAILED
assert tru_status == exp_status
88 changes: 88 additions & 0 deletions tests_integ/hooks/multiagent/test_cancel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import pytest

from strands import Agent
from strands.experimental.hooks.multiagent import BeforeNodeCallEvent
from strands.hooks import HookProvider
from strands.multiagent import GraphBuilder, Swarm
from strands.multiagent.base import Status
from strands.types._events import MultiAgentNodeCancelEvent


@pytest.fixture
def cancel_hook():
class Hook(HookProvider):
def register_hooks(self, registry):
registry.add_callback(BeforeNodeCallEvent, self.cancel)

def cancel(self, event):
if event.node_id == "weather":
event.cancel_node = "test cancel"

return Hook()


@pytest.fixture
def info_agent():
return Agent(name="info")


@pytest.fixture
def weather_agent():
return Agent(name="weather")


@pytest.fixture
def swarm(cancel_hook, info_agent, weather_agent):
return Swarm([info_agent, weather_agent], hooks=[cancel_hook])


@pytest.fixture
def graph(cancel_hook, info_agent, weather_agent):
builder = GraphBuilder()
builder.add_node(info_agent, "info")
builder.add_node(weather_agent, "weather")
builder.add_edge("info", "weather")
builder.set_entry_point("info")
builder.set_hook_providers([cancel_hook])

return builder.build()


@pytest.mark.asyncio
async def test_swarm_cancel_node(swarm):
tru_cancel_event = None
async for event in swarm.stream_async("What is the weather"):
if event.get("type") == "multiagent_node_cancel":
tru_cancel_event = event

multiagent_result = event["result"]

exp_cancel_event = MultiAgentNodeCancelEvent(node_id="weather", message="test cancel")
assert tru_cancel_event == exp_cancel_event

tru_status = multiagent_result.status
exp_status = Status.FAILED
assert tru_status == exp_status

assert len(multiagent_result.node_history) == 1
tru_node_id = multiagent_result.node_history[0].node_id
exp_node_id = "info"
assert tru_node_id == exp_node_id


@pytest.mark.asyncio
async def test_graph_cancel_node(graph):
tru_cancel_event = None
with pytest.raises(RuntimeError, match="test cancel"):
async for event in graph.stream_async("What is the weather"):
if event.get("type") == "multiagent_node_cancel":
tru_cancel_event = event

exp_cancel_event = MultiAgentNodeCancelEvent(node_id="weather", message="test cancel")
assert tru_cancel_event == exp_cancel_event

state = graph.state

tru_status = state.status
exp_status = Status.FAILED
assert tru_status == exp_status
Loading