Skip to content

models - anthropic #12

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ Documentation = "https://strandsagents.com"
packages = ["src/strands"]

[project.optional-dependencies]
anthropic = [
"anthropic>=0.21.0,<1.0.0",
]
dev = [
"commitizen>=4.4.0,<5.0.0",
"hatch>=1.0.0,<2.0.0",
Expand All @@ -71,7 +74,7 @@ ollama = [
]

[tool.hatch.envs.hatch-static-analysis]
features = ["litellm", "ollama"]
features = ["anthropic", "litellm", "ollama"]
dependencies = [
"mypy>=1.15.0,<2.0.0",
"ruff>=0.11.6,<0.12.0",
Expand All @@ -94,7 +97,7 @@ lint-fix = [
]

[tool.hatch.envs.hatch-test]
features = ["litellm", "ollama"]
features = ["anthropic", "litellm", "ollama"]
extra-dependencies = [
"moto>=5.1.0,<6.0.0",
"pytest>=8.0.0,<9.0.0",
Expand Down
357 changes: 357 additions & 0 deletions src/strands/models/anthropic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,357 @@
"""Anthropic Claude model provider.

- Docs: https://docs.anthropic.com/claude/reference/getting-started-with-the-api
"""

import base64
import json
import logging
import mimetypes
from typing import Any, Iterable, Optional, TypedDict, cast

import anthropic
from typing_extensions import Required, Unpack, override

from ..types.content import ContentBlock, Messages
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
from ..types.models import Model
from ..types.streaming import StreamEvent
from ..types.tools import ToolSpec

logger = logging.getLogger(__name__)


class AnthropicModel(Model):
"""Anthropic model provider implementation."""

EVENT_TYPES = {
"message_start",
"content_block_start",
"content_block_delta",
"content_block_stop",
"message_stop",
}

OVERFLOW_MESSAGES = {
"input is too long",
"input length exceeds context window",
"input and output tokens exceed your context limit",
}

class AnthropicConfig(TypedDict, total=False):
"""Configuration options for Anthropic models.

Attributes:
max_tokens: Maximum number of tokens to generate.
model_id: Calude model ID (e.g., "claude-3-7-sonnet-latest").
For a complete list of supported models, see
https://docs.anthropic.com/en/docs/about-claude/models/all-models.
params: Additional model parameters (e.g., temperature).
For a complete list of supported parameters, see https://docs.anthropic.com/en/api/messages.
"""

max_tokens: Required[str]
model_id: Required[str]
params: Optional[dict[str, Any]]

def __init__(self, *, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[AnthropicConfig]):
"""Initialize provider instance.

Args:
client_args: Arguments for the underlying Anthropic client (e.g., api_key).
For a complete list of supported arguments, see https://docs.anthropic.com/en/api/client-sdks.
**model_config: Configuration options for the Anthropic model.
"""
self.config = AnthropicModel.AnthropicConfig(**model_config)

logger.debug("config=<%s> | initializing", self.config)

client_args = client_args or {}
self.client = anthropic.Anthropic(**client_args)

@override
def update_config(self, **model_config: Unpack[AnthropicConfig]) -> None: # type: ignore[override]
"""Update the Anthropic model configuration with the provided arguments.

Args:
**model_config: Configuration overrides.
"""
self.config.update(model_config)

@override
def get_config(self) -> AnthropicConfig:
"""Get the Anthropic model configuration.

Returns:
The Anthropic model configuration.
"""
return self.config

def _format_request_message_content(self, content: ContentBlock) -> dict[str, Any]:
"""Format an Anthropic content block.

Args:
content: Message content.

Returns:
Anthropic formatted content block.
"""
if "document" in content:
return {
"source": {
"data": base64.b64encode(content["document"]["source"]["bytes"]).decode("utf-8"),
"media_type": mimetypes.types_map.get(
f".{content['document']['format']}", "application/octet-stream"
),
"type": "base64",
},
"title": content["document"]["name"],
"type": "document",
}

if "image" in content:
return {
"source": {
"data": base64.b64encode(content["image"]["source"]["bytes"]).decode("utf-8"),
"media_type": mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream"),
"type": "base64",
},
"type": "image",
}

if "reasoningContent" in content:
return {
"signature": content["reasoningContent"]["reasoningText"]["signature"],
"thinking": content["reasoningContent"]["reasoningText"]["text"],
"type": "thinking",
}

if "text" in content:
return {"text": content["text"], "type": "text"}

if "toolUse" in content:
return {
"id": content["toolUse"]["toolUseId"],
"input": content["toolUse"]["input"],
"name": content["toolUse"]["name"],
"type": "tool_use",
}

if "toolResult" in content:
return {
"content": [
self._format_request_message_content(cast(ContentBlock, tool_result_content))
for tool_result_content in content["toolResult"]["content"]
],
"is_error": content["toolResult"]["status"] == "error",
"tool_use_id": content["toolResult"]["toolUseId"],
"type": "tool_result",
}

return {"text": json.dumps(content), "type": "text"}

def _format_request_messages(self, messages: Messages) -> list[dict[str, Any]]:
"""Format an Anthropic messages array.

Args:
messages: List of message objects to be processed by the model.

Returns:
An Anthropic messages array.
"""
formatted_messages = []

for message in messages:
formatted_contents: list[dict[str, Any]] = []

for content in message["content"]:
if "cachePoint" in content:
formatted_contents[-1]["cache_control"] = {"type": "ephemeral"}
continue

formatted_contents.append(self._format_request_message_content(content))

if formatted_contents:
formatted_messages.append({"content": formatted_contents, "role": message["role"]})

return formatted_messages

@override
def format_request(
self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
) -> dict[str, Any]:
"""Format an Anthropic streaming request.

Args:
messages: List of message objects to be processed by the model.
tool_specs: List of tool specifications to make available to the model.
system_prompt: System prompt to provide context to the model.

Returns:
An Anthropic streaming request.
"""
return {
"max_tokens": self.config["max_tokens"],
"messages": self._format_request_messages(messages),
"model": self.config["model_id"],
"tools": [
{
"name": tool_spec["name"],
"description": tool_spec["description"],
"input_schema": tool_spec["inputSchema"]["json"],
}
for tool_spec in tool_specs or []
],
**({"system": system_prompt} if system_prompt else {}),
**(self.config.get("params") or {}),
}

@override
def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
"""Format the Anthropic response events into standardized message chunks.

Args:
event: A response event from the Anthropic model.

Returns:
The formatted chunk.

Raises:
RuntimeError: If chunk_type is not recognized.
This error should never be encountered as we control chunk_type in the stream method.
"""
match event["type"]:
case "message_start":
return {"messageStart": {"role": "assistant"}}

case "content_block_start":
content = event["content_block"]

if content["type"] == "tool_use":
return {
"contentBlockStart": {
"contentBlockIndex": event["index"],
"start": {
"toolUse": {
"name": content["name"],
"toolUseId": content["id"],
}
},
}
}

return {"contentBlockStart": {"contentBlockIndex": event["index"], "start": {}}}

case "content_block_delta":
delta = event["delta"]

match delta["type"]:
case "signature_delta":
return {
"contentBlockDelta": {
"contentBlockIndex": event["index"],
"delta": {
"reasoningContent": {
"signature": delta["signature"],
},
},
},
}

case "thinking_delta":
return {
"contentBlockDelta": {
"contentBlockIndex": event["index"],
"delta": {
"reasoningContent": {
"text": delta["thinking"],
},
},
},
}

case "input_json_delta":
return {
"contentBlockDelta": {
"contentBlockIndex": event["index"],
"delta": {
"toolUse": {
"input": delta["partial_json"],
},
},
},
}

case "text_delta":
return {
"contentBlockDelta": {
"contentBlockIndex": event["index"],
"delta": {
"text": delta["text"],
},
},
}

case _:
raise RuntimeError(
f"event_type=<content_block_delta>, delta_type=<{delta['type']}> | unknown type"
)

case "content_block_stop":
return {"contentBlockStop": {"contentBlockIndex": event["index"]}}

case "message_stop":
message = event["message"]

return {"messageStop": {"stopReason": message["stop_reason"]}}

case "metadata":
usage = event["usage"]

return {
"metadata": {
"usage": {
"inputTokens": usage["input_tokens"],
"outputTokens": usage["output_tokens"],
"totalTokens": usage["input_tokens"] + usage["output_tokens"],
},
"metrics": {
"latencyMs": 0, # TODO
},
}
}

case _:
raise RuntimeError(f"event_type=<{event['type']} | unknown type")

@override
def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
"""Send the request to the Anthropic model and get the streaming response.

Args:
request: The formatted request to send to the Anthropic model.

Returns:
An iterable of response events from the Anthropic model.

Raises:
ContextWindowOverflowException: If the input exceeds the model's context window.
ModelThrottledException: If the request is throttled by Anthropic.
"""
try:
with self.client.messages.stream(**request) as stream:
for event in stream:
if event.type in AnthropicModel.EVENT_TYPES:
yield event.dict()

usage = event.message.usage # type: ignore
yield {"type": "metadata", "usage": usage.dict()}

except anthropic.RateLimitError as error:
raise ModelThrottledException(str(error)) from error

except anthropic.BadRequestError as error:
if any(overflow_message in str(error).lower() for overflow_message in AnthropicModel.OVERFLOW_MESSAGES):
raise ContextWindowOverflowException(str(error)) from error

raise error
Loading
Loading