Skip to content

Commit

Permalink
mistralai[patch]: standardize model params (#20163)
Browse files Browse the repository at this point in the history
Related to #20085
  • Loading branch information
baskaryan authored Apr 8, 2024
1 parent 1718240 commit 3490d70
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 12 deletions.
6 changes: 3 additions & 3 deletions docs/docs/integrations/chat/mistralai.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
"source": [
"import getpass\n",
"\n",
"mistral_api_key = getpass.getpass()"
"api_key = getpass.getpass()"
]
},
{
Expand Down Expand Up @@ -81,8 +81,8 @@
},
"outputs": [],
"source": [
"# If mistral_api_key is not passed, default behavior is to use the `MISTRAL_API_KEY` environment variable.\n",
"chat = ChatMistralAI(mistral_api_key=mistral_api_key)"
"# If api_key is not passed, default behavior is to use the `MISTRAL_API_KEY` environment variable.\n",
"chat = ChatMistralAI(api_key=api_key)"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion docs/docs/integrations/text_embedding/mistralai.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
"metadata": {},
"outputs": [],
"source": [
"embedding = MistralAIEmbeddings(mistral_api_key=\"your-api-key\")"
"embedding = MistralAIEmbeddings(api_key=\"your-api-key\")"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion docs/docs/modules/model_io/chat/quick_start.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@
"<ChatModelTabs\n",
" openaiParams={`model=\"gpt-3.5-turbo-0125\", api_key=\"...\"`}\n",
" anthropicParams={`model=\"claude-3-sonnet-20240229\", anthropic_api_key=\"...\"`}\n",
" mistralParams={`model=\"mistral-large-latest\", api_key=\"...\"`}\n",
" fireworksParams={`model=\"accounts/fireworks/models/mixtral-8x7b-instruct\", api_key=\"...\"`}\n",
" mistralParams={`model=\"mistral-large-latest\", mistral_api_key=\"...\"`}\n",
" googleParams={`model=\"gemini-pro\", google_api_key=\"...\"`}\n",
" togetherParams={`, together_api_key=\"...\"`}\n",
" customVarName=\"chat\"\n",
Expand Down
8 changes: 7 additions & 1 deletion libs/partners/mistralai/langchain_mistralai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ class ChatMistralAI(BaseChatModel):

client: httpx.Client = Field(default=None) #: :meta private:
async_client: httpx.AsyncClient = Field(default=None) #: :meta private:
mistral_api_key: Optional[SecretStr] = None
mistral_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
endpoint: str = "https://api.mistral.ai/v1"
max_retries: int = 5
timeout: int = 120
Expand All @@ -202,6 +202,12 @@ class ChatMistralAI(BaseChatModel):
safe_mode: bool = False
streaming: bool = False

class Config:
"""Configuration for this pydantic object."""

allow_population_by_field_name = True
arbitrary_types_allowed = True

@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling the API."""
Expand Down
6 changes: 4 additions & 2 deletions libs/partners/mistralai/langchain_mistralai/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,16 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
.. code-block:: python
from langchain_mistralai import MistralAIEmbeddings
mistral = MistralAIEmbeddings(
model="mistral-embed",
mistral_api_key="my-api-key"
api_key="my-api-key"
)
"""

client: httpx.Client = Field(default=None) #: :meta private:
async_client: httpx.AsyncClient = Field(default=None) #: :meta private:
mistral_api_key: Optional[SecretStr] = None
mistral_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
endpoint: str = "https://api.mistral.ai/v1/"
max_retries: int = 5
timeout: int = 120
Expand All @@ -49,6 +50,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
class Config:
extra = Extra.forbid
arbitrary_types_allowed = True
allow_population_by_field_name = True

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
Expand Down
9 changes: 7 additions & 2 deletions libs/partners/mistralai/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Test MistralAI Chat API wrapper."""

import os
from typing import Any, AsyncGenerator, Dict, Generator
from typing import Any, AsyncGenerator, Dict, Generator, cast
from unittest.mock import patch

import pytest
Expand All @@ -13,6 +13,7 @@
HumanMessage,
SystemMessage,
)
from langchain_core.pydantic_v1 import SecretStr

from langchain_mistralai.chat_models import ( # type: ignore[import]
ChatMistralAI,
Expand All @@ -31,7 +32,11 @@ def test_mistralai_initialization() -> None:
"""Test ChatMistralAI initialization."""
# Verify that ChatMistralAI can be initialized using a secret key provided
# as a parameter rather than an environment variable.
ChatMistralAI(model="test", mistral_api_key="test")
for model in [
ChatMistralAI(model="test", mistral_api_key="test"),
ChatMistralAI(model="test", api_key="test"),
]:
assert cast(SecretStr, model.mistral_api_key).get_secret_value() == "test"


@pytest.mark.parametrize(
Expand Down
11 changes: 9 additions & 2 deletions libs/partners/mistralai/tests/unit_tests/test_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
import os
from typing import cast

from langchain_core.pydantic_v1 import SecretStr

from langchain_mistralai import MistralAIEmbeddings

os.environ["MISTRAL_API_KEY"] = "foo"


def test_mistral_init() -> None:
embeddings = MistralAIEmbeddings()
assert embeddings.model == "mistral-embed"
for model in [
MistralAIEmbeddings(model="mistral-embed", mistral_api_key="test"),
MistralAIEmbeddings(model="mistral-embed", api_key="test"),
]:
assert model.model == "mistral-embed"
assert cast(SecretStr, model.mistral_api_key).get_secret_value() == "test"

0 comments on commit 3490d70

Please sign in to comment.