Skip to content

Add Mistral provider #1118

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions docs/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ pip/uv-add 'pydantic-ai-slim[mistral]'

To use [Mistral](https://mistral.ai) through their API, go to [console.mistral.ai/api-keys/](https://console.mistral.ai/api-keys/) and follow your nose until you find the place to generate an API key.

[`MistralModelName`][pydantic_ai.models.mistral.MistralModelName] contains a list of the most popular Mistral models.
[`LatestMistralModelNames`][pydantic_ai.models.mistral.LatestMistralModelNames] contains a list of the most popular Mistral models.

### Environment variable

Expand Down Expand Up @@ -537,15 +537,37 @@ agent = Agent(model)
...
```

### `api_key` argument
### `provider` argument

You can provide a custom [`Provider`][pydantic_ai.providers.Provider] via the
[`provider` argument][pydantic_ai.models.mistral.MistralModel.__init__]:

```python {title="groq_model_provider.py"}
from pydantic_ai import Agent
from pydantic_ai.models.mistral import MistralModel
from pydantic_ai.providers.mistral import MistralProvider

model = MistralModel(
'mistral-large-latest', provider=MistralProvider(api_key='your-api-key')
)
agent = Agent(model)
...
```

If you don't want to or can't set the environment variable, you can pass it at runtime via the [`api_key` argument][pydantic_ai.models.mistral.MistralModel.__init__]:
You can also customize the provider with a custom `httpx.AsyncHTTPClient`:

```python {title="groq_model_custom_provider.py"}
from httpx import AsyncClient

```python {title="mistral_model_api_key.py"}
from pydantic_ai import Agent
from pydantic_ai.models.mistral import MistralModel
from pydantic_ai.providers.mistral import MistralProvider

model = MistralModel('mistral-small-latest', api_key='your-api-key')
custom_http_client = AsyncClient(timeout=30)
model = MistralModel(
'mistral-large-latest',
provider=MistralProvider(api_key='your-api-key', http_client=custom_http_client),
)
agent = Agent(model)
...
```
Expand Down
8 changes: 3 additions & 5 deletions pydantic_ai_slim/pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,9 @@ def __init__(

if provider is not None:
if isinstance(provider, str):
self._system = provider
self.client = infer_provider(provider).client
else:
self._system = provider.name
self.client = provider.client
provider = infer_provider(provider)
self._system = provider.name
self.client = provider.client
self._url = str(self.client.base_url)
else:
if api_key is None:
Expand Down
5 changes: 2 additions & 3 deletions pydantic_ai_slim/pydantic_ai/models/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,8 @@ def __init__(

if provider is not None:
if isinstance(provider, str):
self.client = infer_provider(provider).client
else:
self.client = provider.client
provider = infer_provider(provider)
self.client = provider.client
elif groq_client is not None:
assert http_client is None, 'Cannot provide both `groq_client` and `http_client`'
assert api_key is None, 'Cannot provide both `groq_client` and `api_key`'
Expand Down
42 changes: 37 additions & 5 deletions pydantic_ai_slim/pydantic_ai/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
from dataclasses import dataclass, field
from datetime import datetime, timezone
from itertools import chain
from typing import Any, Callable, Literal, Union, cast
from typing import Any, Callable, Literal, Union, cast, overload

import pydantic_core
from httpx import AsyncClient as AsyncHTTPClient, Timeout
from typing_extensions import assert_never
from typing_extensions import assert_never, deprecated

from .. import ModelHTTPError, UnexpectedModelBehavior, _utils
from .._utils import now_utc as _now_utc
Expand All @@ -31,6 +31,7 @@
ToolReturnPart,
UserPromptPart,
)
from ..providers import Provider, infer_provider
from ..result import Usage
from ..settings import ModelSettings
from ..tools import ToolDefinition
Expand Down Expand Up @@ -112,10 +113,33 @@ class MistralModel(Model):
_model_name: MistralModelName = field(repr=False)
_system: str = field(default='mistral_ai', repr=False)

@overload
def __init__(
self,
model_name: MistralModelName,
*,
provider: Literal['mistral'] | Provider[Mistral] = 'mistral',
json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n""",
) -> None: ...

@overload
@deprecated('Use the `provider` parameter instead of `api_key`, `client` and `http_client`.')
def __init__(
self,
model_name: MistralModelName,
*,
provider: None = None,
api_key: str | Callable[[], str | None] | None = None,
client: Mistral | None = None,
http_client: AsyncHTTPClient | None = None,
json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n""",
) -> None: ...

def __init__(
self,
model_name: MistralModelName,
*,
provider: Literal['mistral'] | Provider[Mistral] | None = None,
api_key: str | Callable[[], str | None] | None = None,
client: Mistral | None = None,
http_client: AsyncHTTPClient | None = None,
Expand All @@ -124,6 +148,9 @@ def __init__(
"""Initialize a Mistral model.

Args:
provider: The provider to use for authentication and API access. Can be either the string
'mistral' or an instance of `Provider[Mistral]`. If not provided, a new provider will be
created using the other parameters.
model_name: The name of the model to use.
api_key: The API key to use for authentication, if unset uses `MISTRAL_API_KEY` environment variable.
client: An existing `Mistral` client to use, if provided, `api_key` and `http_client` must be `None`.
Expand All @@ -133,17 +160,22 @@ def __init__(
self._model_name = model_name
self.json_mode_schema_prompt = json_mode_schema_prompt

if client is not None:
if provider is not None:
if isinstance(provider, str):
# TODO(Marcelo): We should add an integration test with VCR when I get the API key.
provider = infer_provider(provider) # pragma: no cover
self.client = provider.client
elif client is not None:
assert http_client is None, 'Cannot provide both `mistral_client` and `http_client`'
assert api_key is None, 'Cannot provide both `mistral_client` and `api_key`'
self.client = client
else:
api_key = os.getenv('MISTRAL_API_KEY') if api_key is None else api_key
api_key = api_key or os.getenv('MISTRAL_API_KEY')
self.client = Mistral(api_key=api_key, async_client=http_client or cached_async_http_client())

@property
def base_url(self) -> str:
return str(self.client.sdk_configuration.get_server_details()[0])
return self.client.sdk_configuration.get_server_details()[0]

async def request(
self,
Expand Down
5 changes: 2 additions & 3 deletions pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,8 @@ def __init__(

if provider is not None:
if isinstance(provider, str):
self.client = infer_provider(provider).client
else:
self.client = provider.client
provider = infer_provider(provider)
self.client = provider.client
else: # pragma: no cover
# This is a workaround for the OpenAI client requiring an API key, whilst locally served,
# openai compatible models do not always need an API key, but a placeholder (non-empty) key is required.
Expand Down
6 changes: 5 additions & 1 deletion pydantic_ai_slim/pydantic_ai/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,13 @@ def infer_provider(provider: str) -> Provider[Any]:
from .groq import GroqProvider

return GroqProvider()
elif provider == 'anthropic': # pragma: no cover
elif provider == 'anthropic':
from .anthropic import AnthropicProvider

return AnthropicProvider()
elif provider == 'mistral':
from .mistral import MistralProvider

return MistralProvider()
else: # pragma: no cover
raise ValueError(f'Unknown provider: {provider}')
73 changes: 73 additions & 0 deletions pydantic_ai_slim/pydantic_ai/providers/mistral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from __future__ import annotations as _annotations

import os
from typing import overload

from httpx import AsyncClient as AsyncHTTPClient

from pydantic_ai.models import cached_async_http_client

try:
from mistralai import Mistral
except ImportError as e: # pragma: no cover
raise ImportError(
'Please install the `mistral` package to use the Mistral provider, '
"you can use the `mistral` optional group — `pip install 'pydantic-ai-slim[mistral]'`"
) from e


from . import Provider


class MistralProvider(Provider[Mistral]):
"""Provider for Mistral API."""

@property
def name(self) -> str:
return 'mistral'

@property
def base_url(self) -> str:
return self.client.sdk_configuration.get_server_details()[0]

@property
def client(self) -> Mistral:
return self._client

@overload
def __init__(self, *, mistral_client: Mistral | None = None) -> None: ...

@overload
def __init__(self, *, api_key: str | None = None, http_client: AsyncHTTPClient | None = None) -> None: ...

def __init__(
self,
*,
api_key: str | None = None,
mistral_client: Mistral | None = None,
http_client: AsyncHTTPClient | None = None,
) -> None:
"""Create a new Mistral provider.

Args:
api_key: The API key to use for authentication, if not provided, the `MISTRAL_API_KEY` environment variable
will be used if available.
mistral_client: An existing `Mistral` client to use, if provided, `api_key` and `http_client` must be `None`.
http_client: An existing async client to use for making HTTP requests.
"""
api_key = api_key or os.environ.get('MISTRAL_API_KEY')

if api_key is None and mistral_client is None:
raise ValueError(
'Set the `MISTRAL_API_KEY` environment variable or pass it via `MistralProvider(api_key=...)`'
'to use the Mistral provider.'
)

if mistral_client is not None:
assert http_client is None, 'Cannot provide both `mistral_client` and `http_client`'
assert api_key is None, 'Cannot provide both `mistral_client` and `api_key`'
self._client = mistral_client
elif http_client is not None:
self._client = Mistral(api_key=api_key, async_client=http_client)
else:
self._client = Mistral(api_key=api_key, async_client=cached_async_http_client())
Loading