Skip to content

Commit 6aa49ec

Browse files
committed
hooks - before node call - cancel node
1 parent 95ac650 commit 6aa49ec

File tree

4 files changed

+81
-29
lines changed

4 files changed

+81
-29
lines changed

src/strands/experimental/hooks/multiagent/events.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,18 @@ class BeforeNodeCallEvent(BaseHookEvent):
3535
source: The multi-agent orchestrator instance
3636
node_id: ID of the node about to execute
3737
invocation_state: Configuration that user passes in
38+
cancel_node: A user defined message that when set, will cancel the node execution with status FAILED.
39+
The message will be emitted under a MultiAgentNodeCancel event. If set to `True`, Strands will cancel the
40+
node using a default cancel message.
3841
"""
3942

4043
source: "MultiAgentBase"
4144
node_id: str
4245
invocation_state: dict[str, Any] | None = None
46+
cancel_node: bool | str = False
47+
48+
def _can_write(self, name: str) -> bool:
49+
return name in ["cancel_node"]
4350

4451

4552
@dataclass

src/strands/multiagent/graph.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from ..telemetry import get_tracer
3939
from ..types._events import (
4040
MultiAgentHandoffEvent,
41+
MultiAgentNodeCancelEvent,
4142
MultiAgentNodeStartEvent,
4243
MultiAgentNodeStopEvent,
4344
MultiAgentNodeStreamEvent,
@@ -776,8 +777,6 @@ def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list[
776777

777778
async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) -> AsyncIterator[Any]:
778779
"""Execute a single node and yield TypedEvent objects."""
779-
await self.hooks.invoke_callbacks_async(BeforeNodeCallEvent(self, node.node_id, invocation_state))
780-
781780
# Reset the node's state if reset_on_revisit is enabled, and it's being revisited
782781
if self.reset_on_revisit and node in self.state.completed_nodes:
783782
logger.debug("node_id=<%s> | resetting node state for revisit", node.node_id)
@@ -795,6 +794,18 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
795794

796795
start_time = time.time()
797796
try:
797+
before_event, _ = await self.hooks.invoke_callbacks_async(
798+
BeforeNodeCallEvent(self, node.node_id, invocation_state)
799+
)
800+
801+
if before_event.cancel_node:
802+
cancel_message = (
803+
before_event.cancel_node if isinstance(before_event.cancel_node, str) else "node cancelled by user"
804+
)
805+
logger.debug("reason=<%s> | cancelling execution", cancel_message)
806+
yield MultiAgentNodeCancelEvent(node.node_id, cancel_message)
807+
raise RuntimeError(cancel_message)
808+
798809
# Build node input from satisfied dependencies
799810
node_input = self._build_node_input(node)
800811

src/strands/multiagent/swarm.py

Lines changed: 42 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from ..tools.decorator import tool
3939
from ..types._events import (
4040
MultiAgentHandoffEvent,
41+
MultiAgentNodeCancelEvent,
4142
MultiAgentNodeStartEvent,
4243
MultiAgentNodeStopEvent,
4344
MultiAgentNodeStreamEvent,
@@ -680,9 +681,21 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato
680681

681682
# TODO: Implement cancellation token to stop _execute_node from continuing
682683
try:
683-
await self.hooks.invoke_callbacks_async(
684+
before_event, _ = await self.hooks.invoke_callbacks_async(
684685
BeforeNodeCallEvent(self, current_node.node_id, invocation_state)
685686
)
687+
688+
if before_event.cancel_node:
689+
cancel_message = (
690+
before_event.cancel_node
691+
if isinstance(before_event.cancel_node, str)
692+
else "node cancelled by user"
693+
)
694+
logger.debug("reason=<%s> | cancelling execution", cancel_message)
695+
yield MultiAgentNodeCancelEvent(current_node.node_id, cancel_message)
696+
self.state.completion_status = Status.FAILED
697+
break
698+
686699
node_stream = self._stream_with_timeout(
687700
self._execute_node(current_node, self.state.task, invocation_state),
688701
self.node_timeout,
@@ -692,40 +705,42 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato
692705
yield event
693706

694707
self.state.node_history.append(current_node)
708+
709+
except Exception:
710+
logger.exception("node=<%s> | node execution failed", current_node.node_id)
711+
self.state.completion_status = Status.FAILED
712+
break
713+
714+
finally:
695715
await self.hooks.invoke_callbacks_async(
696716
AfterNodeCallEvent(self, current_node.node_id, invocation_state)
697717
)
698718

699-
logger.debug("node=<%s> | node execution completed", current_node.node_id)
719+
logger.debug("node=<%s> | node execution completed", current_node.node_id)
700720

701-
# Check if handoff requested during execution
702-
if self.state.handoff_node:
703-
previous_node = current_node
704-
current_node = self.state.handoff_node
721+
# Check if handoff requested during execution
722+
if self.state.handoff_node:
723+
previous_node = current_node
724+
current_node = self.state.handoff_node
705725

706-
self.state.handoff_node = None
707-
self.state.current_node = current_node
726+
self.state.handoff_node = None
727+
self.state.current_node = current_node
708728

709-
handoff_event = MultiAgentHandoffEvent(
710-
from_node_ids=[previous_node.node_id],
711-
to_node_ids=[current_node.node_id],
712-
message=self.state.handoff_message or "Agent handoff occurred",
713-
)
714-
yield handoff_event
715-
logger.debug(
716-
"from_node=<%s>, to_node=<%s> | handoff detected",
717-
previous_node.node_id,
718-
current_node.node_id,
719-
)
720-
721-
else:
722-
logger.debug("node=<%s> | no handoff occurred, marking swarm as complete", current_node.node_id)
723-
self.state.completion_status = Status.COMPLETED
724-
break
729+
handoff_event = MultiAgentHandoffEvent(
730+
from_node_ids=[previous_node.node_id],
731+
to_node_ids=[current_node.node_id],
732+
message=self.state.handoff_message or "Agent handoff occurred",
733+
)
734+
yield handoff_event
735+
logger.debug(
736+
"from_node=<%s>, to_node=<%s> | handoff detected",
737+
previous_node.node_id,
738+
current_node.node_id,
739+
)
725740

726-
except Exception:
727-
logger.exception("node=<%s> | node execution failed", current_node.node_id)
728-
self.state.completion_status = Status.FAILED
741+
else:
742+
logger.debug("node=<%s> | no handoff occurred, marking swarm as complete", current_node.node_id)
743+
self.state.completion_status = Status.COMPLETED
729744
break
730745

731746
except Exception:

src/strands/types/_events.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,3 +524,22 @@ def __init__(self, node_id: str, agent_event: dict[str, Any]) -> None:
524524
"event": agent_event, # Nest agent event to avoid field conflicts
525525
}
526526
)
527+
528+
529+
class MultiAgentNodeCancelEvent(TypedEvent):
530+
"""Event emitted when a user cancels node execution from their BeforeNodeCallEvent hook."""
531+
532+
def __init__(self, node_id: str, message: str) -> None:
533+
"""Initialize with cancel message.
534+
535+
Args:
536+
node_id: Unique identifier for the node.
537+
message: The node cancellation message.
538+
"""
539+
super().__init__(
540+
{
541+
"type": "multiagent_node_cancel",
542+
"node_id": node_id,
543+
"message": message,
544+
}
545+
)

0 commit comments

Comments
 (0)