Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
_parse_completion_response,
_parse_streaming_response,
_parse_streaming_response_async,
_validate_and_format_cache_point,
_validate_guardrail_config,
)

Expand Down Expand Up @@ -160,6 +161,7 @@ def __init__(
tools: ToolsType | None = None,
*,
guardrail_config: dict[str, str] | None = None,
tools_cachepoint_config: dict[str, str] | None = None,
) -> None:
"""
Initializes the `AmazonBedrockChatGenerator` with the provided parameters. The parameters are passed to the
Expand Down Expand Up @@ -201,6 +203,9 @@ def __init__(
See the
[Guardrails Streaming documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/guardrails-streaming.html)
for more information.
:param tools_cachepoint_config: Optional configuration to use prompt caching for tools.
The dictionary must match the
[CachePointBlock schema](https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CachePointBlock.html).


:raises ValueError: If the model name is empty or None.
Expand All @@ -225,6 +230,10 @@ def __init__(
_validate_guardrail_config(guardrail_config=guardrail_config, streaming=streaming_callback is not None)
self.guardrail_config = guardrail_config

if tools_cachepoint_config:
_validate_and_format_cache_point(tools_cachepoint_config)
self.tools_cachepoint_config = tools_cachepoint_config

def resolve_secret(secret: Secret | None) -> str | None:
return secret.resolve_value() if secret else None

Expand Down Expand Up @@ -310,6 +319,7 @@ def to_dict(self) -> dict[str, Any]:
boto3_config=self.boto3_config,
tools=serialize_tools_or_toolset(self.tools),
guardrail_config=self.guardrail_config,
tools_cachepoint_config=self.tools_cachepoint_config,
)

@classmethod
Expand Down Expand Up @@ -389,7 +399,7 @@ def _prepare_request_params(
tool_config = merged_kwargs.pop("toolConfig", None)
if flattened_tools:
# Format Haystack tools to Bedrock format
tool_config = _format_tools(flattened_tools)
tool_config = _format_tools(flattened_tools, tools_cachepoint_config=self.tools_cachepoint_config)

# Any remaining kwargs go to additionalModelRequestFields
additional_fields = merged_kwargs if merged_kwargs else None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@


# Haystack to Bedrock util methods
def _format_tools(tools: list[Tool] | None = None) -> dict[str, Any] | None:
def _format_tools(
tools: list[Tool] | None = None, tools_cachepoint_config: dict[str, Any] | None = None
) -> dict[str, Any] | None:
"""
Format Haystack Tool(s) to Amazon Bedrock toolConfig format.

Expand All @@ -57,7 +59,10 @@ def _format_tools(tools: list[Tool] | None = None) -> dict[str, Any] | None:
{"toolSpec": {"name": tool.name, "description": tool.description, "inputSchema": {"json": tool.parameters}}}
)

return {"tools": tool_specs} if tool_specs else None
if tools_cachepoint_config:
tool_specs.append({"cachePoint": tools_cachepoint_config})

return {"tools": tool_specs}


def _convert_image_content_to_bedrock_format(image_content: ImageContent) -> dict[str, Any]:
Expand Down Expand Up @@ -181,20 +186,23 @@ def _repair_tool_result_messages(bedrock_formatted_messages: list[dict[str, Any]
original_idx = None
for tool_call_id in tool_call_ids:
for idx, tool_result in tool_result_messages:
tool_result_contents = [c for c in tool_result["content"] if "toolResult" in c]
tool_result_contents = [c for c in tool_result["content"] if "toolResult" in c or "cachePoint" in c]
for content in tool_result_contents:
if content["toolResult"]["toolUseId"] == tool_call_id:
if "toolResult" in content and content["toolResult"]["toolUseId"] == tool_call_id:
regrouped_tool_result.append(content)
# Keep track of the original index of the last tool result message
original_idx = idx
elif "cachePoint" in content and content not in regrouped_tool_result:
regrouped_tool_result.append(content)

if regrouped_tool_result and original_idx is not None:
repaired_tool_result_prompts.append((original_idx, {"role": "user", "content": regrouped_tool_result}))

# Remove the tool result messages from bedrock_formatted_messages
bedrock_formatted_messages_minus_tool_results: list[tuple[int, Any]] = []
for idx, msg in enumerate(bedrock_formatted_messages):
# Assumes the content of tool result messages only contains 'toolResult': {...} objects (e.g. no 'text')
if msg.get("content") and "toolResult" not in msg["content"][0]:
# Filter out messages that contain toolResult (they are handled by repaired_tool_result_prompts)
if msg.get("content") and not any("toolResult" in c for c in msg["content"]):
bedrock_formatted_messages_minus_tool_results.append((idx, msg))

# Add the repaired tool result messages and sort to maintain the correct order
Expand Down Expand Up @@ -251,6 +259,29 @@ def _format_text_image_message(message: ChatMessage) -> dict[str, Any]:
return {"role": message.role.value, "content": bedrock_content_blocks}


def _validate_and_format_cache_point(cache_point: dict[str, str] | None) -> dict[str, Any] | None:
"""
Validate and format a cache point dictionary.

Schema available at https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CachePointBlock.html

:param cache_point: Cache point dictionary to validate and format.
:returns: Dictionary in Bedrock cachePoint format or None if no cache point is provided.
:raises ValueError: If cache point is not valid.
"""
if not cache_point:
return None

if "type" not in cache_point or cache_point["type"] != "default":
err_msg = "Cache point must have a 'type' key with value 'default'."
raise ValueError(err_msg)
if not set(cache_point).issubset({"type", "ttl"}):
err_msg = "Cache point can only contain 'type' and 'ttl' keys."
raise ValueError(err_msg)

return {"cachePoint": cache_point}


def _format_messages(messages: list[ChatMessage]) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
"""
Format a list of Haystack ChatMessages to the format expected by Bedrock API.
Expand All @@ -267,18 +298,27 @@ def _format_messages(messages: list[ChatMessage]) -> tuple[list[dict[str, Any]],
system_prompts = []
bedrock_formatted_messages = []
for msg in messages:
cache_point = _validate_and_format_cache_point(msg.meta.get("cachePoint"))
if msg.is_from(ChatRole.SYSTEM):
# Assuming system messages can only contain text
# Don't need to track idx since system_messages are handled separately
system_prompts.append({"text": msg.text})
elif msg.tool_calls:
bedrock_formatted_messages.append(_format_tool_call_message(msg))
if cache_point:
system_prompts.append(cache_point)
continue

if msg.tool_calls:
formatted_msg = _format_tool_call_message(msg)
elif msg.tool_call_results:
bedrock_formatted_messages.append(_format_tool_result_message(msg))
formatted_msg = _format_tool_result_message(msg)
else:
bedrock_formatted_messages.append(_format_text_image_message(msg))
formatted_msg = _format_text_image_message(msg)
if cache_point:
formatted_msg["content"].append(cache_point)
bedrock_formatted_messages.append(formatted_msg)

repaired_bedrock_formatted_messages = _repair_tool_result_messages(bedrock_formatted_messages)

return system_prompts, repaired_bedrock_formatted_messages


Expand Down Expand Up @@ -310,6 +350,9 @@ def _parse_completion_response(response_body: dict[str, Any], model: str) -> lis
"prompt_tokens": response_body.get("usage", {}).get("inputTokens", 0),
"completion_tokens": response_body.get("usage", {}).get("outputTokens", 0),
"total_tokens": response_body.get("usage", {}).get("totalTokens", 0),
"cache_read_input_tokens": response_body.get("usage", {}).get("cacheReadInputTokens", 0),
"cache_write_input_tokens": response_body.get("usage", {}).get("cacheWriteInputTokens", 0),
"cache_details": response_body.get("usage", {}).get("CacheDetails", {}),
},
}
# guardrail trace
Expand Down Expand Up @@ -461,6 +504,9 @@ def _convert_event_to_streaming_chunk(
"prompt_tokens": usage.get("inputTokens", 0),
"completion_tokens": usage.get("outputTokens", 0),
"total_tokens": usage.get("totalTokens", 0),
"cache_read_input_tokens": usage.get("cacheReadInputTokens", 0),
"cache_write_input_tokens": usage.get("cacheWriteInputTokens", 0),
"cache_details": usage.get("cacheDetails", {}),
}
if "trace" in event_meta:
chunk_meta["trace"] = event_meta["trace"]
Expand Down
28 changes: 28 additions & 0 deletions integrations/amazon_bedrock/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@
"us.anthropic.claude-sonnet-4-20250514-v1:0",
]

MODELS_TO_TEST_WITH_PROMPT_CACHING = [
"amazon.nova-micro-v1:0" # cheap, fast model
]


def hello_world():
return "Hello, World!"
Expand Down Expand Up @@ -164,6 +168,7 @@ def test_to_dict(self, mock_boto3_session, boto3_config):
"boto3_config": boto3_config,
"tools": None,
"guardrail_config": {"guardrailIdentifier": "test", "guardrailVersion": "test"},
"tools_cachepoint_config": None,
},
}

Expand Down Expand Up @@ -298,6 +303,7 @@ def test_serde_in_pipeline(self, mock_boto3_session, monkeypatch):
}
],
"guardrail_config": None,
"tools_cachepoint_config": None,
},
}
},
Expand Down Expand Up @@ -945,6 +951,28 @@ def test_live_run_with_guardrail(self, streaming_callback):
assert "trace" in results["replies"][0].meta
assert "guardrail" in results["replies"][0].meta["trace"]

@pytest.mark.parametrize("streaming_callback", [None, print_streaming_chunk])
@pytest.mark.parametrize("model_name", MODELS_TO_TEST_WITH_PROMPT_CACHING)
def test_prompt_caching_live_run_with_user_message(self, model_name, streaming_callback):
generator = AmazonBedrockChatGenerator(model=model_name, streaming_callback=streaming_callback)

system_message = ChatMessage.from_system("Always respond with: 'Life is beatiful' (and nothing else).")

user_message = ChatMessage.from_user(
"User message that should be long enough to cache. " * 100, meta={"cachePoint": {"type": "default"}}
)
messages = [system_message, user_message]
result = generator.run(messages=messages)

assert "replies" in result
assert len(result["replies"]) == 1
usage = result["replies"][0].meta["usage"]

# tests run in parallel based on the workflow matrix, so this request should either hit the cache (read tokens)
# or populate it (write tokens)
assert usage["cache_read_input_tokens"] > 1000 or usage["cache_write_input_tokens"] > 1000
assert "cache_details" in usage

@pytest.mark.parametrize("model_name", [MODELS_TO_TEST_WITH_TOOLS[0]]) # just one model is enough
def test_pipeline_with_amazon_bedrock_chat_generator(self, model_name, tools):
"""
Expand Down
Loading