Skip to content

feat: multimodal support in AmazonBedrockChatGenerator #307

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
May 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,17 @@ on:
- "pyproject.toml"
- ".github/workflows/tests.yml"


permissions:
id-token: write
contents: read
env:
PYTHON_VERSION: "3.9"
HATCH_VERSION: "1.14.1"
PYTHONUNBUFFERED: "1"
FORCE_COLOR: "1"
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}

AWS_REGION: "us-east-1"
jobs:
linting:
runs-on: ubuntu-latest
Expand Down Expand Up @@ -112,5 +116,10 @@ jobs:
- name: Install Hatch
run: pip install hatch==${{ env.HATCH_VERSION }}

- name: AWS authentication
uses: aws-actions/configure-aws-credentials@b47578312673ae6fa5b5096b330d9fbac3d116df
with:
aws-region: ${{ env.AWS_REGION }}
role-to-assume: ${{ secrets.AWS_CI_ROLE_ARN }}
- name: Run
run: hatch run test:integration
1 change: 1 addition & 0 deletions docs/pydoc/config/generators_api.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ loaders:
modules:
[
"haystack_experimental.components.generators.chat.openai",
"haystack_experimental.components.generators.chat.amazon_bedrock",
]
ignore_when_discovered: ["__init__"]
processors:
Expand Down
2 changes: 1 addition & 1 deletion haystack_experimental/components/generators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@

_import_structure = {
"openai": ["OpenAIChatGenerator"],
"amazon_bedrock": ["AmazonBedrockChatGenerator"],
}

if TYPE_CHECKING:
from .amazon_bedrock import AmazonBedrockChatGenerator
from .openai import OpenAIChatGenerator

else:
sys.modules[__name__] = LazyImporter(name=__name__, module_file=__file__, import_structure=_import_structure)
sys.modules[__name__] = LazyImporter(name=__name__, module_file=__file__, import_structure=_import_structure)
149 changes: 149 additions & 0 deletions haystack_experimental/components/generators/chat/amazon_bedrock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0

import base64
from typing import Any, Dict, List, Optional, Tuple, Union

from haystack import component
from haystack.dataclasses.streaming_chunk import StreamingCallbackT
from haystack.lazy_imports import LazyImport
from haystack.tools import Tool, Toolset
from haystack.utils.auth import Secret

from haystack_experimental.dataclasses.chat_message import ChatMessage, ChatRole, ImageContent, TextContent

with LazyImport("Run 'pip install amazon-bedrock-haystack'") as bedrock_integration_import:
import haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator as original_chat_generator
import haystack_integrations.components.generators.amazon_bedrock.chat.utils as original_utils
from haystack_integrations.components.generators.amazon_bedrock.chat.utils import (
_format_tool_call_message,
_format_tool_result_message,
_repair_tool_result_messages,
)

# NOTE: The following implementation ensures that:
# - we reuse existing code where possible
# - people can use haystack-experimental without installing amazon-bedrock-haystack.
#
#
# If amazon-bedrock-haystack is installed: all works correctly.
#
# If amazon-bedrock-haystack is not installed:
# - haystack-experimental package works fine (no import errors).
# - AmazonBedrockChatGenerator fails with ImportError at init (due to bedrock_integration_import.check()).

if not bedrock_integration_import.is_successful():
@component
class AmazonBedrockChatGenerator:
"""
Experimental version of AmazonBedrockChatGenerator that allows multimodal chat messages.
"""
def __init__( # pylint: disable=too-many-positional-arguments
self,
model: str,
aws_access_key_id: Optional[Secret] = Secret.from_env_var(["AWS_ACCESS_KEY_ID"], strict=False), # noqa: B008
aws_secret_access_key: Optional[Secret] = Secret.from_env_var( # noqa: B008
["AWS_SECRET_ACCESS_KEY"], strict=False
),
aws_session_token: Optional[Secret] = Secret.from_env_var(["AWS_SESSION_TOKEN"], strict=False), # noqa: B008
aws_region_name: Optional[Secret] = Secret.from_env_var(["AWS_DEFAULT_REGION"], strict=False), # noqa: B008
aws_profile_name: Optional[Secret] = Secret.from_env_var(["AWS_PROFILE"], strict=False), # noqa: B008
generation_kwargs: Optional[Dict[str, Any]] = None,
stop_words: Optional[List[str]] = None,
streaming_callback: Optional[StreamingCallbackT] = None,
boto3_config: Optional[Dict[str, Any]] = None,
tools: Optional[Union[List[Tool], Toolset]] = None,
) -> None:
bedrock_integration_import.check() # this always fails

@component.output_types(replies=List[ChatMessage])
def run(
self,
messages: List[ChatMessage],
streaming_callback: Optional[StreamingCallbackT] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
tools: Optional[Union[List[Tool], Toolset]] = None,
) -> Dict[str, List[ChatMessage]]:
"""
Executes a synchronous inference call to the Amazon Bedrock model using the Converse API.
"""

# NOTE: placeholder run method needed to make component happy
raise NotImplementedError("Unreachable code")
else:
# see https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ImageBlock.html for supported formats
IMAGE_SUPPORTED_FORMATS = ["png", "jpeg", "gif", "webp"]

# NOTE: this is the new function needed to support images
def _format_text_image_message(message: ChatMessage) -> Dict[str, Any]:
"""
Format a Haystack ChatMessage containing text and optional image content into Bedrock format.

:param message: Haystack ChatMessage.
:returns: Dictionary representing the message in Bedrock's expected format.
:raises ValueError: If image content is found in an assistant message or an unsupported image format is used.
"""
content_parts = message._content

bedrock_content_blocks = []
for part in content_parts:
if isinstance(part, TextContent):
bedrock_content_blocks.append({"text": part.text})

elif isinstance(part, ImageContent):
if message.is_from(ChatRole.ASSISTANT):
err_msg = "Image content is not supported for assistant messages"
raise ValueError(err_msg)

image_format = part.mime_type.split("/")[-1] if part.mime_type else None
if image_format not in IMAGE_SUPPORTED_FORMATS:
err_msg = (
f"Unsupported image format: {image_format}. "
f"Bedrock supports the following image formats: {IMAGE_SUPPORTED_FORMATS}"
)
raise ValueError(err_msg)
source = {"bytes": base64.b64decode(part.base64_image)}
bedrock_content_blocks.append({"image": {"format": image_format, "source": source}})

return {"role": message.role.value, "content": bedrock_content_blocks}

# NOTE: this is reimplemented in order to call the new _format_text_image_message function
def _format_messages(messages: List[ChatMessage]) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
"""
Format a list of Haystack ChatMessages to the format expected by Bedrock API.

Processes and separates system messages from other message types and handles special formatting for tool calls
and tool results.

:param messages: List of ChatMessage objects to format for Bedrock API.
:returns: Tuple containing (system_prompts, non_system_messages) in Bedrock format,
where system_prompts is a list of system message dictionaries and
non_system_messages is a list of properly formatted message dictionaries.
"""
# Separate system messages, tool calls, and tool results
system_prompts = []
bedrock_formatted_messages = []
for msg in messages:
if msg.is_from(ChatRole.SYSTEM):
# Assuming system messages can only contain text
# Don't need to track idx since system_messages are handled separately
system_prompts.append({"text": msg.text})
elif msg.tool_calls:
bedrock_formatted_messages.append(_format_tool_call_message(msg))
elif msg.tool_call_results:
bedrock_formatted_messages.append(_format_tool_result_message(msg))
else:
bedrock_formatted_messages.append(_format_text_image_message(msg))

repaired_bedrock_formatted_messages = _repair_tool_result_messages(bedrock_formatted_messages)
return system_prompts, repaired_bedrock_formatted_messages

# NOTE: monkey patches needed to use the new ChatMessage dataclass and _format_messages function
original_utils.ChatMessage = ChatMessage
original_chat_generator.ChatMessage = ChatMessage
original_chat_generator._format_messages = _format_messages

@component
class AmazonBedrockChatGenerator(original_chat_generator.AmazonBedrockChatGenerator): # type: ignore[no-redef]
pass
9 changes: 5 additions & 4 deletions haystack_experimental/components/generators/chat/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
# SPDX-License-Identifier: Apache-2.0

import haystack.components.generators.chat.openai
from haystack_experimental.dataclasses.chat_message import ChatMessage
from haystack import component

# Monkey patch the Haystack ChatMessage class with the experimental one. By doing so, we can use the new
from haystack_experimental.dataclasses.chat_message import ChatMessage

# Monkey patch the Haystack ChatMessage class with the experimental one. By doing so, we can use the new
# `to_openai_dict_format` method, allowing multimodal chat messages.
haystack.components.generators.chat.openai.ChatMessage = ChatMessage

Expand All @@ -16,5 +17,5 @@ class OpenAIChatGenerator(haystack.components.generators.chat.openai.OpenAIChatG
Experimental version of OpenAIChatGenerator that allows multimodal chat messages.
"""
pass


1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ extra-dependencies = [
"arrow>=1.3.0", # Multimodal experiment - ChatPromptBuilder
"pypdfium2", # Multimodal experiment - PDFToImageContent
"pillow", # Multimodal experiment - ImageFileToImageContent, PDFToImageContent
"amazon-bedrock-haystack>=3.6.2", # Multimodal experiment - AmazonBedrockChatGenerator
]

[tool.hatch.envs.test.scripts]
Expand Down
26 changes: 26 additions & 0 deletions test/components/generators/chat/amazon_bedrock/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from unittest.mock import patch

import pytest


@pytest.fixture
def set_env_variables(monkeypatch):
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "some_fake_id")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "some_fake_key")
monkeypatch.setenv("AWS_SESSION_TOKEN", "some_fake_token")
monkeypatch.setenv("AWS_DEFAULT_REGION", "fake_region")
monkeypatch.setenv("AWS_PROFILE", "some_fake_profile")


# create a fixture with mocked boto3 client and session
@pytest.fixture
def mock_boto3_session():
with patch("boto3.Session") as mock_client:
yield mock_client


# create a fixture with mocked aioboto3 client and session
@pytest.fixture
def mock_aioboto3_session():
with patch("aioboto3.Session") as mock_client:
yield mock_client
Loading
Loading