Skip to content

Add Anthropic provider classes #1120

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions docs/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,15 +172,38 @@ agent = Agent(model)
...
```

### `api_key` argument
### `provider` argument

You can provide a custom [`Provider`][pydantic_ai.providers.Provider] via the [`provider` argument][pydantic_ai.models.anthropic.AnthropicModel.__init__]:

```py title="anthropic_model_provider.py"
from pydantic_ai import Agent
from pydantic_ai.models.anthropic import AnthropicModel
from pydantic_ai.providers.anthropic import AnthropicProvider

model = AnthropicModel(
'claude-3-5-sonnet-latest', provider=AnthropicProvider(api_key='your-api-key')
)
agent = Agent(model)
...
```

### Custom HTTP Client

If you don't want to or can't set the environment variable, you can pass it at runtime via the [`api_key` argument][pydantic_ai.models.anthropic.AnthropicModel.__init__]:
You can customize the `AnthropicProvider` with a custom `httpx.AsyncClient`:

```py title="anthropic_model_custom_provider.py"
from httpx import AsyncClient

```py title="anthropic_model_api_key.py"
from pydantic_ai import Agent
from pydantic_ai.models.anthropic import AnthropicModel
from pydantic_ai.providers.anthropic import AnthropicProvider

model = AnthropicModel('claude-3-5-sonnet-latest', api_key='your-api-key')
custom_http_client = AsyncClient(timeout=30)
model = AnthropicModel(
'claude-3-5-sonnet-latest',
provider=AnthropicProvider(api_key='your-api-key', http_client=custom_http_client),
)
agent = Agent(model)
...
```
Expand Down
33 changes: 31 additions & 2 deletions pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from anthropic.types import DocumentBlockParam
from httpx import AsyncClient as AsyncHTTPClient
from typing_extensions import assert_never
from typing_extensions import assert_never, deprecated

from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
from .._utils import guard_tool_call_id as _guard_tool_call_id
Expand All @@ -31,6 +31,7 @@
ToolReturnPart,
UserPromptPart,
)
from ..providers import Provider, infer_provider
from ..settings import ModelSettings
from ..tools import ToolDefinition
from . import Model, ModelRequestParameters, StreamedResponse, cached_async_http_client, check_allow_model_requests
Expand Down Expand Up @@ -111,10 +112,31 @@ class AnthropicModel(Model):
_model_name: AnthropicModelName = field(repr=False)
_system: str = field(default='anthropic', repr=False)

@overload
def __init__(
self,
model_name: AnthropicModelName,
*,
provider: Literal['anthropic'] | Provider[AsyncAnthropic] = 'anthropic',
) -> None: ...

@deprecated('Use the `provider` parameter instead of `api_key`, `anthropic_client`, and `http_client`.')
@overload
def __init__(
self,
model_name: AnthropicModelName,
*,
provider: None = None,
api_key: str | None = None,
anthropic_client: AsyncAnthropic | None = None,
http_client: AsyncHTTPClient | None = None,
) -> None: ...

def __init__(
self,
model_name: AnthropicModelName,
*,
provider: Literal['anthropic'] | Provider[AsyncAnthropic] | None = None,
api_key: str | None = None,
anthropic_client: AsyncAnthropic | None = None,
http_client: AsyncHTTPClient | None = None,
Expand All @@ -124,6 +146,8 @@ def __init__(
Args:
model_name: The name of the Anthropic model to use. List of model names available
[here](https://docs.anthropic.com/en/docs/about-claude/models).
provider: The provider to use for the Anthropic API. Can be either the string 'anthropic' or an
instance of `Provider[AsyncAnthropic]`. If not provided, the other parameters will be used.
api_key: The API key to use for authentication, if not provided, the `ANTHROPIC_API_KEY` environment variable
will be used if available.
anthropic_client: An existing
Expand All @@ -132,7 +156,12 @@ def __init__(
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
"""
self._model_name = model_name
if anthropic_client is not None:

if provider is not None:
if isinstance(provider, str):
provider = infer_provider(provider)
self.client = provider.client
elif anthropic_client is not None:
assert http_client is None, 'Cannot provide both `anthropic_client` and `http_client`'
assert api_key is None, 'Cannot provide both `anthropic_client` and `api_key`'
self.client = anthropic_client
Expand Down
4 changes: 4 additions & 0 deletions pydantic_ai_slim/pydantic_ai/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,5 +69,9 @@ def infer_provider(provider: str) -> Provider[Any]:
from .groq import GroqProvider

return GroqProvider()
elif provider == 'anthropic': # pragma: no cover
from .anthropic import AnthropicProvider

return AnthropicProvider()
else: # pragma: no cover
raise ValueError(f'Unknown provider: {provider}')
74 changes: 74 additions & 0 deletions pydantic_ai_slim/pydantic_ai/providers/anthropic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from __future__ import annotations as _annotations

import os
from typing import overload

import httpx

from pydantic_ai.models import cached_async_http_client

try:
from anthropic import AsyncAnthropic
except ImportError as _import_error: # pragma: no cover
raise ImportError(
'Please install the `anthropic` package to use the Anthropic provider, '
"you can use the `anthropic` optional group — `pip install 'pydantic-ai-slim[anthropic]'`"
) from _import_error


from . import Provider


class AnthropicProvider(Provider[AsyncAnthropic]):
"""Provider for Anthropic API."""

@property
def name(self) -> str:
return 'anthropic'

@property
def base_url(self) -> str:
return str(self._client.base_url)

@property
def client(self) -> AsyncAnthropic:
return self._client

@overload
def __init__(self, *, anthropic_client: AsyncAnthropic | None = None) -> None: ...

@overload
def __init__(self, *, api_key: str | None = None, http_client: httpx.AsyncClient | None = None) -> None: ...

def __init__(
self,
*,
api_key: str | None = None,
anthropic_client: AsyncAnthropic | None = None,
http_client: httpx.AsyncClient | None = None,
) -> None:
"""Create a new Anthropic provider.

Args:
api_key: The API key to use for authentication, if not provided, the `ANTHROPIC_API_KEY` environment variable
will be used if available.
anthropic_client: An existing [`AsyncAnthropic`](https://github.com/anthropics/anthropic-sdk-python)
client to use. If provided, the `api_key` and `http_client` arguments will be ignored.
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
"""
if anthropic_client is not None:
assert http_client is None, 'Cannot provide both `anthropic_client` and `http_client`'
assert api_key is None, 'Cannot provide both `anthropic_client` and `api_key`'
self._client = anthropic_client
else:
api_key = api_key or os.environ.get('ANTHROPIC_API_KEY')
if api_key is None:
raise ValueError(
'Set the `ANTHROPIC_API_KEY` environment variable or pass it via `AnthropicProvider(api_key=...)`'
'to use the Anthropic provider.'
)

if http_client is not None:
self._client = AsyncAnthropic(api_key=api_key, http_client=http_client)
else:
self._client = AsyncAnthropic(api_key=api_key, http_client=cached_async_http_client())
49 changes: 33 additions & 16 deletions tests/models/test_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from datetime import timezone
from functools import cached_property
from typing import Any, TypeVar, Union, cast
from unittest.mock import patch

import httpx
import pytest
Expand Down Expand Up @@ -53,6 +54,7 @@
from anthropic.types.raw_message_delta_event import Delta

from pydantic_ai.models.anthropic import AnthropicModel, AnthropicModelSettings
from pydantic_ai.providers.anthropic import AnthropicProvider

# note: we use Union here so that casting works with Python 3.9
MockAnthropicMessage = Union[AnthropicMessage, Exception]
Expand All @@ -68,7 +70,7 @@


def test_init():
m = AnthropicModel('claude-3-5-haiku-latest', api_key='foobar')
m = AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider(api_key='foobar'))
assert m.client.api_key == 'foobar'
assert m.model_name == 'claude-3-5-haiku-latest'
assert m.system == 'anthropic'
Expand All @@ -81,6 +83,7 @@ class MockAnthropic:
stream: Sequence[MockRawMessageStreamEvent] | Sequence[Sequence[MockRawMessageStreamEvent]] | None = None
index = 0
chat_completion_kwargs: list[dict[str, Any]] = field(default_factory=list)
base_url: str | None = None

@cached_property
def messages(self) -> Any:
Expand Down Expand Up @@ -134,7 +137,7 @@ def completion_message(content: list[ContentBlock], usage: AnthropicUsage) -> An
async def test_sync_request_text_response(allow_model_requests: None):
c = completion_message([TextBlock(text='world', type='text')], AnthropicUsage(input_tokens=5, output_tokens=10))
mock_client = MockAnthropic.create_mock(c)
m = AnthropicModel('claude-3-5-haiku-latest', anthropic_client=mock_client)
m = AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider(anthropic_client=mock_client))
agent = Agent(m)

result = await agent.run('hello')
Expand Down Expand Up @@ -171,7 +174,7 @@ async def test_async_request_text_response(allow_model_requests: None):
usage=AnthropicUsage(input_tokens=3, output_tokens=5),
)
mock_client = MockAnthropic.create_mock(c)
m = AnthropicModel('claude-3-5-haiku-latest', anthropic_client=mock_client)
m = AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider(anthropic_client=mock_client))
agent = Agent(m)

result = await agent.run('hello')
Expand All @@ -185,7 +188,7 @@ async def test_request_structured_response(allow_model_requests: None):
usage=AnthropicUsage(input_tokens=3, output_tokens=5),
)
mock_client = MockAnthropic.create_mock(c)
m = AnthropicModel('claude-3-5-haiku-latest', anthropic_client=mock_client)
m = AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider(anthropic_client=mock_client))
agent = Agent(m, result_type=list[int])

result = await agent.run('hello')
Expand Down Expand Up @@ -235,7 +238,7 @@ async def test_request_tool_call(allow_model_requests: None):
]

mock_client = MockAnthropic.create_mock(responses)
m = AnthropicModel('claude-3-5-haiku-latest', anthropic_client=mock_client)
m = AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider(anthropic_client=mock_client))
agent = Agent(m, system_prompt='this is the system prompt')

@agent.tool_plain
Expand Down Expand Up @@ -327,7 +330,7 @@ async def test_parallel_tool_calls(allow_model_requests: None, parallel_tool_cal
]

mock_client = MockAnthropic.create_mock(responses)
m = AnthropicModel('claude-3-5-haiku-latest', anthropic_client=mock_client)
m = AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider(anthropic_client=mock_client))
agent = Agent(m, model_settings=ModelSettings(parallel_tool_calls=parallel_tool_calls))

@agent.tool_plain
Expand Down Expand Up @@ -366,7 +369,7 @@ async def retrieve_entity_info(name: str) -> str:
# However, we do want to use the environment variable if present when rewriting VCR cassettes.
api_key = os.environ.get('ANTHROPIC_API_KEY', 'mock-value')
agent = Agent(
AnthropicModel('claude-3-5-haiku-latest', api_key=api_key),
AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider(api_key=api_key)),
system_prompt=system_prompt,
tools=[retrieve_entity_info],
)
Expand Down Expand Up @@ -436,7 +439,7 @@ async def retrieve_entity_info(name: str) -> str:
async def test_anthropic_specific_metadata(allow_model_requests: None) -> None:
c = completion_message([TextBlock(text='world', type='text')], AnthropicUsage(input_tokens=5, output_tokens=10))
mock_client = MockAnthropic.create_mock(c)
m = AnthropicModel('claude-3-5-haiku-latest', anthropic_client=mock_client)
m = AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider(anthropic_client=mock_client))
agent = Agent(m)

result = await agent.run('hello', model_settings=AnthropicModelSettings(anthropic_metadata={'user_id': '123'}))
Expand Down Expand Up @@ -525,7 +528,7 @@ async def test_stream_structured(allow_model_requests: None):
]

mock_client = MockAnthropic.create_stream_mock([stream, done_stream])
m = AnthropicModel('claude-3-5-haiku-latest', anthropic_client=mock_client)
m = AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider(anthropic_client=mock_client))
agent = Agent(m)

tool_called = False
Expand Down Expand Up @@ -555,7 +558,7 @@ async def my_tool(first: str, second: str) -> int:

@pytest.mark.vcr()
async def test_image_url_input(allow_model_requests: None, anthropic_api_key: str):
m = AnthropicModel('claude-3-5-haiku-latest', api_key=anthropic_api_key)
m = AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider(api_key=anthropic_api_key))
agent = Agent(m)

result = await agent.run(
Expand All @@ -573,7 +576,7 @@ async def test_image_url_input(allow_model_requests: None, anthropic_api_key: st

@pytest.mark.vcr()
async def test_image_url_input_invalid_mime_type(allow_model_requests: None, anthropic_api_key: str):
m = AnthropicModel('claude-3-5-haiku-latest', api_key=anthropic_api_key)
m = AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider(api_key=anthropic_api_key))
agent = Agent(m)

result = await agent.run(
Expand All @@ -593,7 +596,7 @@ async def test_image_url_input_invalid_mime_type(allow_model_requests: None, ant
async def test_audio_as_binary_content_input(allow_model_requests: None, media_type: str):
c = completion_message([TextBlock(text='world', type='text')], AnthropicUsage(input_tokens=5, output_tokens=10))
mock_client = MockAnthropic.create_mock(c)
m = AnthropicModel('claude-3-5-haiku-latest', anthropic_client=mock_client)
m = AnthropicModel('claude-3-5-haiku-latest', provider=AnthropicProvider(anthropic_client=mock_client))
agent = Agent(m)

base64_content = b'//uQZ'
Expand All @@ -610,7 +613,7 @@ def test_model_status_error(allow_model_requests: None) -> None:
body={'error': 'test error'},
)
)
m = AnthropicModel('claude-3-5-sonnet-latest', anthropic_client=mock_client)
m = AnthropicModel('claude-3-5-sonnet-latest', provider=AnthropicProvider(anthropic_client=mock_client))
agent = Agent(m)
with pytest.raises(ModelHTTPError) as exc_info:
agent.run_sync('hello')
Expand All @@ -623,7 +626,7 @@ def test_model_status_error(allow_model_requests: None) -> None:
async def test_document_binary_content_input(
allow_model_requests: None, anthropic_api_key: str, document_content: BinaryContent
):
m = AnthropicModel('claude-3-5-sonnet-latest', api_key=anthropic_api_key)
m = AnthropicModel('claude-3-5-sonnet-latest', provider=AnthropicProvider(api_key=anthropic_api_key))
agent = Agent(m)

result = await agent.run(['What is the main content on this document?', document_content])
Expand All @@ -634,7 +637,7 @@ async def test_document_binary_content_input(

@pytest.mark.vcr()
async def test_document_url_input(allow_model_requests: None, anthropic_api_key: str):
m = AnthropicModel('claude-3-5-sonnet-latest', api_key=anthropic_api_key)
m = AnthropicModel('claude-3-5-sonnet-latest', provider=AnthropicProvider(api_key=anthropic_api_key))
agent = Agent(m)

document_url = DocumentUrl(url='https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf')
Expand All @@ -647,7 +650,7 @@ async def test_document_url_input(allow_model_requests: None, anthropic_api_key:

@pytest.mark.vcr()
async def test_text_document_url_input(allow_model_requests: None, anthropic_api_key: str):
m = AnthropicModel('claude-3-5-sonnet-latest', api_key=anthropic_api_key)
m = AnthropicModel('claude-3-5-sonnet-latest', provider=AnthropicProvider(api_key=anthropic_api_key))
agent = Agent(m)

text_document_url = DocumentUrl(url='https://example-files.online-convert.com/document/txt/example.txt')
Expand All @@ -668,3 +671,17 @@ async def test_text_document_url_input(allow_model_requests: None, anthropic_api

The document is formatted as a test file with metadata including its purpose, file type, and version. It also includes attribution information indicating the content is from Wikipedia and is licensed under Attribution-ShareAlike 4.0.\
""")


def test_init_with_provider():
provider = AnthropicProvider(api_key='api-key')
model = AnthropicModel('claude-3-opus-latest', provider=provider)
assert model.model_name == 'claude-3-opus-latest'
assert model.client == provider.client


def test_init_with_provider_string():
with patch.dict(os.environ, {'ANTHROPIC_API_KEY': 'env-api-key'}, clear=False):
model = AnthropicModel('claude-3-opus-latest', provider='anthropic')
assert model.model_name == 'claude-3-opus-latest'
assert model.client is not None
Loading