Skip to content

feat(multi-agent): introduce Graph multi-agent orchestrator #336

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ __pycache__*
.ruff_cache
*.bak
.vscode
dist
dist
repl_state
83 changes: 54 additions & 29 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ dev = [
"pre-commit>=3.2.0,<4.2.0",
"pytest>=8.0.0,<9.0.0",
"pytest-asyncio>=0.26.0,<0.27.0",
"pytest-cov>=4.1.0,<5.0.0",
"pytest-xdist>=3.0.0,<4.0.0",
"ruff>=0.4.4,<0.5.0",
]
docs = [
Expand Down Expand Up @@ -94,13 +96,59 @@ a2a = [
"fastapi>=0.115.12",
"starlette>=0.46.2",
]
all = [
# anthropic
"anthropic>=0.21.0,<1.0.0",

# dev
"commitizen>=4.4.0,<5.0.0",
"hatch>=1.0.0,<2.0.0",
"moto>=5.1.0,<6.0.0",
"mypy>=1.15.0,<2.0.0",
"pre-commit>=3.2.0,<4.2.0",
"pytest>=8.0.0,<9.0.0",
"pytest-asyncio>=0.26.0,<0.27.0",
"pytest-cov>=4.1.0,<5.0.0",
"pytest-xdist>=3.0.0,<4.0.0",
"ruff>=0.4.4,<0.5.0",

# docs
"sphinx>=5.0.0,<6.0.0",
"sphinx-rtd-theme>=1.0.0,<2.0.0",
"sphinx-autodoc-typehints>=1.12.0,<2.0.0",

# litellm
"litellm>=1.72.6,<1.73.0",

# llama
"llama-api-client>=0.1.0,<1.0.0",

# mistral
"mistralai>=1.8.2",

# ollama
"ollama>=0.4.8,<1.0.0",

# openai
"openai>=1.68.0,<2.0.0",

# otel
"opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0",

# a2a
"a2a-sdk[sql]>=0.2.11",
"uvicorn>=0.34.2",
"httpx>=0.28.1",
"fastapi>=0.115.12",
"starlette>=0.46.2",
]

[tool.hatch.version]
# Tells Hatch to use your version control system (git) to determine the version.
source = "vcs"

[tool.hatch.envs.hatch-static-analysis]
features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "writer"]
features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "writer", "a2a"]
dependencies = [
"mypy>=1.15.0,<2.0.0",
"ruff>=0.11.6,<0.12.0",
Expand All @@ -116,15 +164,14 @@ format-fix = [
]
lint-check = [
"ruff check",
# excluding due to A2A and OTEL http exporter dependency conflict
"mypy -p src --exclude src/strands/multiagent"
"mypy -p src"
]
lint-fix = [
"ruff check --fix"
]

[tool.hatch.envs.hatch-test]
features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "writer"]
features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "writer", "a2a"]
extra-dependencies = [
"moto>=5.1.0,<6.0.0",
"pytest>=8.0.0,<9.0.0",
Expand All @@ -140,35 +187,17 @@ extra-args = [

[tool.hatch.envs.dev]
dev-mode = true
features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "otel", "mistral", "writer"]

[tool.hatch.envs.a2a]
dev-mode = true
features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "a2a"]

[tool.hatch.envs.a2a.scripts]
run = [
"pytest{env:HATCH_TEST_ARGS:} tests/strands/multiagent/a2a {args}"
]
run-cov = [
"pytest{env:HATCH_TEST_ARGS:} tests/strands/multiagent/a2a --cov --cov-config=pyproject.toml {args}"
]
lint-check = [
"ruff check",
"mypy -p src/strands/multiagent/a2a"
]
features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "otel", "mistral", "writer", "a2a"]

[[tool.hatch.envs.hatch-test.matrix]]
python = ["3.13", "3.12", "3.11", "3.10"]

[tool.hatch.envs.hatch-test.scripts]
run = [
# excluding due to A2A and OTEL http exporter dependency conflict
"pytest{env:HATCH_TEST_ARGS:} {args} --ignore=tests/strands/multiagent/a2a"
"pytest{env:HATCH_TEST_ARGS:} {args}"
]
run-cov = [
# excluding due to A2A and OTEL http exporter dependency conflict
"pytest{env:HATCH_TEST_ARGS:} --cov --cov-config=pyproject.toml {args} --ignore=tests/strands/multiagent/a2a"
"pytest{env:HATCH_TEST_ARGS:} --cov --cov-config=pyproject.toml {args}"
]

cov-combine = []
Expand Down Expand Up @@ -203,10 +232,6 @@ prepare = [
"hatch run test-lint",
"hatch test --all"
]
test-a2a = [
# required to run manually due to A2A and OTEL http exporter dependency conflict
"hatch -e a2a run run {args}"
]

[tool.mypy]
python_version = "3.10"
Expand Down
9 changes: 7 additions & 2 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import random
from concurrent.futures import ThreadPoolExecutor
from typing import Any, AsyncGenerator, AsyncIterator, Callable, Mapping, Optional, Type, TypeVar, Union, cast
from uuid import uuid4

from opentelemetry import trace
from pydantic import BaseModel
Expand Down Expand Up @@ -199,6 +200,7 @@ def __init__(
load_tools_from_directory: bool = True,
trace_attributes: Optional[Mapping[str, AttributeValue]] = None,
*,
agent_id: Optional[str] = None,
name: Optional[str] = None,
description: Optional[str] = None,
state: Optional[Union[AgentState, dict]] = None,
Expand Down Expand Up @@ -232,6 +234,8 @@ def __init__(
load_tools_from_directory: Whether to load and automatically reload tools in the `./tools/` directory.
Defaults to True.
trace_attributes: Custom trace attributes to apply to the agent's trace span.
agent_id: Optional ID for the agent, useful for multi-agent scenarios.
If None, a UUID is generated.
name: name of the Agent
Defaults to None.
description: description of what the Agent does
Expand All @@ -243,6 +247,9 @@ def __init__(
self.messages = messages if messages is not None else []

self.system_prompt = system_prompt
self.agent_id = agent_id or str(uuid4())
self.name = name or _DEFAULT_AGENT_NAME
self.description = description

# If not provided, create a new PrintingCallbackHandler instance
# If explicitly set to None, use null_callback_handler
Expand Down Expand Up @@ -298,8 +305,6 @@ def __init__(
self.state = AgentState()

self.tool_caller = Agent.ToolCaller(self)
self.name = name or _DEFAULT_AGENT_NAME
self.description = description

self._hooks = HookRegistry()
# Register built-in hook providers (like ConversationManager) here
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def apply_management(self, agent: "Agent") -> None:

if len(messages) <= self.window_size:
logger.debug(
"window_size=<%s>, message_count=<%s> | skipping context reduction", len(messages), self.window_size
"message_count=<%s>, window_size=<%s> | skipping context reduction", len(messages), self.window_size
)
return
self.reduce_context(agent)
Expand Down
10 changes: 9 additions & 1 deletion src/strands/multiagent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,13 @@
"""

from . import a2a
from .base import MultiAgentBase, MultiAgentResult
from .graph import GraphBuilder, GraphResult

__all__ = ["a2a"]
__all__ = [
"a2a",
"GraphBuilder",
"GraphResult",
"MultiAgentBase",
"MultiAgentResult",
]
87 changes: 87 additions & 0 deletions src/strands/multiagent/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""Multi-Agent Base Class.

Provides minimal foundation for multi-agent patterns (Swarm, Graph).
"""

from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from typing import Union

from ..agent import AgentResult
from ..types.event_loop import Metrics, Usage


class Status(Enum):
"""Execution status for both graphs and nodes."""

PENDING = "pending"
EXECUTING = "executing"
COMPLETED = "completed"
FAILED = "failed"


@dataclass
class NodeResult:
"""Unified result from node execution - handles both Agent and nested MultiAgentBase results.

The status field represents the semantic outcome of the node's work:
- COMPLETED: The node's task was successfully accomplished
- FAILED: The node's task failed or produced an error
"""

# Core result data - single AgentResult, nested MultiAgentResult, or Exception
result: Union[AgentResult, "MultiAgentResult", Exception]

# Execution metadata
execution_time: int = 0
status: Status = Status.PENDING

# Accumulated metrics from this node and all children
accumulated_usage: Usage = field(default_factory=lambda: Usage(inputTokens=0, outputTokens=0, totalTokens=0))
accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0))
execution_count: int = 0

def get_agent_results(self) -> list[AgentResult]:
"""Get all AgentResult objects from this node, flattened if nested."""
if isinstance(self.result, Exception):
return [] # No agent results for exceptions
elif isinstance(self.result, AgentResult):
return [self.result]
else:
# Flatten nested results from MultiAgentResult
flattened = []
for nested_node_result in self.result.results.values():
flattened.extend(nested_node_result.get_agent_results())
return flattened


@dataclass
class MultiAgentResult:
"""Result from multi-agent execution with accumulated metrics."""

results: dict[str, NodeResult]
accumulated_usage: Usage = field(default_factory=lambda: Usage(inputTokens=0, outputTokens=0, totalTokens=0))
accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0))
execution_count: int = 0
execution_time: int = 0


class MultiAgentBase(ABC):
"""Base class for multi-agent helpers.

This class integrates with existing Strands Agent instances and provides
multi-agent orchestration capabilities.
"""

@abstractmethod
# TODO: for task - multi-modal input (Message), list of messages
async def execute_async(self, task: str) -> MultiAgentResult:
"""Execute task asynchronously."""
raise NotImplementedError("execute_async not implemented")

@abstractmethod
# TODO: for task - multi-modal input (Message), list of messages
def execute(self, task: str) -> MultiAgentResult:
"""Execute task synchronously."""
raise NotImplementedError("execute not implemented")
Loading