Skip to content

Commit da9153a

Browse files
authored
feat(multiagent): Graph - support multi-modal inputs (#430)
1 parent 812b1d3 commit da9153a

File tree

4 files changed

+85
-21
lines changed

4 files changed

+85
-21
lines changed

src/strands/multiagent/base.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import Union
1010

1111
from ..agent import AgentResult
12+
from ..types.content import ContentBlock
1213
from ..types.event_loop import Metrics, Usage
1314

1415

@@ -75,13 +76,11 @@ class MultiAgentBase(ABC):
7576
"""
7677

7778
@abstractmethod
78-
# TODO: for task - multi-modal input (Message), list of messages
79-
async def execute_async(self, task: str) -> MultiAgentResult:
79+
async def execute_async(self, task: str | list[ContentBlock]) -> MultiAgentResult:
8080
"""Execute task asynchronously."""
8181
raise NotImplementedError("execute_async not implemented")
8282

8383
@abstractmethod
84-
# TODO: for task - multi-modal input (Message), list of messages
85-
def execute(self, task: str) -> MultiAgentResult:
84+
def execute(self, task: str | list[ContentBlock]) -> MultiAgentResult:
8685
"""Execute task synchronously."""
8786
raise NotImplementedError("execute not implemented")

src/strands/multiagent/graph.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from typing import Any, Callable, Tuple, cast
2323

2424
from ..agent import Agent, AgentResult
25+
from ..types.content import ContentBlock
2526
from ..types.event_loop import Metrics, Usage
2627
from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status
2728

@@ -42,12 +43,14 @@ class GraphState:
4243
Entry point nodes receive this task as their input if they have no dependencies.
4344
"""
4445

46+
# Task (with default empty string)
47+
task: str | list[ContentBlock] = ""
48+
4549
# Execution state
4650
status: Status = Status.PENDING
4751
completed_nodes: set["GraphNode"] = field(default_factory=set)
4852
failed_nodes: set["GraphNode"] = field(default_factory=set)
4953
execution_order: list["GraphNode"] = field(default_factory=list)
50-
task: str = ""
5154

5255
# Results
5356
results: dict[str, NodeResult] = field(default_factory=dict)
@@ -247,7 +250,7 @@ def __init__(self, nodes: dict[str, GraphNode], edges: set[GraphEdge], entry_poi
247250
self.entry_points = entry_points
248251
self.state = GraphState()
249252

250-
def execute(self, task: str) -> GraphResult:
253+
def execute(self, task: str | list[ContentBlock]) -> GraphResult:
251254
"""Execute task synchronously."""
252255

253256
def execute() -> GraphResult:
@@ -257,7 +260,7 @@ def execute() -> GraphResult:
257260
future = executor.submit(execute)
258261
return future.result()
259262

260-
async def execute_async(self, task: str) -> GraphResult:
263+
async def execute_async(self, task: str | list[ContentBlock]) -> GraphResult:
261264
"""Execute the graph asynchronously."""
262265
logger.debug("task=<%s> | starting graph execution", task)
263266

@@ -435,8 +438,8 @@ def _accumulate_metrics(self, node_result: NodeResult) -> None:
435438
self.state.accumulated_metrics["latencyMs"] += node_result.accumulated_metrics.get("latencyMs", 0)
436439
self.state.execution_count += node_result.execution_count
437440

438-
def _build_node_input(self, node: GraphNode) -> str:
439-
"""Build input text for a node based on dependency outputs."""
441+
def _build_node_input(self, node: GraphNode) -> list[ContentBlock]:
442+
"""Build input for a node based on dependency outputs."""
440443
# Get satisfied dependencies
441444
dependency_results = {}
442445
for edge in self.edges:
@@ -449,21 +452,36 @@ def _build_node_input(self, node: GraphNode) -> str:
449452
dependency_results[edge.from_node.node_id] = self.state.results[edge.from_node.node_id]
450453

451454
if not dependency_results:
452-
return self.state.task
455+
# No dependencies - return task as ContentBlocks
456+
if isinstance(self.state.task, str):
457+
return [ContentBlock(text=self.state.task)]
458+
else:
459+
return self.state.task
453460

454461
# Combine task with dependency outputs
455-
input_parts = [f"Original Task: {self.state.task}", "\nInputs from previous nodes:"]
462+
node_input = []
463+
464+
# Add original task
465+
if isinstance(self.state.task, str):
466+
node_input.append(ContentBlock(text=f"Original Task: {self.state.task}"))
467+
else:
468+
# Add task content blocks with a prefix
469+
node_input.append(ContentBlock(text="Original Task:"))
470+
node_input.extend(self.state.task)
471+
472+
# Add dependency outputs
473+
node_input.append(ContentBlock(text="\nInputs from previous nodes:"))
456474

457475
for dep_id, node_result in dependency_results.items():
458-
input_parts.append(f"\nFrom {dep_id}:")
476+
node_input.append(ContentBlock(text=f"\nFrom {dep_id}:"))
459477
# Get all agent results from this node (flattened if nested)
460478
agent_results = node_result.get_agent_results()
461479
for result in agent_results:
462480
agent_name = getattr(result, "agent_name", "Agent")
463481
result_text = str(result)
464-
input_parts.append(f" - {agent_name}: {result_text}")
482+
node_input.append(ContentBlock(text=f" - {agent_name}: {result_text}"))
465483

466-
return "\n".join(input_parts)
484+
return node_input
467485

468486
def _build_result(self) -> GraphResult:
469487
"""Build graph result from current state."""

tests/strands/multiagent/test_graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,10 +273,10 @@ async def test_graph_edge_cases():
273273
builder.add_node(entry_agent, "entry_only")
274274
graph = builder.build()
275275

276-
result = await graph.execute_async("Original task")
276+
result = await graph.execute_async([{"text": "Original task"}])
277277

278278
# Verify entry node was called with original task
279-
entry_agent.stream_async.assert_called_once_with("Original task")
279+
entry_agent.stream_async.assert_called_once_with([{"text": "Original task"}])
280280
assert result.status == Status.COMPLETED
281281

282282

tests_integ/test_multiagent_graph.py

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from strands import Agent, tool
44
from strands.multiagent.graph import GraphBuilder
5+
from strands.types.content import ContentBlock
56

67

78
@tool
@@ -23,7 +24,6 @@ def math_agent():
2324
model="us.amazon.nova-pro-v1:0",
2425
system_prompt="You are a mathematical assistant. Always provide clear, step-by-step calculations.",
2526
tools=[calculate_sum, multiply_numbers],
26-
load_tools_from_directory=False,
2727
)
2828

2929

@@ -33,7 +33,6 @@ def analysis_agent():
3333
return Agent(
3434
model="us.amazon.nova-pro-v1:0",
3535
system_prompt="You are a data analysis expert. Provide insights and interpretations of numerical results.",
36-
load_tools_from_directory=False,
3736
)
3837

3938

@@ -43,7 +42,6 @@ def summary_agent():
4342
return Agent(
4443
model="us.amazon.nova-lite-v1:0",
4544
system_prompt="You are a summarization expert. Create concise, clear summaries of complex information.",
46-
load_tools_from_directory=False,
4745
)
4846

4947

@@ -53,7 +51,16 @@ def validation_agent():
5351
return Agent(
5452
model="us.amazon.nova-pro-v1:0",
5553
system_prompt="You are a validation expert. Check results for accuracy and completeness.",
56-
load_tools_from_directory=False,
54+
)
55+
56+
57+
@pytest.fixture
58+
def image_analysis_agent():
59+
"""Create an agent specialized in image analysis."""
60+
return Agent(
61+
system_prompt=(
62+
"You are an image analysis expert. Describe what you see in images and provide detailed analysis."
63+
)
5764
)
5865

5966

@@ -74,7 +81,7 @@ def nested_computation_graph(math_agent, analysis_agent):
7481

7582

7683
@pytest.mark.asyncio
77-
async def test_graph_execution(math_agent, summary_agent, validation_agent, nested_computation_graph):
84+
async def test_graph_execution_with_string(math_agent, summary_agent, validation_agent, nested_computation_graph):
7885
# Define conditional functions
7986
def should_validate(state):
8087
"""Condition to determine if validation should run."""
@@ -131,3 +138,43 @@ def proceed_to_second_summary(state):
131138
# Verify nested graph execution
132139
nested_result = result.results["computation_subgraph"].result
133140
assert nested_result.status.value == "completed"
141+
142+
143+
@pytest.mark.asyncio
144+
async def test_graph_execution_with_image(image_analysis_agent, summary_agent, yellow_img):
145+
"""Test graph execution with multi-modal image input."""
146+
builder = GraphBuilder()
147+
148+
# Add agents to graph
149+
builder.add_node(image_analysis_agent, "image_analyzer")
150+
builder.add_node(summary_agent, "summarizer")
151+
152+
# Connect them sequentially
153+
builder.add_edge("image_analyzer", "summarizer")
154+
builder.set_entry_point("image_analyzer")
155+
156+
graph = builder.build()
157+
158+
# Create content blocks with text and image
159+
content_blocks: list[ContentBlock] = [
160+
{"text": "Analyze this image and describe what you see:"},
161+
{"image": {"format": "png", "source": {"bytes": yellow_img}}},
162+
]
163+
164+
# Execute the graph with multi-modal input
165+
result = await graph.execute_async(content_blocks)
166+
167+
# Verify results
168+
assert result.status.value == "completed"
169+
assert result.total_nodes == 2
170+
assert result.completed_nodes == 2
171+
assert result.failed_nodes == 0
172+
assert len(result.results) == 2
173+
174+
# Verify execution order
175+
execution_order_ids = [node.node_id for node in result.execution_order]
176+
assert execution_order_ids == ["image_analyzer", "summarizer"]
177+
178+
# Verify both nodes completed
179+
assert "image_analyzer" in result.results
180+
assert "summarizer" in result.results

0 commit comments

Comments
 (0)