Skip to content

Commit d97fcb5

Browse files
authored
multi modal input (#367)
1 parent f78b03a commit d97fcb5

File tree

11 files changed

+162
-43
lines changed

11 files changed

+162
-43
lines changed

src/strands/agent/agent.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -344,14 +344,14 @@ def __del__(self) -> None:
344344
self.thread_pool.shutdown(wait=False)
345345
logger.debug("thread pool executor shutdown complete")
346346

347-
def __call__(self, prompt: str, **kwargs: Any) -> AgentResult:
347+
def __call__(self, prompt: Union[str, list[ContentBlock]], **kwargs: Any) -> AgentResult:
348348
"""Process a natural language prompt through the agent's event loop.
349349
350350
This method implements the conversational interface (e.g., `agent("hello!")`). It adds the user's prompt to
351351
the conversation history, processes it through the model, executes any tool calls, and returns the final result.
352352
353353
Args:
354-
prompt: The natural language prompt from the user.
354+
prompt: User input as text or list of ContentBlock objects for multi-modal content.
355355
**kwargs: Additional parameters to pass through the event loop.
356356
357357
Returns:
@@ -370,14 +370,14 @@ def execute() -> AgentResult:
370370
future = executor.submit(execute)
371371
return future.result()
372372

373-
async def invoke_async(self, prompt: str, **kwargs: Any) -> AgentResult:
373+
async def invoke_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: Any) -> AgentResult:
374374
"""Process a natural language prompt through the agent's event loop.
375375
376376
This method implements the conversational interface (e.g., `agent("hello!")`). It adds the user's prompt to
377377
the conversation history, processes it through the model, executes any tool calls, and returns the final result.
378378
379379
Args:
380-
prompt: The natural language prompt from the user.
380+
prompt: User input as text or list of ContentBlock objects for multi-modal content.
381381
**kwargs: Additional parameters to pass through the event loop.
382382
383383
Returns:
@@ -456,7 +456,7 @@ async def structured_output_async(self, output_model: Type[T], prompt: Optional[
456456
finally:
457457
self._hooks.invoke_callbacks(EndRequestEvent(agent=self))
458458

459-
async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]:
459+
async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: Any) -> AsyncIterator[Any]:
460460
"""Process a natural language prompt and yield events as an async iterator.
461461
462462
This method provides an asynchronous interface for streaming agent events, allowing
@@ -465,7 +465,7 @@ async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]:
465465
async environments.
466466
467467
Args:
468-
prompt: The natural language prompt from the user.
468+
prompt: User input as text or list of ContentBlock objects for multi-modal content.
469469
**kwargs: Additional parameters to pass to the event loop.
470470
471471
Returns:
@@ -488,10 +488,13 @@ async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]:
488488
"""
489489
callback_handler = kwargs.get("callback_handler", self.callback_handler)
490490

491-
self._start_agent_trace_span(prompt)
491+
content: list[ContentBlock] = [{"text": prompt}] if isinstance(prompt, str) else prompt
492+
message: Message = {"role": "user", "content": content}
493+
494+
self._start_agent_trace_span(message)
492495

493496
try:
494-
events = self._run_loop(prompt, kwargs)
497+
events = self._run_loop(message, kwargs)
495498
async for event in events:
496499
if "callback" in event:
497500
callback_handler(**event["callback"])
@@ -507,18 +510,22 @@ async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]:
507510
self._end_agent_trace_span(error=e)
508511
raise
509512

510-
async def _run_loop(self, prompt: str, kwargs: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
511-
"""Execute the agent's event loop with the given prompt and parameters."""
513+
async def _run_loop(self, message: Message, kwargs: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
514+
"""Execute the agent's event loop with the given message and parameters.
515+
516+
Args:
517+
message: The user message to add to the conversation.
518+
kwargs: Additional parameters to pass to the event loop.
519+
520+
Yields:
521+
Events from the event loop cycle.
522+
"""
512523
self._hooks.invoke_callbacks(StartRequestEvent(agent=self))
513524

514525
try:
515-
# Extract key parameters
516526
yield {"callback": {"init_event_loop": True, **kwargs}}
517527

518-
# Set up the user message with optional knowledge base retrieval
519-
message_content: list[ContentBlock] = [{"text": prompt}]
520-
new_message: Message = {"role": "user", "content": message_content}
521-
self.messages.append(new_message)
528+
self.messages.append(message)
522529

523530
# Execute the event loop cycle with retry logic for context limits
524531
events = self._execute_event_loop_cycle(kwargs)
@@ -613,16 +620,16 @@ def _record_tool_execution(
613620
messages.append(tool_result_msg)
614621
messages.append(assistant_msg)
615622

616-
def _start_agent_trace_span(self, prompt: str) -> None:
623+
def _start_agent_trace_span(self, message: Message) -> None:
617624
"""Starts a trace span for the agent.
618625
619626
Args:
620-
prompt: The natural language prompt from the user.
627+
message: The user message.
621628
"""
622629
model_id = self.model.config.get("model_id") if hasattr(self.model, "config") else None
623630

624631
self.trace_span = self.tracer.start_agent_span(
625-
prompt=prompt,
632+
message=message,
626633
agent_name=self.name,
627634
model_id=model_id,
628635
tools=self.tool_names,

src/strands/telemetry/tracer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ def end_event_loop_cycle_span(
407407

408408
def start_agent_span(
409409
self,
410-
prompt: str,
410+
message: Message,
411411
agent_name: str,
412412
model_id: Optional[str] = None,
413413
tools: Optional[list] = None,
@@ -417,7 +417,7 @@ def start_agent_span(
417417
"""Start a new span for an agent invocation.
418418
419419
Args:
420-
prompt: The user prompt being sent to the agent.
420+
message: The user message being sent to the agent.
421421
agent_name: Name of the agent.
422422
model_id: Optional model identifier.
423423
tools: Optional list of tools being used.
@@ -454,7 +454,7 @@ def start_agent_span(
454454
span,
455455
"gen_ai.user.message",
456456
event_attributes={
457-
"content": prompt,
457+
"content": serialize(message["content"]),
458458
},
459459
)
460460

tests/strands/agent/test_agent.py

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,6 +1017,39 @@ async def test_event_loop(*args, **kwargs):
10171017
mock_callback.assert_has_calls(exp_calls)
10181018

10191019

1020+
@pytest.mark.asyncio
1021+
async def test_stream_async_multi_modal_input(mock_model, agent, agenerator, alist):
1022+
mock_model.mock_converse.return_value = agenerator(
1023+
[
1024+
{"contentBlockDelta": {"delta": {"text": "I see text and an image"}}},
1025+
{"contentBlockStop": {}},
1026+
{"messageStop": {"stopReason": "end_turn"}},
1027+
]
1028+
)
1029+
1030+
prompt = [
1031+
{"text": "This is a description of the image:"},
1032+
{
1033+
"image": {
1034+
"format": "png",
1035+
"source": {
1036+
"bytes": b"\x89PNG\r\n\x1a\n",
1037+
},
1038+
}
1039+
},
1040+
]
1041+
1042+
stream = agent.stream_async(prompt)
1043+
await alist(stream)
1044+
1045+
tru_message = agent.messages
1046+
exp_message = [
1047+
{"content": prompt, "role": "user"},
1048+
{"content": [{"text": "I see text and an image"}], "role": "assistant"},
1049+
]
1050+
assert tru_message == exp_message
1051+
1052+
10201053
@pytest.mark.asyncio
10211054
async def test_stream_async_passes_kwargs(agent, mock_model, mock_event_loop_cycle, agenerator, alist):
10221055
mock_model.mock_converse.side_effect = [
@@ -1150,12 +1183,12 @@ def test_agent_call_creates_and_ends_span_on_success(mock_get_tracer, mock_model
11501183

11511184
# Verify span was created
11521185
mock_tracer.start_agent_span.assert_called_once_with(
1153-
prompt="test prompt",
11541186
agent_name="Strands Agents",
1187+
custom_trace_attributes=agent.trace_attributes,
1188+
message={"content": [{"text": "test prompt"}], "role": "user"},
11551189
model_id=unittest.mock.ANY,
1156-
tools=agent.tool_names,
11571190
system_prompt=agent.system_prompt,
1158-
custom_trace_attributes=agent.trace_attributes,
1191+
tools=agent.tool_names,
11591192
)
11601193

11611194
# Verify span was ended with the result
@@ -1184,12 +1217,12 @@ async def test_event_loop(*args, **kwargs):
11841217

11851218
# Verify span was created
11861219
mock_tracer.start_agent_span.assert_called_once_with(
1187-
prompt="test prompt",
1220+
custom_trace_attributes=agent.trace_attributes,
11881221
agent_name="Strands Agents",
1222+
message={"content": [{"text": "test prompt"}], "role": "user"},
11891223
model_id=unittest.mock.ANY,
1190-
tools=agent.tool_names,
11911224
system_prompt=agent.system_prompt,
1192-
custom_trace_attributes=agent.trace_attributes,
1225+
tools=agent.tool_names,
11931226
)
11941227

11951228
expected_response = AgentResult(
@@ -1222,12 +1255,12 @@ def test_agent_call_creates_and_ends_span_on_exception(mock_get_tracer, mock_mod
12221255

12231256
# Verify span was created
12241257
mock_tracer.start_agent_span.assert_called_once_with(
1225-
prompt="test prompt",
1258+
custom_trace_attributes=agent.trace_attributes,
12261259
agent_name="Strands Agents",
1260+
message={"content": [{"text": "test prompt"}], "role": "user"},
12271261
model_id=unittest.mock.ANY,
1228-
tools=agent.tool_names,
12291262
system_prompt=agent.system_prompt,
1230-
custom_trace_attributes=agent.trace_attributes,
1263+
tools=agent.tool_names,
12311264
)
12321265

12331266
# Verify span was ended with the exception
@@ -1258,12 +1291,12 @@ async def test_agent_stream_async_creates_and_ends_span_on_exception(mock_get_tr
12581291

12591292
# Verify span was created
12601293
mock_tracer.start_agent_span.assert_called_once_with(
1261-
prompt="test prompt",
12621294
agent_name="Strands Agents",
1295+
custom_trace_attributes=agent.trace_attributes,
1296+
message={"content": [{"text": "test prompt"}], "role": "user"},
12631297
model_id=unittest.mock.ANY,
1264-
tools=agent.tool_names,
12651298
system_prompt=agent.system_prompt,
1266-
custom_trace_attributes=agent.trace_attributes,
1299+
tools=agent.tool_names,
12671300
)
12681301

12691302
# Verify span was ended with the exception

tests/strands/telemetry/test_tracer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -276,17 +276,17 @@ def test_start_agent_span(mock_tracer):
276276
mock_span = mock.MagicMock()
277277
mock_tracer.start_span.return_value = mock_span
278278

279-
prompt = "What's the weather today?"
279+
content = [{"text": "test prompt"}]
280280
model_id = "test-model"
281281
tools = [{"name": "weather_tool"}]
282282
custom_attrs = {"custom_attr": "value"}
283283

284284
span = tracer.start_agent_span(
285-
prompt=prompt,
285+
custom_trace_attributes=custom_attrs,
286286
agent_name="WeatherAgent",
287+
message={"content": content, "role": "user"},
287288
model_id=model_id,
288289
tools=tools,
289-
custom_trace_attributes=custom_attrs,
290290
)
291291

292292
mock_tracer.start_span.assert_called_once()
@@ -295,7 +295,7 @@ def test_start_agent_span(mock_tracer):
295295
mock_span.set_attribute.assert_any_call("gen_ai.agent.name", "WeatherAgent")
296296
mock_span.set_attribute.assert_any_call("gen_ai.request.model", model_id)
297297
mock_span.set_attribute.assert_any_call("custom_attr", "value")
298-
mock_span.add_event.assert_any_call("gen_ai.user.message", attributes={"content": prompt})
298+
mock_span.add_event.assert_any_call("gen_ai.user.message", attributes={"content": json.dumps(content)})
299299
assert span is not None
300300

301301

tests_integ/conftest.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
import pytest
22

3+
## Data
4+
5+
6+
@pytest.fixture
7+
def yellow_img(pytestconfig):
8+
path = pytestconfig.rootdir / "tests_integ/yellow.png"
9+
with open(path, "rb") as fp:
10+
return fp.read()
11+
12+
313
## Async
414

515

tests_integ/models/test_model_anthropic.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,21 @@ async def test_agent_structured_output_async(agent, weather):
9595
tru_weather = await agent.structured_output_async(type(weather), "The time is 12:00 and the weather is sunny")
9696
exp_weather = weather
9797
assert tru_weather == exp_weather
98+
99+
100+
def test_multi_modal_input(agent, yellow_img):
101+
content = [
102+
{"text": "what is in this image"},
103+
{
104+
"image": {
105+
"format": "png",
106+
"source": {
107+
"bytes": yellow_img,
108+
},
109+
},
110+
},
111+
]
112+
result = agent(content)
113+
text = result.message["content"][0]["text"].lower()
114+
115+
assert "yellow" in text

tests_integ/models/test_model_bedrock.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,3 +151,21 @@ class Weather(BaseModel):
151151
assert isinstance(result, Weather)
152152
assert result.time == "12:00"
153153
assert result.weather == "sunny"
154+
155+
156+
def test_multi_modal_input(streaming_agent, yellow_img):
157+
content = [
158+
{"text": "what is in this image"},
159+
{
160+
"image": {
161+
"format": "png",
162+
"source": {
163+
"bytes": yellow_img,
164+
},
165+
},
166+
},
167+
]
168+
result = streaming_agent(content)
169+
text = result.message["content"][0]["text"].lower()
170+
171+
assert "yellow" in text

tests_integ/models/test_model_litellm.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,21 @@ class Weather(BaseModel):
4747
assert isinstance(result, Weather)
4848
assert result.time == "12:00"
4949
assert result.weather == "sunny"
50+
51+
52+
def test_multi_modal_input(agent, yellow_img):
53+
content = [
54+
{"text": "what is in this image"},
55+
{
56+
"image": {
57+
"format": "png",
58+
"source": {
59+
"bytes": yellow_img,
60+
},
61+
},
62+
},
63+
]
64+
result = agent(content)
65+
text = result.message["content"][0]["text"].lower()
66+
67+
assert "yellow" in text

0 commit comments

Comments
 (0)