Skip to content

Commit 0805ebf

Browse files
authored
fix: Merge pull request #2 from AI21Labs/ai21_prepare_pr_for_merge
fix: AI21 prepare pr for merge
2 parents 3498321 + 71da71e commit 0805ebf

File tree

10 files changed

+100
-86
lines changed

10 files changed

+100
-86
lines changed

libs/partners/ai21/langchain_ai21/ai21_base.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
import os
2-
from typing import Optional, Dict
2+
from typing import Dict, Optional
33

44
from ai21 import AI21Client
5-
6-
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator, SecretStr
5+
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
76
from langchain_core.utils import convert_to_secret_str
87

8+
_DEFAULT_TIMEOUT_SEC = 300
9+
910

1011
class AI21Base(BaseModel):
1112
class Config:
1213
arbitrary_types_allowed = True
1314

14-
client: Optional[AI21Client] = Field(default=None)
15+
client: AI21Client = Field(default=None)
1516
api_key: Optional[SecretStr] = None
1617
api_host: Optional[str] = None
1718
timeout_sec: Optional[float] = None
@@ -25,20 +26,22 @@ def validate_environment(cls, values: Dict) -> Dict:
2526
values["api_key"] = api_key
2627

2728
api_host = (
28-
values.get("api_host")
29-
or os.getenv("AI21_API_URL")
30-
or "https://api.ai21.com"
29+
values.get("api_host")
30+
or os.getenv("AI21_API_URL")
31+
or "https://api.ai21.com"
3132
)
3233
values["api_host"] = api_host
3334

34-
timeout_sec = values.get("timeout_sec") or os.getenv("AI21_TIMEOUT_SEC")
35+
timeout_sec = values.get("timeout_sec") or float(
36+
os.getenv("AI21_TIMEOUT_SEC", _DEFAULT_TIMEOUT_SEC)
37+
)
3538
values["timeout_sec"] = timeout_sec
3639

37-
if values.get('client') is None:
40+
if values.get("client") is None:
3841
values["client"] = AI21Client(
3942
api_key=api_key.get_secret_value(),
4043
api_host=api_host,
41-
timeout_sec=timeout_sec,
44+
timeout_sec=None if timeout_sec is None else float(timeout_sec),
4245
via="langchain",
4346
)
4447

libs/partners/ai21/langchain_ai21/chat_models.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
import asyncio
22
from functools import partial
3-
from typing import Any, AsyncIterator, Iterator, List, Optional, cast, Tuple
3+
from typing import Any, List, Optional, Tuple, cast
44

5-
from ai21.models import ChatMessage, RoleType, Penalty
6-
7-
from langchain_ai21.ai21_base import AI21Base
5+
from ai21.models import ChatMessage, Penalty, RoleType
86
from langchain_core.callbacks import (
97
AsyncCallbackManagerForLLMRun,
108
CallbackManagerForLLMRun,
@@ -16,7 +14,18 @@
1614
HumanMessage,
1715
SystemMessage,
1816
)
19-
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
17+
from langchain_core.outputs import ChatGeneration, ChatResult
18+
19+
from langchain_ai21.ai21_base import AI21Base
20+
21+
22+
def _get_system_message_from_message(message: BaseMessage) -> str:
23+
if not isinstance(message.content, str):
24+
raise ValueError(
25+
f"System Message must be of type str. Got {type(message.content)}"
26+
)
27+
28+
return message.content
2029

2130

2231
def _convert_messages_to_ai21_messages(
@@ -30,7 +39,7 @@ def _convert_messages_to_ai21_messages(
3039
if i != 0:
3140
raise ValueError("System message must be at beginning of message list.")
3241
else:
33-
system_message = message.content
42+
system_message = _get_system_message_from_message(message)
3443
else:
3544
converted_message = _convert_message_to_ai21_message(message)
3645
converted_messages.append(converted_message)
@@ -105,7 +114,13 @@ class ChatAI21(BaseChatModel, AI21Base):
105114
""" A penalty applied to tokens that are already present in the prompt."""
106115

107116
count_penalty: Optional[Penalty] = None
108-
"""A penalty applied to tokens based on their frequency in the generated responses."""
117+
"""A penalty applied to tokens based on their frequency
118+
in the generated responses."""
119+
120+
class Config:
121+
"""Configuration for this pydantic object."""
122+
123+
arbitrary_types_allowed = True
109124

110125
@property
111126
def _llm_type(self) -> str:

libs/partners/ai21/langchain_ai21/embeddings.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
from typing import List, Any
1+
from typing import Any, List
22

33
from ai21.models import EmbedType
4+
from langchain_core.embeddings import Embeddings
45

56
from langchain_ai21.ai21_base import AI21Base
6-
from langchain_core.embeddings import Embeddings
77

88

99
class AI21Embeddings(Embeddings, AI21Base):
@@ -19,6 +19,7 @@ class AI21Embeddings(Embeddings, AI21Base):
1919
embeddings = AI21Embeddings()
2020
query_result = embeddings.embed_query("Hello embeddings world!")
2121
"""
22+
2223
def embed_documents(self, texts: List[str], **kwargs: Any) -> List[List[float]]:
2324
"""Embed search docs."""
2425
response = self.client.embed.create(

libs/partners/ai21/langchain_ai21/llms.py

Lines changed: 11 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,19 @@
22
from functools import partial
33
from typing import (
44
Any,
5-
AsyncIterator,
6-
Iterator,
75
List,
86
Optional,
97
)
108

119
from ai21.models import CompletionsResponse, Penalty
12-
13-
from langchain_ai21.ai21_base import AI21Base
1410
from langchain_core.callbacks import (
1511
AsyncCallbackManagerForLLMRun,
1612
CallbackManagerForLLMRun,
1713
)
1814
from langchain_core.language_models import BaseLLM
19-
from langchain_core.outputs import GenerationChunk, LLMResult, Generation, RunInfo
15+
from langchain_core.outputs import Generation, LLMResult
16+
17+
from langchain_ai21.ai21_base import AI21Base
2018

2119

2220
class AI21(BaseLLM, AI21Base):
@@ -58,11 +56,17 @@ class AI21(BaseLLM, AI21Base):
5856
""" A penalty applied to tokens that are already present in the prompt."""
5957

6058
count_penalty: Optional[Penalty] = None
61-
"""A penalty applied to tokens based on their frequency in the generated responses."""
59+
"""A penalty applied to tokens based on their frequency
60+
in the generated responses."""
6261

6362
custom_model: Optional[str] = None
6463
epoch: Optional[int] = None
6564

65+
class Config:
66+
"""Configuration for this pydantic object."""
67+
68+
allow_population_by_field_name = True
69+
6670
@property
6771
def _llm_type(self) -> str:
6872
"""Return type of LLM."""
@@ -101,30 +105,12 @@ async def _agenerate(
101105
None, partial(self._generate, **kwargs), prompts, stop, run_manager
102106
)
103107

104-
def _stream(
105-
self,
106-
prompt: str,
107-
stop: Optional[List[str]] = None,
108-
run_manager: Optional[CallbackManagerForLLMRun] = None,
109-
**kwargs: Any,
110-
) -> Iterator[GenerationChunk]:
111-
raise NotImplementedError
112-
113-
async def _astream(
114-
self,
115-
prompt: str,
116-
stop: Optional[List[str]] = None,
117-
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
118-
**kwargs: Any,
119-
) -> AsyncIterator[GenerationChunk]:
120-
raise NotImplementedError
121-
122108
def _invoke_completion(
123109
self,
124110
prompt: str,
125111
model: str,
126112
stop_sequences: Optional[List[str]] = None,
127-
**kwargs,
113+
**kwargs: Any,
128114
) -> CompletionsResponse:
129115
return self.client.completion.create(
130116
prompt=prompt,

libs/partners/ai21/tests/integration_tests/test_chat_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
"""Test ChatAI21 chat model."""
22
import pytest
3-
4-
from langchain_ai21.chat_models import ChatAI21
53
from langchain_core.messages import HumanMessage
64
from langchain_core.outputs import ChatGeneration
75

6+
from langchain_ai21.chat_models import ChatAI21
7+
88

99
@pytest.mark.requires("ai21")
1010
def test_invoke() -> None:

libs/partners/ai21/tests/integration_tests/test_llms.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Test AI21LLM llm."""
2+
23
import pytest
34
from ai21.models import Penalty
45

@@ -45,9 +46,8 @@ def test_stream() -> None:
4546
"""Test streaming tokens from AI21."""
4647
llm = AI21()
4748

48-
with pytest.raises(NotImplementedError):
49-
for token in llm.stream("I'm Pickle Rick"):
50-
assert isinstance(token, str)
49+
for token in llm.stream("I'm Pickle Rick"):
50+
assert isinstance(token, str)
5151

5252

5353
@pytest.mark.requires("ai21")
@@ -109,7 +109,7 @@ def test__generate() -> None:
109109
)
110110

111111
assert len(llm_result.generations) > 0
112-
assert llm_result.llm_output["token_count"] != 0
112+
assert llm_result.llm_output["token_count"] != 0 # type: ignore
113113

114114

115115
@pytest.mark.requires("ai21")
@@ -121,4 +121,4 @@ async def test__agenerate() -> None:
121121
)
122122

123123
assert len(llm_result.generations) > 0
124-
assert llm_result.llm_output["token_count"] != 0
124+
assert llm_result.llm_output["token_count"] != 0 # type: ignore

libs/partners/ai21/tests/unit_tests/conftest.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33
import pytest
44
from ai21 import AI21Client
55
from ai21.models import (
6-
CompletionsResponse,
6+
ChatOutput,
7+
ChatResponse,
78
Completion,
89
CompletionData,
910
CompletionFinishReason,
10-
ChatResponse,
11+
CompletionsResponse,
12+
FinishReason,
1113
Penalty,
12-
ChatOutput,
1314
RoleType,
14-
FinishReason,
1515
)
1616
from pytest_mock import MockerFixture
1717

@@ -34,7 +34,7 @@
3434

3535

3636
@pytest.fixture
37-
def mocked_completion_response(mocker: MockerFixture):
37+
def mocked_completion_response(mocker: MockerFixture) -> Mock:
3838
mocked_response = mocker.MagicMock(spec=CompletionsResponse)
3939
mocked_response.prompt = "this is a test prompt"
4040
mocked_response.completions = [
@@ -48,7 +48,7 @@ def mocked_completion_response(mocker: MockerFixture):
4848

4949
@pytest.fixture
5050
def mock_client_with_completion(
51-
mocker: MockerFixture, mocked_completion_response
51+
mocker: MockerFixture, mocked_completion_response: Mock
5252
) -> Mock:
5353
mock_client = mocker.MagicMock(spec=AI21Client)
5454
mock_client.completion = mocker.MagicMock()

0 commit comments

Comments
 (0)