Skip to content

Commit 19db55c

Browse files
authored
feat(multi-agent): introduce Graph multi-agent orchestrator (#336)
1 parent 89d261e commit 19db55c

File tree

11 files changed

+1371
-36
lines changed

11 files changed

+1371
-36
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@ __pycache__*
88
.ruff_cache
99
*.bak
1010
.vscode
11-
dist
11+
dist
12+
repl_state

pyproject.toml

Lines changed: 54 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ dev = [
5858
"pre-commit>=3.2.0,<4.2.0",
5959
"pytest>=8.0.0,<9.0.0",
6060
"pytest-asyncio>=0.26.0,<0.27.0",
61+
"pytest-cov>=4.1.0,<5.0.0",
62+
"pytest-xdist>=3.0.0,<4.0.0",
6163
"ruff>=0.4.4,<0.5.0",
6264
]
6365
docs = [
@@ -94,13 +96,59 @@ a2a = [
9496
"fastapi>=0.115.12",
9597
"starlette>=0.46.2",
9698
]
99+
all = [
100+
# anthropic
101+
"anthropic>=0.21.0,<1.0.0",
102+
103+
# dev
104+
"commitizen>=4.4.0,<5.0.0",
105+
"hatch>=1.0.0,<2.0.0",
106+
"moto>=5.1.0,<6.0.0",
107+
"mypy>=1.15.0,<2.0.0",
108+
"pre-commit>=3.2.0,<4.2.0",
109+
"pytest>=8.0.0,<9.0.0",
110+
"pytest-asyncio>=0.26.0,<0.27.0",
111+
"pytest-cov>=4.1.0,<5.0.0",
112+
"pytest-xdist>=3.0.0,<4.0.0",
113+
"ruff>=0.4.4,<0.5.0",
114+
115+
# docs
116+
"sphinx>=5.0.0,<6.0.0",
117+
"sphinx-rtd-theme>=1.0.0,<2.0.0",
118+
"sphinx-autodoc-typehints>=1.12.0,<2.0.0",
119+
120+
# litellm
121+
"litellm>=1.72.6,<1.73.0",
122+
123+
# llama
124+
"llama-api-client>=0.1.0,<1.0.0",
125+
126+
# mistral
127+
"mistralai>=1.8.2",
128+
129+
# ollama
130+
"ollama>=0.4.8,<1.0.0",
131+
132+
# openai
133+
"openai>=1.68.0,<2.0.0",
134+
135+
# otel
136+
"opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0",
137+
138+
# a2a
139+
"a2a-sdk[sql]>=0.2.11",
140+
"uvicorn>=0.34.2",
141+
"httpx>=0.28.1",
142+
"fastapi>=0.115.12",
143+
"starlette>=0.46.2",
144+
]
97145

98146
[tool.hatch.version]
99147
# Tells Hatch to use your version control system (git) to determine the version.
100148
source = "vcs"
101149

102150
[tool.hatch.envs.hatch-static-analysis]
103-
features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "writer"]
151+
features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "writer", "a2a"]
104152
dependencies = [
105153
"mypy>=1.15.0,<2.0.0",
106154
"ruff>=0.11.6,<0.12.0",
@@ -116,15 +164,14 @@ format-fix = [
116164
]
117165
lint-check = [
118166
"ruff check",
119-
# excluding due to A2A and OTEL http exporter dependency conflict
120-
"mypy -p src --exclude src/strands/multiagent"
167+
"mypy -p src"
121168
]
122169
lint-fix = [
123170
"ruff check --fix"
124171
]
125172

126173
[tool.hatch.envs.hatch-test]
127-
features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "writer"]
174+
features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "writer", "a2a"]
128175
extra-dependencies = [
129176
"moto>=5.1.0,<6.0.0",
130177
"pytest>=8.0.0,<9.0.0",
@@ -140,35 +187,17 @@ extra-args = [
140187

141188
[tool.hatch.envs.dev]
142189
dev-mode = true
143-
features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "otel", "mistral", "writer"]
144-
145-
[tool.hatch.envs.a2a]
146-
dev-mode = true
147-
features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "a2a"]
148-
149-
[tool.hatch.envs.a2a.scripts]
150-
run = [
151-
"pytest{env:HATCH_TEST_ARGS:} tests/strands/multiagent/a2a {args}"
152-
]
153-
run-cov = [
154-
"pytest{env:HATCH_TEST_ARGS:} tests/strands/multiagent/a2a --cov --cov-config=pyproject.toml {args}"
155-
]
156-
lint-check = [
157-
"ruff check",
158-
"mypy -p src/strands/multiagent/a2a"
159-
]
190+
features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "otel", "mistral", "writer", "a2a"]
160191

161192
[[tool.hatch.envs.hatch-test.matrix]]
162193
python = ["3.13", "3.12", "3.11", "3.10"]
163194

164195
[tool.hatch.envs.hatch-test.scripts]
165196
run = [
166-
# excluding due to A2A and OTEL http exporter dependency conflict
167-
"pytest{env:HATCH_TEST_ARGS:} {args} --ignore=tests/strands/multiagent/a2a"
197+
"pytest{env:HATCH_TEST_ARGS:} {args}"
168198
]
169199
run-cov = [
170-
# excluding due to A2A and OTEL http exporter dependency conflict
171-
"pytest{env:HATCH_TEST_ARGS:} --cov --cov-config=pyproject.toml {args} --ignore=tests/strands/multiagent/a2a"
200+
"pytest{env:HATCH_TEST_ARGS:} --cov --cov-config=pyproject.toml {args}"
172201
]
173202

174203
cov-combine = []
@@ -203,10 +232,6 @@ prepare = [
203232
"hatch run test-lint",
204233
"hatch test --all"
205234
]
206-
test-a2a = [
207-
# required to run manually due to A2A and OTEL http exporter dependency conflict
208-
"hatch -e a2a run run {args}"
209-
]
210235

211236
[tool.mypy]
212237
python_version = "3.10"

src/strands/agent/agent.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import random
1616
from concurrent.futures import ThreadPoolExecutor
1717
from typing import Any, AsyncGenerator, AsyncIterator, Callable, Mapping, Optional, Type, TypeVar, Union, cast
18+
from uuid import uuid4
1819

1920
from opentelemetry import trace
2021
from pydantic import BaseModel
@@ -200,6 +201,7 @@ def __init__(
200201
load_tools_from_directory: bool = True,
201202
trace_attributes: Optional[Mapping[str, AttributeValue]] = None,
202203
*,
204+
agent_id: Optional[str] = None,
203205
name: Optional[str] = None,
204206
description: Optional[str] = None,
205207
state: Optional[Union[AgentState, dict]] = None,
@@ -234,6 +236,8 @@ def __init__(
234236
load_tools_from_directory: Whether to load and automatically reload tools in the `./tools/` directory.
235237
Defaults to True.
236238
trace_attributes: Custom trace attributes to apply to the agent's trace span.
239+
agent_id: Optional ID for the agent, useful for multi-agent scenarios.
240+
If None, a UUID is generated.
237241
name: name of the Agent
238242
Defaults to None.
239243
description: description of what the Agent does
@@ -247,6 +251,9 @@ def __init__(
247251
self.messages = messages if messages is not None else []
248252

249253
self.system_prompt = system_prompt
254+
self.agent_id = agent_id or str(uuid4())
255+
self.name = name or _DEFAULT_AGENT_NAME
256+
self.description = description
250257

251258
# If not provided, create a new PrintingCallbackHandler instance
252259
# If explicitly set to None, use null_callback_handler
@@ -302,8 +309,6 @@ def __init__(
302309
self.state = AgentState()
303310

304311
self.tool_caller = Agent.ToolCaller(self)
305-
self.name = name or _DEFAULT_AGENT_NAME
306-
self.description = description
307312

308313
self.hooks = HookRegistry()
309314
if hooks:

src/strands/agent/conversation_manager/sliding_window_conversation_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def apply_management(self, agent: "Agent") -> None:
7575

7676
if len(messages) <= self.window_size:
7777
logger.debug(
78-
"window_size=<%s>, message_count=<%s> | skipping context reduction", len(messages), self.window_size
78+
"message_count=<%s>, window_size=<%s> | skipping context reduction", len(messages), self.window_size
7979
)
8080
return
8181
self.reduce_context(agent)

src/strands/multiagent/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,13 @@
99
"""
1010

1111
from . import a2a
12+
from .base import MultiAgentBase, MultiAgentResult
13+
from .graph import GraphBuilder, GraphResult
1214

13-
__all__ = ["a2a"]
15+
__all__ = [
16+
"a2a",
17+
"GraphBuilder",
18+
"GraphResult",
19+
"MultiAgentBase",
20+
"MultiAgentResult",
21+
]

src/strands/multiagent/base.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
"""Multi-Agent Base Class.
2+
3+
Provides minimal foundation for multi-agent patterns (Swarm, Graph).
4+
"""
5+
6+
from abc import ABC, abstractmethod
7+
from dataclasses import dataclass, field
8+
from enum import Enum
9+
from typing import Union
10+
11+
from ..agent import AgentResult
12+
from ..types.event_loop import Metrics, Usage
13+
14+
15+
class Status(Enum):
16+
"""Execution status for both graphs and nodes."""
17+
18+
PENDING = "pending"
19+
EXECUTING = "executing"
20+
COMPLETED = "completed"
21+
FAILED = "failed"
22+
23+
24+
@dataclass
25+
class NodeResult:
26+
"""Unified result from node execution - handles both Agent and nested MultiAgentBase results.
27+
28+
The status field represents the semantic outcome of the node's work:
29+
- COMPLETED: The node's task was successfully accomplished
30+
- FAILED: The node's task failed or produced an error
31+
"""
32+
33+
# Core result data - single AgentResult, nested MultiAgentResult, or Exception
34+
result: Union[AgentResult, "MultiAgentResult", Exception]
35+
36+
# Execution metadata
37+
execution_time: int = 0
38+
status: Status = Status.PENDING
39+
40+
# Accumulated metrics from this node and all children
41+
accumulated_usage: Usage = field(default_factory=lambda: Usage(inputTokens=0, outputTokens=0, totalTokens=0))
42+
accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0))
43+
execution_count: int = 0
44+
45+
def get_agent_results(self) -> list[AgentResult]:
46+
"""Get all AgentResult objects from this node, flattened if nested."""
47+
if isinstance(self.result, Exception):
48+
return [] # No agent results for exceptions
49+
elif isinstance(self.result, AgentResult):
50+
return [self.result]
51+
else:
52+
# Flatten nested results from MultiAgentResult
53+
flattened = []
54+
for nested_node_result in self.result.results.values():
55+
flattened.extend(nested_node_result.get_agent_results())
56+
return flattened
57+
58+
59+
@dataclass
60+
class MultiAgentResult:
61+
"""Result from multi-agent execution with accumulated metrics."""
62+
63+
results: dict[str, NodeResult]
64+
accumulated_usage: Usage = field(default_factory=lambda: Usage(inputTokens=0, outputTokens=0, totalTokens=0))
65+
accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0))
66+
execution_count: int = 0
67+
execution_time: int = 0
68+
69+
70+
class MultiAgentBase(ABC):
71+
"""Base class for multi-agent helpers.
72+
73+
This class integrates with existing Strands Agent instances and provides
74+
multi-agent orchestration capabilities.
75+
"""
76+
77+
@abstractmethod
78+
# TODO: for task - multi-modal input (Message), list of messages
79+
async def execute_async(self, task: str) -> MultiAgentResult:
80+
"""Execute task asynchronously."""
81+
raise NotImplementedError("execute_async not implemented")
82+
83+
@abstractmethod
84+
# TODO: for task - multi-modal input (Message), list of messages
85+
def execute(self, task: str) -> MultiAgentResult:
86+
"""Execute task synchronously."""
87+
raise NotImplementedError("execute not implemented")

0 commit comments

Comments
 (0)