Skip to content

models - correct tool result content #154

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/strands/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

import base64
import json
import logging
import mimetypes
from typing import Any, Iterable, Optional, TypedDict, cast
Expand Down Expand Up @@ -145,7 +146,11 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
if "toolResult" in content:
return {
"content": [
self._format_request_message_content(cast(ContentBlock, tool_result_content))
self._format_request_message_content(
{"text": json.dumps(tool_result_content["json"])}
if "json" in tool_result_content
else cast(ContentBlock, tool_result_content)
)
for tool_result_content in content["toolResult"]["content"]
],
"is_error": content["toolResult"]["status"] == "error",
Expand Down
6 changes: 3 additions & 3 deletions src/strands/models/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def get_config(self) -> LiteLLMConfig:
return cast(LiteLLMModel.LiteLLMConfig, self.config)

@override
@staticmethod
def format_request_message_content(content: ContentBlock) -> dict[str, Any]:
@classmethod
def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]:
"""Format a LiteLLM content block.

Args:
Expand Down Expand Up @@ -96,4 +96,4 @@ def format_request_message_content(content: ContentBlock) -> dict[str, Any]:
},
}

return OpenAIModel.format_request_message_content(content)
return super().format_request_message_content(content)
26 changes: 19 additions & 7 deletions src/strands/models/llamaapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import json
import logging
import mimetypes
from typing import Any, Iterable, Optional
from typing import Any, Iterable, Optional, cast

import llama_api_client
from llama_api_client import LlamaAPIClient
Expand Down Expand Up @@ -139,18 +139,30 @@ def _format_request_tool_message(self, tool_result: ToolResult) -> dict[str, Any
Returns:
Llama API formatted tool message.
"""
contents = cast(
list[ContentBlock],
[
{"text": json.dumps(content["json"])} if "json" in content else content
for content in tool_result["content"]
],
)

return {
"role": "tool",
"tool_call_id": tool_result["toolUseId"],
"content": json.dumps(
{
"content": tool_result["content"],
"status": tool_result["status"],
}
),
"content": [self._format_request_message_content(content) for content in contents],
}

def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]:
"""Format a LlamaAPI compatible messages array.

Args:
messages: List of message objects to be processed by the model.
system_prompt: System prompt to provide context to the model.

Returns:
An LlamaAPI compatible messages array.
"""
formatted_messages: list[dict[str, Any]]
formatted_messages = [{"role": "system", "content": system_prompt}] if system_prompt else []

Expand Down
213 changes: 114 additions & 99 deletions src/strands/models/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@

import json
import logging
from typing import Any, Iterable, Optional, Union
from typing import Any, Iterable, Optional, cast

from ollama import Client as OllamaClient
from typing_extensions import TypedDict, Unpack, override

from ..types.content import ContentBlock, Message, Messages
from ..types.media import DocumentContent, ImageContent
from ..types.content import ContentBlock, Messages
from ..types.models import Model
from ..types.streaming import StopReason, StreamEvent
from ..types.tools import ToolSpec
Expand Down Expand Up @@ -92,35 +91,31 @@ def get_config(self) -> OllamaConfig:
"""
return self.config

@override
def format_request(
self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
) -> dict[str, Any]:
"""Format an Ollama chat streaming request.
def _format_request_message_contents(self, role: str, content: ContentBlock) -> list[dict[str, Any]]:
"""Format Ollama compatible message contents.

Ollama doesn't support an array of contents, so we must flatten everything into separate message blocks.

Args:
messages: List of message objects to be processed by the model.
tool_specs: List of tool specifications to make available to the model.
system_prompt: System prompt to provide context to the model.
role: E.g., user.
content: Content block to format.

Returns:
An Ollama chat streaming request.
Ollama formatted message contents.

Raises:
TypeError: If a message contains a content block type that cannot be converted to an Ollama-compatible
format.
TypeError: If the content block type cannot be converted to an Ollama-compatible format.
"""
if "text" in content:
return [{"role": role, "content": content["text"]}]

def format_message(message: Message, content: ContentBlock) -> dict[str, Any]:
if "text" in content:
return {"role": message["role"], "content": content["text"]}
if "image" in content:
return [{"role": role, "images": [content["image"]["source"]["bytes"]]}]

if "image" in content:
return {"role": message["role"], "images": [content["image"]["source"]["bytes"]]}

if "toolUse" in content:
return {
"role": "assistant",
if "toolUse" in content:
return [
{
"role": role,
"tool_calls": [
{
"function": {
Expand All @@ -130,45 +125,63 @@ def format_message(message: Message, content: ContentBlock) -> dict[str, Any]:
}
],
}
]

if "toolResult" in content:
return [
formatted_tool_result_content
for tool_result_content in content["toolResult"]["content"]
for formatted_tool_result_content in self._format_request_message_contents(
"tool",
(
{"text": json.dumps(tool_result_content["json"])}
if "json" in tool_result_content
else cast(ContentBlock, tool_result_content)
),
)
]

if "toolResult" in content:
result_content: Union[str, ImageContent, DocumentContent, Any] = None
result_images = []
for tool_result_content in content["toolResult"]["content"]:
if "text" in tool_result_content:
result_content = tool_result_content["text"]
elif "json" in tool_result_content:
result_content = tool_result_content["json"]
elif "image" in tool_result_content:
result_content = "see images"
result_images.append(tool_result_content["image"]["source"]["bytes"])
else:
result_content = content["toolResult"]["content"]
raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type")

return {
"role": "tool",
"content": json.dumps(
{
"name": content["toolResult"]["toolUseId"],
"result": result_content,
"status": content["toolResult"]["status"],
}
),
**({"images": result_images} if result_images else {}),
}
def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]:
"""Format an Ollama compatible messages array.

raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type")
Args:
messages: List of message objects to be processed by the model.
system_prompt: System prompt to provide context to the model.

def format_messages() -> list[dict[str, Any]]:
return [format_message(message, content) for message in messages for content in message["content"]]
Returns:
An Ollama compatible messages array.
"""
system_message = [{"role": "system", "content": system_prompt}] if system_prompt else []

formatted_messages = format_messages()
return system_message + [
formatted_message
for message in messages
for content in message["content"]
for formatted_message in self._format_request_message_contents(message["role"], content)
]

@override
def format_request(
self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
) -> dict[str, Any]:
"""Format an Ollama chat streaming request.

Args:
messages: List of message objects to be processed by the model.
tool_specs: List of tool specifications to make available to the model.
system_prompt: System prompt to provide context to the model.

Returns:
An Ollama chat streaming request.

Raises:
TypeError: If a message contains a content block type that cannot be converted to an Ollama-compatible
format.
"""
return {
"messages": [
*([{"role": "system", "content": system_prompt}] if system_prompt else []),
*formatted_messages,
],
"messages": self._format_request_messages(messages, system_prompt),
"model": self.config["model_id"],
"options": {
**(self.config.get("options") or {}),
Expand Down Expand Up @@ -217,52 +230,54 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
RuntimeError: If chunk_type is not recognized.
This error should never be encountered as we control chunk_type in the stream method.
"""
if event["chunk_type"] == "message_start":
return {"messageStart": {"role": "assistant"}}

if event["chunk_type"] == "content_start":
if event["data_type"] == "text":
return {"contentBlockStart": {"start": {}}}

tool_name = event["data"].function.name
return {"contentBlockStart": {"start": {"toolUse": {"name": tool_name, "toolUseId": tool_name}}}}

if event["chunk_type"] == "content_delta":
if event["data_type"] == "text":
return {"contentBlockDelta": {"delta": {"text": event["data"]}}}

tool_arguments = event["data"].function.arguments
return {"contentBlockDelta": {"delta": {"toolUse": {"input": json.dumps(tool_arguments)}}}}

if event["chunk_type"] == "content_stop":
return {"contentBlockStop": {}}

if event["chunk_type"] == "message_stop":
reason: StopReason
if event["data"] == "tool_use":
reason = "tool_use"
elif event["data"] == "length":
reason = "max_tokens"
else:
reason = "end_turn"

return {"messageStop": {"stopReason": reason}}

if event["chunk_type"] == "metadata":
return {
"metadata": {
"usage": {
"inputTokens": event["data"].eval_count,
"outputTokens": event["data"].prompt_eval_count,
"totalTokens": event["data"].eval_count + event["data"].prompt_eval_count,
},
"metrics": {
"latencyMs": event["data"].total_duration / 1e6,
match event["chunk_type"]:
case "message_start":
return {"messageStart": {"role": "assistant"}}

case "content_start":
if event["data_type"] == "text":
return {"contentBlockStart": {"start": {}}}

tool_name = event["data"].function.name
return {"contentBlockStart": {"start": {"toolUse": {"name": tool_name, "toolUseId": tool_name}}}}

case "content_delta":
if event["data_type"] == "text":
return {"contentBlockDelta": {"delta": {"text": event["data"]}}}

tool_arguments = event["data"].function.arguments
return {"contentBlockDelta": {"delta": {"toolUse": {"input": json.dumps(tool_arguments)}}}}

case "content_stop":
return {"contentBlockStop": {}}

case "message_stop":
reason: StopReason
if event["data"] == "tool_use":
reason = "tool_use"
elif event["data"] == "length":
reason = "max_tokens"
else:
reason = "end_turn"

return {"messageStop": {"stopReason": reason}}

case "metadata":
return {
"metadata": {
"usage": {
"inputTokens": event["data"].eval_count,
"outputTokens": event["data"].prompt_eval_count,
"totalTokens": event["data"].eval_count + event["data"].prompt_eval_count,
},
"metrics": {
"latencyMs": event["data"].total_duration / 1e6,
},
},
},
}
}

raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type")
case _:
raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type")

@override
def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
Expand Down
Loading