Skip to content

feature: models - openai #65

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ dev = [
"commitizen>=4.4.0,<5.0.0",
"hatch>=1.0.0,<2.0.0",
"moto>=5.1.0,<6.0.0",
"mypy>=0.981,<1.0.0",
"mypy>=1.15.0,<2.0.0",
"pre-commit>=3.2.0,<4.2.0",
"pytest>=8.0.0,<9.0.0",
"pytest-asyncio>=0.26.0,<0.27.0",
Expand All @@ -69,15 +69,18 @@ docs = [
litellm = [
"litellm>=1.69.0,<2.0.0",
]
llamaapi = [
"llama-api-client>=0.1.0,<1.0.0",
]
ollama = [
"ollama>=0.4.8,<1.0.0",
]
llamaapi = [
"llama-api-client>=0.1.0,<1.0.0",
openai = [
"openai>=1.68.0,<2.0.0",
]

[tool.hatch.envs.hatch-static-analysis]
features = ["anthropic", "litellm", "llamaapi", "ollama"]
features = ["anthropic", "litellm", "llamaapi", "ollama", "openai"]
dependencies = [
"mypy>=1.15.0,<2.0.0",
"ruff>=0.11.6,<0.12.0",
Expand All @@ -100,7 +103,7 @@ lint-fix = [
]

[tool.hatch.envs.hatch-test]
features = ["anthropic", "litellm", "llamaapi", "ollama"]
features = ["anthropic", "litellm", "llamaapi", "ollama", "openai"]
extra-dependencies = [
"moto>=5.1.0,<6.0.0",
"pytest>=8.0.0,<9.0.0",
Expand Down
265 changes: 10 additions & 255 deletions src/strands/models/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,19 @@
- Docs: https://docs.litellm.ai/
"""

import json
import logging
import mimetypes
from typing import Any, Iterable, Optional, TypedDict
from typing import Any, Optional, TypedDict, cast

import litellm
from typing_extensions import Unpack, override

from ..types.content import ContentBlock, Messages
from ..types.models import Model
from ..types.streaming import StreamEvent
from ..types.tools import ToolResult, ToolSpec, ToolUse
from ..types.content import ContentBlock
from .openai import OpenAIModel

logger = logging.getLogger(__name__)


class LiteLLMModel(Model):
class LiteLLMModel(OpenAIModel):
"""LiteLLM model provider implementation."""

class LiteLLMConfig(TypedDict, total=False):
Expand All @@ -45,7 +41,7 @@ def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config:
https://github.com/BerriAI/litellm/blob/main/litellm/main.py.
**model_config: Configuration options for the LiteLLM model.
"""
self.config = LiteLLMModel.LiteLLMConfig(**model_config)
self.config = dict(model_config)

logger.debug("config=<%s> | initializing", self.config)

Expand All @@ -68,9 +64,11 @@ def get_config(self) -> LiteLLMConfig:
Returns:
The LiteLLM model configuration.
"""
return self.config
return cast(LiteLLMModel.LiteLLMConfig, self.config)

def _format_request_message_content(self, content: ContentBlock) -> dict[str, Any]:
@override
@staticmethod
def format_request_message_content(content: ContentBlock) -> dict[str, Any]:
"""Format a LiteLLM content block.

Args:
Expand All @@ -79,28 +77,13 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
Returns:
LiteLLM formatted content block.
"""
if "image" in content:
mime_type = mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream")
image_data = content["image"]["source"]["bytes"].decode("utf-8")
return {
"image_url": {
"detail": "auto",
"format": mime_type,
"url": f"data:{mime_type};base64,{image_data}",
},
"type": "image_url",
}

if "reasoningContent" in content:
return {
"signature": content["reasoningContent"]["reasoningText"]["signature"],
"thinking": content["reasoningContent"]["reasoningText"]["text"],
"type": "thinking",
}

if "text" in content:
return {"text": content["text"], "type": "text"}

if "video" in content:
return {
"type": "video_url",
Expand All @@ -110,232 +93,4 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
},
}

return {"text": json.dumps(content), "type": "text"}

def _format_request_message_tool_call(self, tool_use: ToolUse) -> dict[str, Any]:
"""Format a LiteLLM tool call.

Args:
tool_use: Tool use requested by the model.

Returns:
LiteLLM formatted tool call.
"""
return {
"function": {
"arguments": json.dumps(tool_use["input"]),
"name": tool_use["name"],
},
"id": tool_use["toolUseId"],
"type": "function",
}

def _format_request_tool_message(self, tool_result: ToolResult) -> dict[str, Any]:
"""Format a LiteLLM tool message.

Args:
tool_result: Tool result collected from a tool execution.

Returns:
LiteLLM formatted tool message.
"""
return {
"role": "tool",
"tool_call_id": tool_result["toolUseId"],
"content": json.dumps(
{
"content": tool_result["content"],
"status": tool_result["status"],
}
),
}

def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]:
"""Format a LiteLLM messages array.

Args:
messages: List of message objects to be processed by the model.
system_prompt: System prompt to provide context to the model.

Returns:
A LiteLLM messages array.
"""
formatted_messages: list[dict[str, Any]]
formatted_messages = [{"role": "system", "content": system_prompt}] if system_prompt else []

for message in messages:
contents = message["content"]

formatted_contents = [
self._format_request_message_content(content)
for content in contents
if not any(block_type in content for block_type in ["toolResult", "toolUse"])
]
formatted_tool_calls = [
self._format_request_message_tool_call(content["toolUse"])
for content in contents
if "toolUse" in content
]
formatted_tool_messages = [
self._format_request_tool_message(content["toolResult"])
for content in contents
if "toolResult" in content
]

formatted_message = {
"role": message["role"],
"content": formatted_contents,
**({"tool_calls": formatted_tool_calls} if formatted_tool_calls else {}),
}
formatted_messages.append(formatted_message)
formatted_messages.extend(formatted_tool_messages)

return [message for message in formatted_messages if message["content"] or "tool_calls" in message]

@override
def format_request(
self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
) -> dict[str, Any]:
"""Format a LiteLLM chat streaming request.

Args:
messages: List of message objects to be processed by the model.
tool_specs: List of tool specifications to make available to the model.
system_prompt: System prompt to provide context to the model.

Returns:
A LiteLLM chat streaming request.
"""
return {
"messages": self._format_request_messages(messages, system_prompt),
"model": self.config["model_id"],
"stream": True,
"stream_options": {"include_usage": True},
"tools": [
{
"type": "function",
"function": {
"name": tool_spec["name"],
"description": tool_spec["description"],
"parameters": tool_spec["inputSchema"]["json"],
},
}
for tool_spec in tool_specs or []
],
**(self.config.get("params") or {}),
}

@override
def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
"""Format the LiteLLM response events into standardized message chunks.

Args:
event: A response event from the LiteLLM model.

Returns:
The formatted chunk.

Raises:
RuntimeError: If chunk_type is not recognized.
This error should never be encountered as we control chunk_type in the stream method.
"""
match event["chunk_type"]:
case "message_start":
return {"messageStart": {"role": "assistant"}}

case "content_start":
if event["data_type"] == "tool":
return {
"contentBlockStart": {
"start": {
"toolUse": {
"name": event["data"].function.name,
"toolUseId": event["data"].id,
}
}
}
}

return {"contentBlockStart": {"start": {}}}

case "content_delta":
if event["data_type"] == "tool":
return {"contentBlockDelta": {"delta": {"toolUse": {"input": event["data"].function.arguments}}}}

return {"contentBlockDelta": {"delta": {"text": event["data"]}}}

case "content_stop":
return {"contentBlockStop": {}}

case "message_stop":
match event["data"]:
case "tool_calls":
return {"messageStop": {"stopReason": "tool_use"}}
case "length":
return {"messageStop": {"stopReason": "max_tokens"}}
case _:
return {"messageStop": {"stopReason": "end_turn"}}

case "metadata":
return {
"metadata": {
"usage": {
"inputTokens": event["data"].prompt_tokens,
"outputTokens": event["data"].completion_tokens,
"totalTokens": event["data"].total_tokens,
},
"metrics": {
"latencyMs": 0, # TODO
},
},
}

case _:
raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type")

@override
def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
"""Send the request to the LiteLLM model and get the streaming response.

Args:
request: The formatted request to send to the LiteLLM model.

Returns:
An iterable of response events from the LiteLLM model.
"""
response = self.client.chat.completions.create(**request)

yield {"chunk_type": "message_start"}
yield {"chunk_type": "content_start", "data_type": "text"}

tool_calls: dict[int, list[Any]] = {}

for event in response:
choice = event.choices[0]
if choice.finish_reason:
break

if choice.delta.content:
yield {"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content}

for tool_call in choice.delta.tool_calls or []:
tool_calls.setdefault(tool_call.index, []).append(tool_call)

yield {"chunk_type": "content_stop", "data_type": "text"}

for tool_deltas in tool_calls.values():
tool_start, tool_deltas = tool_deltas[0], tool_deltas[1:]
yield {"chunk_type": "content_start", "data_type": "tool", "data": tool_start}

for tool_delta in tool_deltas:
yield {"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta}

yield {"chunk_type": "content_stop", "data_type": "tool"}

yield {"chunk_type": "message_stop", "data": choice.finish_reason}

# Skip remaining events as we don't have use for anything except the final usage payload
for event in response:
_ = event

yield {"chunk_type": "metadata", "data": event.usage}
return OpenAIModel.format_request_message_content(content)
Loading