Skip to content

Commit adae77b

Browse files
committed
multi modal input
1 parent 460adc9 commit adae77b

File tree

10 files changed

+162
-42
lines changed

10 files changed

+162
-42
lines changed

src/strands/agent/agent.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -353,14 +353,14 @@ def __del__(self) -> None:
353353
self.thread_pool.shutdown(wait=False)
354354
logger.debug("thread pool executor shutdown complete")
355355

356-
def __call__(self, prompt: str, **kwargs: Any) -> AgentResult:
356+
def __call__(self, prompt: Union[str, list[ContentBlock]], **kwargs: Any) -> AgentResult:
357357
"""Process a natural language prompt through the agent's event loop.
358358
359359
This method implements the conversational interface (e.g., `agent("hello!")`). It adds the user's prompt to
360360
the conversation history, processes it through the model, executes any tool calls, and returns the final result.
361361
362362
Args:
363-
prompt: The natural language prompt from the user.
363+
prompt: User input as text or list of ContentBlock objects for multi-modal content.
364364
**kwargs: Additional parameters to pass through the event loop.
365365
366366
Returns:
@@ -379,14 +379,14 @@ def execute() -> AgentResult:
379379
future = executor.submit(execute)
380380
return future.result()
381381

382-
async def invoke_async(self, prompt: str, **kwargs: Any) -> AgentResult:
382+
async def invoke_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: Any) -> AgentResult:
383383
"""Process a natural language prompt through the agent's event loop.
384384
385385
This method implements the conversational interface (e.g., `agent("hello!")`). It adds the user's prompt to
386386
the conversation history, processes it through the model, executes any tool calls, and returns the final result.
387387
388388
Args:
389-
prompt: The natural language prompt from the user.
389+
prompt: User input as text or list of ContentBlock objects for multi-modal content.
390390
**kwargs: Additional parameters to pass through the event loop.
391391
392392
Returns:
@@ -465,7 +465,7 @@ async def structured_output_async(self, output_model: Type[T], prompt: Optional[
465465
finally:
466466
self._hooks.invoke_callbacks(EndRequestEvent(agent=self))
467467

468-
async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]:
468+
async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: Any) -> AsyncIterator[Any]:
469469
"""Process a natural language prompt and yield events as an async iterator.
470470
471471
This method provides an asynchronous interface for streaming agent events, allowing
@@ -474,7 +474,7 @@ async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]:
474474
async environments.
475475
476476
Args:
477-
prompt: The natural language prompt from the user.
477+
prompt: User input as text or list of ContentBlock objects for multi-modal content.
478478
**kwargs: Additional parameters to pass to the event loop.
479479
480480
Returns:
@@ -497,10 +497,13 @@ async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]:
497497
"""
498498
callback_handler = kwargs.get("callback_handler", self.callback_handler)
499499

500-
self._start_agent_trace_span(prompt)
500+
content: list[ContentBlock] = [{"text": prompt}] if isinstance(prompt, str) else prompt
501+
message: Message = {"role": "user", "content": content}
502+
503+
self._start_agent_trace_span(message)
501504

502505
try:
503-
events = self._run_loop(prompt, kwargs)
506+
events = self._run_loop(message, kwargs)
504507
async for event in events:
505508
if "callback" in event:
506509
callback_handler(**event["callback"])
@@ -516,18 +519,22 @@ async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]:
516519
self._end_agent_trace_span(error=e)
517520
raise
518521

519-
async def _run_loop(self, prompt: str, kwargs: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
520-
"""Execute the agent's event loop with the given prompt and parameters."""
522+
async def _run_loop(self, message: Message, kwargs: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
523+
"""Execute the agent's event loop with the given message and parameters.
524+
525+
Args:
526+
message: The user message to add to the conversation.
527+
kwargs: Additional parameters to pass to the event loop.
528+
529+
Yields:
530+
Events from the event loop cycle.
531+
"""
521532
self._hooks.invoke_callbacks(StartRequestEvent(agent=self))
522533

523534
try:
524-
# Extract key parameters
525535
yield {"callback": {"init_event_loop": True, **kwargs}}
526536

527-
# Set up the user message with optional knowledge base retrieval
528-
message_content: list[ContentBlock] = [{"text": prompt}]
529-
new_message: Message = {"role": "user", "content": message_content}
530-
self.messages.append(new_message)
537+
self.messages.append(message)
531538

532539
# Execute the event loop cycle with retry logic for context limits
533540
events = self._execute_event_loop_cycle(kwargs)
@@ -622,16 +629,16 @@ def _record_tool_execution(
622629
messages.append(tool_result_msg)
623630
messages.append(assistant_msg)
624631

625-
def _start_agent_trace_span(self, prompt: str) -> None:
632+
def _start_agent_trace_span(self, message: Message) -> None:
626633
"""Starts a trace span for the agent.
627634
628635
Args:
629-
prompt: The natural language prompt from the user.
636+
message: The user message.
630637
"""
631638
model_id = self.model.config.get("model_id") if hasattr(self.model, "config") else None
632639

633640
self.trace_span = self.tracer.start_agent_span(
634-
prompt=prompt,
641+
message=message,
635642
agent_name=self.name,
636643
model_id=model_id,
637644
tools=self.tool_names,

src/strands/telemetry/tracer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ def end_event_loop_cycle_span(
405405

406406
def start_agent_span(
407407
self,
408-
prompt: str,
408+
message: Message,
409409
agent_name: str,
410410
model_id: Optional[str] = None,
411411
tools: Optional[list] = None,
@@ -415,7 +415,7 @@ def start_agent_span(
415415
"""Start a new span for an agent invocation.
416416
417417
Args:
418-
prompt: The user prompt being sent to the agent.
418+
message: The user message being sent to the agent.
419419
agent_name: Name of the agent.
420420
model_id: Optional model identifier.
421421
tools: Optional list of tools being used.
@@ -452,7 +452,7 @@ def start_agent_span(
452452
span,
453453
"gen_ai.user.message",
454454
event_attributes={
455-
"content": prompt,
455+
"content": serialize(message["content"]),
456456
},
457457
)
458458

tests-integ/conftest.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
import pytest
22

3+
## Data
4+
5+
6+
@pytest.fixture
7+
def yellow_img(pytestconfig):
8+
path = pytestconfig.rootdir / "tests-integ/yellow.png"
9+
with open(path, "rb") as fp:
10+
return fp.read()
11+
12+
313
## Async
414

515

tests-integ/test_model_anthropic.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,22 @@ class Weather(BaseModel):
6161
assert isinstance(result, Weather)
6262
assert result.time == "12:00"
6363
assert result.weather == "sunny"
64+
65+
66+
@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY environment variable missing")
67+
def test_multi_modal_input(agent, yellow_img):
68+
content = [
69+
{"text": "what is in this image"},
70+
{
71+
"image": {
72+
"format": "png",
73+
"source": {
74+
"bytes": yellow_img,
75+
},
76+
},
77+
},
78+
]
79+
result = agent(content)
80+
text = result.message["content"][0]["text"].lower()
81+
82+
assert "yellow" in text

tests-integ/test_model_bedrock.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,3 +151,21 @@ class Weather(BaseModel):
151151
assert isinstance(result, Weather)
152152
assert result.time == "12:00"
153153
assert result.weather == "sunny"
154+
155+
156+
def test_multi_modal_input(streaming_agent, yellow_img):
157+
content = [
158+
{"text": "what is in this image"},
159+
{
160+
"image": {
161+
"format": "png",
162+
"source": {
163+
"bytes": yellow_img,
164+
},
165+
},
166+
},
167+
]
168+
result = streaming_agent(content)
169+
text = result.message["content"][0]["text"].lower()
170+
171+
assert "yellow" in text

tests-integ/test_model_litellm.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,21 @@ class Weather(BaseModel):
4747
assert isinstance(result, Weather)
4848
assert result.time == "12:00"
4949
assert result.weather == "sunny"
50+
51+
52+
def test_multi_modal_input(agent, yellow_img):
53+
content = [
54+
{"text": "what is in this image"},
55+
{
56+
"image": {
57+
"format": "png",
58+
"source": {
59+
"bytes": yellow_img,
60+
},
61+
},
62+
},
63+
]
64+
result = agent(content)
65+
text = result.message["content"][0]["text"].lower()
66+
67+
assert "yellow" in text

tests-integ/test_model_openai.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,19 +96,34 @@ async def test_agent_structured_output_async(agent, weather):
9696
assert tru_weather == exp_weather
9797

9898

99-
def test_tool_returning_images(model, test_image_path):
99+
def test_multi_modal_input(agent, yellow_img):
100+
content = [
101+
{"text": "what is in this image"},
102+
{
103+
"image": {
104+
"format": "png",
105+
"source": {
106+
"bytes": yellow_img,
107+
},
108+
},
109+
},
110+
]
111+
result = agent(content)
112+
text = result.message["content"][0]["text"].lower()
113+
114+
assert "yellow" in text
115+
116+
117+
def test_tool_returning_images(model, yellow_img):
100118
@tool
101119
def tool_with_image_return():
102-
with open(test_image_path, "rb") as image_file:
103-
encoded_image = image_file.read()
104-
105120
return {
106121
"status": "success",
107122
"content": [
108123
{
109124
"image": {
110125
"format": "png",
111-
"source": {"bytes": encoded_image},
126+
"source": {"bytes": yellow_img},
112127
}
113128
},
114129
],
File renamed without changes.

tests/strands/agent/test_agent.py

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,6 +1017,39 @@ async def test_event_loop(*args, **kwargs):
10171017
mock_callback.assert_has_calls(exp_calls)
10181018

10191019

1020+
@pytest.mark.asyncio
1021+
async def test_stream_async_multi_modal_input(mock_model, agent, agenerator, alist):
1022+
mock_model.mock_converse.return_value = agenerator(
1023+
[
1024+
{"contentBlockDelta": {"delta": {"text": "I see text and an image"}}},
1025+
{"contentBlockStop": {}},
1026+
{"messageStop": {"stopReason": "end_turn"}},
1027+
]
1028+
)
1029+
1030+
prompt = [
1031+
{"text": "This is a description of the image:"},
1032+
{
1033+
"image": {
1034+
"format": "png",
1035+
"source": {
1036+
"bytes": b"\x89PNG\r\n\x1a\n",
1037+
},
1038+
}
1039+
},
1040+
]
1041+
1042+
stream = agent.stream_async(prompt)
1043+
await alist(stream)
1044+
1045+
tru_message = agent.messages
1046+
exp_message = [
1047+
{"content": prompt, "role": "user"},
1048+
{"content": [{"text": "I see text and an image"}], "role": "assistant"},
1049+
]
1050+
assert tru_message == exp_message
1051+
1052+
10201053
@pytest.mark.asyncio
10211054
async def test_stream_async_passes_kwargs(agent, mock_model, mock_event_loop_cycle, agenerator, alist):
10221055
mock_model.mock_converse.side_effect = [
@@ -1150,12 +1183,12 @@ def test_agent_call_creates_and_ends_span_on_success(mock_get_tracer, mock_model
11501183

11511184
# Verify span was created
11521185
mock_tracer.start_agent_span.assert_called_once_with(
1153-
prompt="test prompt",
11541186
agent_name="Strands Agents",
1187+
custom_trace_attributes=agent.trace_attributes,
1188+
message={"content": [{"text": "test prompt"}], "role": "user"},
11551189
model_id=unittest.mock.ANY,
1156-
tools=agent.tool_names,
11571190
system_prompt=agent.system_prompt,
1158-
custom_trace_attributes=agent.trace_attributes,
1191+
tools=agent.tool_names,
11591192
)
11601193

11611194
# Verify span was ended with the result
@@ -1184,12 +1217,12 @@ async def test_event_loop(*args, **kwargs):
11841217

11851218
# Verify span was created
11861219
mock_tracer.start_agent_span.assert_called_once_with(
1187-
prompt="test prompt",
1220+
custom_trace_attributes=agent.trace_attributes,
11881221
agent_name="Strands Agents",
1222+
message={"content": [{"text": "test prompt"}], "role": "user"},
11891223
model_id=unittest.mock.ANY,
1190-
tools=agent.tool_names,
11911224
system_prompt=agent.system_prompt,
1192-
custom_trace_attributes=agent.trace_attributes,
1225+
tools=agent.tool_names,
11931226
)
11941227

11951228
expected_response = AgentResult(
@@ -1222,12 +1255,12 @@ def test_agent_call_creates_and_ends_span_on_exception(mock_get_tracer, mock_mod
12221255

12231256
# Verify span was created
12241257
mock_tracer.start_agent_span.assert_called_once_with(
1225-
prompt="test prompt",
1258+
custom_trace_attributes=agent.trace_attributes,
12261259
agent_name="Strands Agents",
1260+
message={"content": [{"text": "test prompt"}], "role": "user"},
12271261
model_id=unittest.mock.ANY,
1228-
tools=agent.tool_names,
12291262
system_prompt=agent.system_prompt,
1230-
custom_trace_attributes=agent.trace_attributes,
1263+
tools=agent.tool_names,
12311264
)
12321265

12331266
# Verify span was ended with the exception
@@ -1258,12 +1291,12 @@ async def test_agent_stream_async_creates_and_ends_span_on_exception(mock_get_tr
12581291

12591292
# Verify span was created
12601293
mock_tracer.start_agent_span.assert_called_once_with(
1261-
prompt="test prompt",
12621294
agent_name="Strands Agents",
1295+
custom_trace_attributes=agent.trace_attributes,
1296+
message={"content": [{"text": "test prompt"}], "role": "user"},
12631297
model_id=unittest.mock.ANY,
1264-
tools=agent.tool_names,
12651298
system_prompt=agent.system_prompt,
1266-
custom_trace_attributes=agent.trace_attributes,
1299+
tools=agent.tool_names,
12671300
)
12681301

12691302
# Verify span was ended with the exception

0 commit comments

Comments
 (0)