Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
# ruff: noqa: I001
# customizing the import order to prioritize the openai adapter over the others
from .openai import OpenAIAdapter
from .anthropic import AnthropicAdapter
from .google import GoogleGenAIAdapter
from .langchain import LangChainModelAdapter
from .litellm import LiteLLMAdapter

__all__ = [
"AnthropicAdapter",
"GoogleGenAIAdapter",
"LangChainModelAdapter",
"LiteLLMAdapter",
"OpenAIAdapter",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .adapter import AnthropicAdapter

__all__ = ["AnthropicAdapter"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,284 @@
import logging
from typing import Any, Dict, Type, Union, cast

from phoenix.evals.legacy.templates import MultimodalPrompt, PromptPartContentType

from ...registries import register_adapter, register_provider
from ...types import BaseLLMAdapter, ObjectGenerationMethod
from .factories import AnthropicClientWrapper, create_anthropic_client

logger = logging.getLogger(__name__)


def identify_anthropic_client(client: Any) -> bool:
if isinstance(client, AnthropicClientWrapper):
return True

return (
hasattr(client, "__module__")
and client.__module__ is not None
and (
"anthropic" in client.__module__
or (hasattr(client, "messages") and hasattr(client.messages, "create"))
)
)


def get_anthropic_rate_limit_errors() -> list[Type[Exception]]:
from anthropic import RateLimitError

return [RateLimitError]


@register_adapter(
identifier=identify_anthropic_client,
name="anthropic",
)
@register_provider(
provider="anthropic",
client_factory=create_anthropic_client,
get_rate_limit_errors=get_anthropic_rate_limit_errors,
dependencies=["anthropic"],
)
class AnthropicAdapter(BaseLLMAdapter):
def __init__(self, client: Any):
self.client = client
self._validate_client()
self._is_async = self._check_if_async_client()

@classmethod
def client_name(cls) -> str:
return "anthropic"

def _validate_client(self) -> None:
actual_client = getattr(self.client, "client", self.client)
if not (hasattr(actual_client, "messages") and hasattr(actual_client.messages, "create")):
raise ValueError(
"AnthropicAdapter requires an Anthropic client instance with messages.create, got "
f"{type(self.client)}"
)

def _check_if_async_client(self) -> bool:
actual_client = getattr(self.client, "client", self.client)

if hasattr(actual_client, "__module__") and actual_client.__module__:
if "anthropic" in actual_client.__module__:
class_name = actual_client.__class__.__name__
return "Async" in class_name

create_method = actual_client.messages.create
import inspect

return inspect.iscoroutinefunction(create_method)

def generate_text(self, prompt: Union[str, MultimodalPrompt], **kwargs: Any) -> str:
if self._is_async:
raise ValueError("Cannot call sync method generate_text() on async Anthropic client.")
messages = self._build_messages(prompt)

try:
response = self.client.messages.create(
model=self.model_name, messages=messages, max_tokens=4096, **kwargs
)
if hasattr(response.content[0], "text"):
return cast(str, response.content[0].text)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Unconditional Access to Empty Content Array

Potential IndexError when accessing response.content[0] without checking if the content array is empty. If Anthropic returns an empty content array, this will raise an IndexError. The same issue exists in the async_generate_text method at lines 102-103.

Fix in Cursor Fix in Web

else:
raise ValueError("Anthropic returned unexpected content format")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could it be useful to add the response body to this error message so users know what is being returned?

except Exception as e:
logger.error(f"Anthropic completion failed: {e}")
raise

async def async_generate_text(self, prompt: Union[str, MultimodalPrompt], **kwargs: Any) -> str:
if not self._is_async:
raise ValueError(
"Cannot call async method async_generate_text() on sync Anthropic client."
)
messages = self._build_messages(prompt)

try:
response = await self.client.messages.create(
model=self.model_name, messages=messages, max_tokens=4096, **kwargs
)
if hasattr(response.content[0], "text"):
return cast(str, response.content[0].text)
else:
raise ValueError("Anthropic returned unexpected content format")
except Exception as e:
logger.error(f"Anthropic async completion failed: {e}")
raise

def generate_object(
self,
prompt: Union[str, MultimodalPrompt],
schema: Dict[str, Any],
method: ObjectGenerationMethod = ObjectGenerationMethod.AUTO,
**kwargs: Any,
) -> Dict[str, Any]:
if self._is_async:
raise ValueError(
"Cannot call sync method generate_object() on async Anthropic client. "
"Use async_generate_object() instead or provide a sync Anthropic client."
)
self._validate_schema(schema)

supports_tool_calls = self._supports_tool_calls()

if method == ObjectGenerationMethod.STRUCTURED_OUTPUT:
raise ValueError(
"Anthropic does not support native structured output. Use TOOL_CALLING or AUTO."
)

elif method == ObjectGenerationMethod.TOOL_CALLING:
if not supports_tool_calls:
raise ValueError(f"Anthropic model {self.model_name} does not support tool calls")
return self._generate_with_tool_calling(prompt, schema, **kwargs)

elif method == ObjectGenerationMethod.AUTO:
if not supports_tool_calls:
raise ValueError(
f"Anthropic model {self.model_name} does not support tool calls "
"or structured output"
)
return self._generate_with_tool_calling(prompt, schema, **kwargs)

else:
raise ValueError(f"Unsupported object generation method: {method}")

async def async_generate_object(
self,
prompt: Union[str, MultimodalPrompt],
schema: Dict[str, Any],
method: ObjectGenerationMethod = ObjectGenerationMethod.AUTO,
**kwargs: Any,
) -> Dict[str, Any]:
if not self._is_async:
raise ValueError(
"Cannot call async method async_generate_object() on sync Anthropic client."
)
self._validate_schema(schema)

supports_tool_calls = self._supports_tool_calls()

if method == ObjectGenerationMethod.STRUCTURED_OUTPUT:
raise ValueError(
"Anthropic does not support native structured output. Use TOOL_CALLING or AUTO."
)

elif method == ObjectGenerationMethod.TOOL_CALLING:
if not supports_tool_calls:
raise ValueError(f"Anthropic model {self.model_name} does not support tool calls")
return await self._async_generate_with_tool_calling(prompt, schema, **kwargs)

elif method == ObjectGenerationMethod.AUTO:
if not supports_tool_calls:
raise ValueError(
f"Anthropic model {self.model_name} does not support tool calls "
"or structured output"
)
return await self._async_generate_with_tool_calling(prompt, schema, **kwargs)

else:
raise ValueError(f"Unsupported object generation method: {method}")

def _generate_with_tool_calling(
self,
prompt: Union[str, MultimodalPrompt],
schema: Dict[str, Any],
**kwargs: Any,
) -> Dict[str, Any]:
messages = self._build_messages(prompt)
tool_definition = self._schema_to_tool(schema)

response = self.client.messages.create(
model=self.model_name,
messages=messages,
tools=[tool_definition],
tool_choice={"type": "tool", "name": "extract_structured_data"},
max_tokens=4096,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't max_tokens be one of the configurable kwargs?

**kwargs,
)

for content_block in response.content:
if hasattr(content_block, "type") and content_block.type == "tool_use":
return cast(Dict[str, Any], content_block.input)

raise ValueError("No tool use in response")

async def _async_generate_with_tool_calling(
self,
prompt: Union[str, MultimodalPrompt],
schema: Dict[str, Any],
**kwargs: Any,
) -> Dict[str, Any]:
messages = self._build_messages(prompt)
tool_definition = self._schema_to_tool(schema)

response = await self.client.messages.create(
model=self.model_name,
messages=messages,
tools=[tool_definition],
tool_choice={"type": "tool", "name": "extract_structured_data"},
max_tokens=4096,
**kwargs,
)

for content_block in response.content:
if hasattr(content_block, "type") and content_block.type == "tool_use":
return cast(Dict[str, Any], content_block.input)

raise ValueError("No tool use in response")

@property
def model_name(self) -> str:
if hasattr(self.client, "model"):
return str(self.client.model)
else:
return "claude-3-5-sonnet-20241022"

def _supports_tool_calls(self) -> bool:
model_name = self.model_name.lower()
return "claude" in model_name

def _schema_to_tool(self, schema: Dict[str, Any]) -> Dict[str, Any]:
description = schema.get(
"description", "Extract structured data according to the provided schema"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small nit: the task is more like "Respond in a format matching the provided schema" than it is about extracting structured data. This may not impact things much, but LLMs are sensitive to wording so 🤷🏼‍♀️

)

tool_definition = {
"name": "extract_structured_data",
"description": description,
"input_schema": schema,
}

return tool_definition

def _build_messages(self, prompt: Union[str, MultimodalPrompt]) -> list[dict[str, Any]]:
if isinstance(prompt, str):
return [{"role": "user", "content": prompt}]

text_parts = []
for part in prompt.parts:
if part.content_type == PromptPartContentType.TEXT:
text_parts.append(part.content)

combined_text = "\n".join(text_parts)
return [{"role": "user", "content": combined_text}]

def _validate_schema(self, schema: Dict[str, Any]) -> None:
if not schema:
raise ValueError(f"Schema must be a non-empty dictionary, got {type(schema)}")

properties = schema.get("properties", {})
required = schema.get("required", [])

if properties and required:
property_names = set(properties.keys())
required_names = set(required)

missing_properties = required_names - property_names
if missing_properties:
raise ValueError(
f"Schema validation error: Required fields {list(missing_properties)} "
f"are not defined in properties. "
f"Properties: {list(property_names)}, Required: {list(required_names)}"
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from typing import Any, Union


class AnthropicClientWrapper:
def __init__(self, client: Any, model: str):
self.client = client
self.model = model

def __getattr__(self, name: str) -> Any:
return getattr(self.client, name)


def create_anthropic_client(model: str, is_async: bool = False, **kwargs: Any) -> Any:
try:
from anthropic import Anthropic, AsyncAnthropic

if is_async:
client: Union[Anthropic, AsyncAnthropic] = AsyncAnthropic(max_retries=0, **kwargs)
else:
client = Anthropic(max_retries=0, **kwargs)

return AnthropicClientWrapper(client, model)
except ImportError:
raise ImportError("Anthropic package not installed. Run: pip install anthropic")
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .adapter import GoogleGenAIAdapter

__all__ = ["GoogleGenAIAdapter"]
Loading
Loading