Skip to content

multi modal input #367

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 25 additions & 18 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,14 +344,14 @@ def __del__(self) -> None:
self.thread_pool.shutdown(wait=False)
logger.debug("thread pool executor shutdown complete")

def __call__(self, prompt: str, **kwargs: Any) -> AgentResult:
def __call__(self, prompt: Union[str, list[ContentBlock]], **kwargs: Any) -> AgentResult:
"""Process a natural language prompt through the agent's event loop.

This method implements the conversational interface (e.g., `agent("hello!")`). It adds the user's prompt to
the conversation history, processes it through the model, executes any tool calls, and returns the final result.

Args:
prompt: The natural language prompt from the user.
prompt: User input as text or list of ContentBlock objects for multi-modal content.
**kwargs: Additional parameters to pass through the event loop.

Returns:
Expand All @@ -370,14 +370,14 @@ def execute() -> AgentResult:
future = executor.submit(execute)
return future.result()

async def invoke_async(self, prompt: str, **kwargs: Any) -> AgentResult:
async def invoke_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: Any) -> AgentResult:
"""Process a natural language prompt through the agent's event loop.

This method implements the conversational interface (e.g., `agent("hello!")`). It adds the user's prompt to
the conversation history, processes it through the model, executes any tool calls, and returns the final result.

Args:
prompt: The natural language prompt from the user.
prompt: User input as text or list of ContentBlock objects for multi-modal content.
**kwargs: Additional parameters to pass through the event loop.

Returns:
Expand Down Expand Up @@ -456,7 +456,7 @@ async def structured_output_async(self, output_model: Type[T], prompt: Optional[
finally:
self._hooks.invoke_callbacks(EndRequestEvent(agent=self))

async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]:
async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: Any) -> AsyncIterator[Any]:
"""Process a natural language prompt and yield events as an async iterator.

This method provides an asynchronous interface for streaming agent events, allowing
Expand All @@ -465,7 +465,7 @@ async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]:
async environments.

Args:
prompt: The natural language prompt from the user.
prompt: User input as text or list of ContentBlock objects for multi-modal content.
**kwargs: Additional parameters to pass to the event loop.

Returns:
Expand All @@ -488,10 +488,13 @@ async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]:
"""
callback_handler = kwargs.get("callback_handler", self.callback_handler)

self._start_agent_trace_span(prompt)
content: list[ContentBlock] = [{"text": prompt}] if isinstance(prompt, str) else prompt
message: Message = {"role": "user", "content": content}

self._start_agent_trace_span(message)

try:
events = self._run_loop(prompt, kwargs)
events = self._run_loop(message, kwargs)
async for event in events:
if "callback" in event:
callback_handler(**event["callback"])
Expand All @@ -507,18 +510,22 @@ async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]:
self._end_agent_trace_span(error=e)
raise

async def _run_loop(self, prompt: str, kwargs: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
"""Execute the agent's event loop with the given prompt and parameters."""
async def _run_loop(self, message: Message, kwargs: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
"""Execute the agent's event loop with the given message and parameters.

Args:
message: The user message to add to the conversation.
kwargs: Additional parameters to pass to the event loop.

Yields:
Events from the event loop cycle.
"""
self._hooks.invoke_callbacks(StartRequestEvent(agent=self))

try:
# Extract key parameters
yield {"callback": {"init_event_loop": True, **kwargs}}

# Set up the user message with optional knowledge base retrieval
message_content: list[ContentBlock] = [{"text": prompt}]
new_message: Message = {"role": "user", "content": message_content}
self.messages.append(new_message)
self.messages.append(message)

# Execute the event loop cycle with retry logic for context limits
events = self._execute_event_loop_cycle(kwargs)
Expand Down Expand Up @@ -613,16 +620,16 @@ def _record_tool_execution(
messages.append(tool_result_msg)
messages.append(assistant_msg)

def _start_agent_trace_span(self, prompt: str) -> None:
def _start_agent_trace_span(self, message: Message) -> None:
"""Starts a trace span for the agent.

Args:
prompt: The natural language prompt from the user.
message: The user message.
"""
model_id = self.model.config.get("model_id") if hasattr(self.model, "config") else None

self.trace_span = self.tracer.start_agent_span(
prompt=prompt,
message=message,
agent_name=self.name,
model_id=model_id,
tools=self.tool_names,
Expand Down
6 changes: 3 additions & 3 deletions src/strands/telemetry/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def end_event_loop_cycle_span(

def start_agent_span(
self,
prompt: str,
message: Message,
agent_name: str,
model_id: Optional[str] = None,
tools: Optional[list] = None,
Expand All @@ -417,7 +417,7 @@ def start_agent_span(
"""Start a new span for an agent invocation.

Args:
prompt: The user prompt being sent to the agent.
message: The user message being sent to the agent.
agent_name: Name of the agent.
model_id: Optional model identifier.
tools: Optional list of tools being used.
Expand Down Expand Up @@ -454,7 +454,7 @@ def start_agent_span(
span,
"gen_ai.user.message",
event_attributes={
"content": prompt,
"content": serialize(message["content"]),
},
)

Expand Down
57 changes: 45 additions & 12 deletions tests/strands/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,39 @@ async def test_event_loop(*args, **kwargs):
mock_callback.assert_has_calls(exp_calls)


@pytest.mark.asyncio
async def test_stream_async_multi_modal_input(mock_model, agent, agenerator, alist):
mock_model.mock_converse.return_value = agenerator(
[
{"contentBlockDelta": {"delta": {"text": "I see text and an image"}}},
{"contentBlockStop": {}},
{"messageStop": {"stopReason": "end_turn"}},
]
)

prompt = [
{"text": "This is a description of the image:"},
{
"image": {
"format": "png",
"source": {
"bytes": b"\x89PNG\r\n\x1a\n",
},
}
},
]

stream = agent.stream_async(prompt)
await alist(stream)

tru_message = agent.messages
exp_message = [
{"content": prompt, "role": "user"},
{"content": [{"text": "I see text and an image"}], "role": "assistant"},
]
assert tru_message == exp_message


@pytest.mark.asyncio
async def test_stream_async_passes_kwargs(agent, mock_model, mock_event_loop_cycle, agenerator, alist):
mock_model.mock_converse.side_effect = [
Expand Down Expand Up @@ -1150,12 +1183,12 @@ def test_agent_call_creates_and_ends_span_on_success(mock_get_tracer, mock_model

# Verify span was created
mock_tracer.start_agent_span.assert_called_once_with(
prompt="test prompt",
agent_name="Strands Agents",
custom_trace_attributes=agent.trace_attributes,
message={"content": [{"text": "test prompt"}], "role": "user"},
model_id=unittest.mock.ANY,
tools=agent.tool_names,
system_prompt=agent.system_prompt,
custom_trace_attributes=agent.trace_attributes,
tools=agent.tool_names,
)

# Verify span was ended with the result
Expand Down Expand Up @@ -1184,12 +1217,12 @@ async def test_event_loop(*args, **kwargs):

# Verify span was created
mock_tracer.start_agent_span.assert_called_once_with(
prompt="test prompt",
custom_trace_attributes=agent.trace_attributes,
agent_name="Strands Agents",
message={"content": [{"text": "test prompt"}], "role": "user"},
model_id=unittest.mock.ANY,
tools=agent.tool_names,
system_prompt=agent.system_prompt,
custom_trace_attributes=agent.trace_attributes,
tools=agent.tool_names,
)

expected_response = AgentResult(
Expand Down Expand Up @@ -1222,12 +1255,12 @@ def test_agent_call_creates_and_ends_span_on_exception(mock_get_tracer, mock_mod

# Verify span was created
mock_tracer.start_agent_span.assert_called_once_with(
prompt="test prompt",
custom_trace_attributes=agent.trace_attributes,
agent_name="Strands Agents",
message={"content": [{"text": "test prompt"}], "role": "user"},
model_id=unittest.mock.ANY,
tools=agent.tool_names,
system_prompt=agent.system_prompt,
custom_trace_attributes=agent.trace_attributes,
tools=agent.tool_names,
)

# Verify span was ended with the exception
Expand Down Expand Up @@ -1258,12 +1291,12 @@ async def test_agent_stream_async_creates_and_ends_span_on_exception(mock_get_tr

# Verify span was created
mock_tracer.start_agent_span.assert_called_once_with(
prompt="test prompt",
agent_name="Strands Agents",
custom_trace_attributes=agent.trace_attributes,
message={"content": [{"text": "test prompt"}], "role": "user"},
model_id=unittest.mock.ANY,
tools=agent.tool_names,
system_prompt=agent.system_prompt,
custom_trace_attributes=agent.trace_attributes,
tools=agent.tool_names,
)

# Verify span was ended with the exception
Expand Down
8 changes: 4 additions & 4 deletions tests/strands/telemetry/test_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,17 +276,17 @@ def test_start_agent_span(mock_tracer):
mock_span = mock.MagicMock()
mock_tracer.start_span.return_value = mock_span

prompt = "What's the weather today?"
content = [{"text": "test prompt"}]
model_id = "test-model"
tools = [{"name": "weather_tool"}]
custom_attrs = {"custom_attr": "value"}

span = tracer.start_agent_span(
prompt=prompt,
custom_trace_attributes=custom_attrs,
agent_name="WeatherAgent",
message={"content": content, "role": "user"},
model_id=model_id,
tools=tools,
custom_trace_attributes=custom_attrs,
)

mock_tracer.start_span.assert_called_once()
Expand All @@ -295,7 +295,7 @@ def test_start_agent_span(mock_tracer):
mock_span.set_attribute.assert_any_call("gen_ai.agent.name", "WeatherAgent")
mock_span.set_attribute.assert_any_call("gen_ai.request.model", model_id)
mock_span.set_attribute.assert_any_call("custom_attr", "value")
mock_span.add_event.assert_any_call("gen_ai.user.message", attributes={"content": prompt})
mock_span.add_event.assert_any_call("gen_ai.user.message", attributes={"content": json.dumps(content)})
assert span is not None


Expand Down
10 changes: 10 additions & 0 deletions tests_integ/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
import pytest

## Data


@pytest.fixture
def yellow_img(pytestconfig):
path = pytestconfig.rootdir / "tests_integ/yellow.png"
with open(path, "rb") as fp:
return fp.read()


## Async


Expand Down
18 changes: 18 additions & 0 deletions tests_integ/models/test_model_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,21 @@ async def test_agent_structured_output_async(agent, weather):
tru_weather = await agent.structured_output_async(type(weather), "The time is 12:00 and the weather is sunny")
exp_weather = weather
assert tru_weather == exp_weather


def test_multi_modal_input(agent, yellow_img):
content = [
{"text": "what is in this image"},
{
"image": {
"format": "png",
"source": {
"bytes": yellow_img,
},
},
},
]
result = agent(content)
text = result.message["content"][0]["text"].lower()

assert "yellow" in text
18 changes: 18 additions & 0 deletions tests_integ/models/test_model_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,21 @@ class Weather(BaseModel):
assert isinstance(result, Weather)
assert result.time == "12:00"
assert result.weather == "sunny"


def test_multi_modal_input(streaming_agent, yellow_img):
content = [
{"text": "what is in this image"},
{
"image": {
"format": "png",
"source": {
"bytes": yellow_img,
},
},
},
]
result = streaming_agent(content)
text = result.message["content"][0]["text"].lower()

assert "yellow" in text
18 changes: 18 additions & 0 deletions tests_integ/models/test_model_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,21 @@ class Weather(BaseModel):
assert isinstance(result, Weather)
assert result.time == "12:00"
assert result.weather == "sunny"


def test_multi_modal_input(agent, yellow_img):
content = [
{"text": "what is in this image"},
{
"image": {
"format": "png",
"source": {
"bytes": yellow_img,
},
},
},
]
result = agent(content)
text = result.message["content"][0]["text"].lower()

assert "yellow" in text
Loading