Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion src/strands/experimental/hooks/multiagent/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@
is used—hooks read from the orchestrator directly.
"""

import uuid
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any

from typing_extensions import override

from ....hooks import BaseHookEvent
from ....types.interrupt import _Interruptible

if TYPE_CHECKING:
from ....multiagent.base import MultiAgentBase
Expand All @@ -28,7 +32,7 @@ class MultiAgentInitializedEvent(BaseHookEvent):


@dataclass
class BeforeNodeCallEvent(BaseHookEvent):
class BeforeNodeCallEvent(BaseHookEvent, _Interruptible):
"""Event triggered before individual node execution starts.

Attributes:
Expand All @@ -48,6 +52,20 @@ class BeforeNodeCallEvent(BaseHookEvent):
def _can_write(self, name: str) -> bool:
return name in ["cancel_node"]

@override
def _interrupt_id(self, name: str) -> str:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shoot, need to add a check in Graph to temporarily error on use of interrupts. This PR just adds the functionality to Swarm. Graph will be a fast follow. Still, please feel free to review the rest of the PR in the meanwhile.

"""Unique id for the interrupt.

Args:
name: User defined name for the interrupt.

Returns:
Interrupt id.
"""
node_id = uuid.uuid5(uuid.NAMESPACE_OID, self.node_id)
call_id = uuid.uuid5(uuid.NAMESPACE_OID, name)
return f"v1:before_node_call:{node_id}:{call_id}"


@dataclass
class AfterNodeCallEvent(BaseHookEvent):
Expand Down
40 changes: 27 additions & 13 deletions src/strands/multiagent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from .._async import run_async
from ..agent import AgentResult
from ..interrupt import Interrupt
from ..types.event_loop import Metrics, Usage
from ..types.multiagent import MultiAgentInput
from ..types.traces import AttributeValue
Expand All @@ -20,22 +21,26 @@


class Status(Enum):
"""Execution status for both graphs and nodes."""
"""Execution status for both graphs and nodes.

Attributes:
PENDING: Task has not started execution yet.
EXECUTING: Task is currently running.
COMPLETED: Task finished successfully.
FAILED: Task encountered an error and could not complete.
INTERRUPTED: Task was interrupted by user.
"""

PENDING = "pending"
EXECUTING = "executing"
COMPLETED = "completed"
FAILED = "failed"
INTERRUPTED = "interrupted"


@dataclass
class NodeResult:
"""Unified result from node execution - handles both Agent and nested MultiAgentBase results.

The status field represents the semantic outcome of the node's work:
- COMPLETED: The node's task was successfully accomplished
- FAILED: The node's task failed or produced an error
"""
"""Unified result from node execution - handles both Agent and nested MultiAgentBase results."""

# Core result data - single AgentResult, nested MultiAgentResult, or Exception
result: Union[AgentResult, "MultiAgentResult", Exception]
Expand All @@ -48,6 +53,7 @@ class NodeResult:
accumulated_usage: Usage = field(default_factory=lambda: Usage(inputTokens=0, outputTokens=0, totalTokens=0))
accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0))
execution_count: int = 0
interrupts: list[Interrupt] = field(default_factory=list)

def get_agent_results(self) -> list[AgentResult]:
"""Get all AgentResult objects from this node, flattened if nested."""
Expand Down Expand Up @@ -79,6 +85,7 @@ def to_dict(self) -> dict[str, Any]:
"accumulated_usage": self.accumulated_usage,
"accumulated_metrics": self.accumulated_metrics,
"execution_count": self.execution_count,
"interrupts": [interrupt.to_dict() for interrupt in self.interrupts],
}

@classmethod
Expand All @@ -101,31 +108,32 @@ def from_dict(cls, data: dict[str, Any]) -> "NodeResult":
usage = _parse_usage(data.get("accumulated_usage", {}))
metrics = _parse_metrics(data.get("accumulated_metrics", {}))

interrupts = []
for interrupt_data in data.get("interrupts", []):
interrupts.append(Interrupt(**interrupt_data))

return cls(
result=result,
execution_time=int(data.get("execution_time", 0)),
status=Status(data.get("status", "pending")),
accumulated_usage=usage,
accumulated_metrics=metrics,
execution_count=int(data.get("execution_count", 0)),
interrupts=interrupts,
)


@dataclass
class MultiAgentResult:
"""Result from multi-agent execution with accumulated metrics.

The status field represents the outcome of the MultiAgentBase execution:
- COMPLETED: The execution was successfully accomplished
- FAILED: The execution failed or produced an error
"""
"""Result from multi-agent execution with accumulated metrics."""

status: Status = Status.PENDING
results: dict[str, NodeResult] = field(default_factory=lambda: {})
accumulated_usage: Usage = field(default_factory=lambda: Usage(inputTokens=0, outputTokens=0, totalTokens=0))
accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0))
execution_count: int = 0
execution_time: int = 0
interrupts: list[Interrupt] = field(default_factory=list)

@classmethod
def from_dict(cls, data: dict[str, Any]) -> "MultiAgentResult":
Expand All @@ -137,13 +145,18 @@ def from_dict(cls, data: dict[str, Any]) -> "MultiAgentResult":
usage = _parse_usage(data.get("accumulated_usage", {}))
metrics = _parse_metrics(data.get("accumulated_metrics", {}))

interrupts = []
for interrupt_data in data.get("interrupts", []):
interrupts.append(Interrupt(**interrupt_data))

multiagent_result = cls(
status=Status(data["status"]),
results=results,
accumulated_usage=usage,
accumulated_metrics=metrics,
execution_count=int(data.get("execution_count", 0)),
execution_time=int(data.get("execution_time", 0)),
interrupts=interrupts,
)
return multiagent_result

Expand All @@ -157,6 +170,7 @@ def to_dict(self) -> dict[str, Any]:
"accumulated_metrics": self.accumulated_metrics,
"execution_count": self.execution_count,
"execution_time": self.execution_time,
"interrupts": [interrupt.to_dict() for interrupt in self.interrupts],
}


Expand Down
4 changes: 2 additions & 2 deletions src/strands/multiagent/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,7 +979,7 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]:
if isinstance(self.state.task, str):
return [ContentBlock(text=self.state.task)]
else:
return self.state.task
return cast(list[ContentBlock], self.state.task)

# Combine task with dependency outputs
node_input = []
Expand All @@ -990,7 +990,7 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]:
else:
# Add task content blocks with a prefix
node_input.append(ContentBlock(text="Original Task:"))
node_input.extend(self.state.task)
node_input.extend(cast(list[ContentBlock], self.state.task))

# Add dependency outputs
node_input.append(ContentBlock(text="\nInputs from previous nodes:"))
Expand Down
Loading
Loading