Skip to content

Commit 71da71e

Browse files
committed
fix: lint
1 parent 21a6deb commit 71da71e

File tree

7 files changed

+59
-49
lines changed

7 files changed

+59
-49
lines changed

libs/partners/ai21/langchain_ai21/ai21_base.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
66
from langchain_core.utils import convert_to_secret_str
77

8+
_DEFAULT_TIMEOUT_SEC = 300
9+
810

911
class AI21Base(BaseModel):
1012
class Config:
1113
arbitrary_types_allowed = True
1214

13-
client: Optional[AI21Client] = Field(default=None)
15+
client: AI21Client = Field(default=None)
1416
api_key: Optional[SecretStr] = None
1517
api_host: Optional[str] = None
1618
timeout_sec: Optional[float] = None
@@ -30,14 +32,16 @@ def validate_environment(cls, values: Dict) -> Dict:
3032
)
3133
values["api_host"] = api_host
3234

33-
timeout_sec = values.get("timeout_sec") or os.getenv("AI21_TIMEOUT_SEC")
35+
timeout_sec = values.get("timeout_sec") or float(
36+
os.getenv("AI21_TIMEOUT_SEC", _DEFAULT_TIMEOUT_SEC)
37+
)
3438
values["timeout_sec"] = timeout_sec
3539

3640
if values.get("client") is None:
3741
values["client"] = AI21Client(
3842
api_key=api_key.get_secret_value(),
3943
api_host=api_host,
40-
timeout_sec=timeout_sec,
44+
timeout_sec=None if timeout_sec is None else float(timeout_sec),
4145
via="langchain",
4246
)
4347

libs/partners/ai21/langchain_ai21/chat_models.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,15 @@
1919
from langchain_ai21.ai21_base import AI21Base
2020

2121

22+
def _get_system_message_from_message(message: BaseMessage) -> str:
23+
if not isinstance(message.content, str):
24+
raise ValueError(
25+
f"System Message must be of type str. Got {type(message.content)}"
26+
)
27+
28+
return message.content
29+
30+
2231
def _convert_messages_to_ai21_messages(
2332
messages: List[BaseMessage],
2433
) -> Tuple[Optional[str], List[ChatMessage]]:
@@ -30,7 +39,7 @@ def _convert_messages_to_ai21_messages(
3039
if i != 0:
3140
raise ValueError("System message must be at beginning of message list.")
3241
else:
33-
system_message = message.content
42+
system_message = _get_system_message_from_message(message)
3443
else:
3544
converted_message = _convert_message_to_ai21_message(message)
3645
converted_messages.append(converted_message)
@@ -108,6 +117,11 @@ class ChatAI21(BaseChatModel, AI21Base):
108117
"""A penalty applied to tokens based on their frequency
109118
in the generated responses."""
110119

120+
class Config:
121+
"""Configuration for this pydantic object."""
122+
123+
arbitrary_types_allowed = True
124+
111125
@property
112126
def _llm_type(self) -> str:
113127
"""Return type of chat model."""

libs/partners/ai21/langchain_ai21/llms.py

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
from functools import partial
33
from typing import (
44
Any,
5-
AsyncIterator,
6-
Iterator,
75
List,
86
Optional,
97
)
@@ -14,7 +12,7 @@
1412
CallbackManagerForLLMRun,
1513
)
1614
from langchain_core.language_models import BaseLLM
17-
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
15+
from langchain_core.outputs import Generation, LLMResult
1816

1917
from langchain_ai21.ai21_base import AI21Base
2018

@@ -64,6 +62,11 @@ class AI21(BaseLLM, AI21Base):
6462
custom_model: Optional[str] = None
6563
epoch: Optional[int] = None
6664

65+
class Config:
66+
"""Configuration for this pydantic object."""
67+
68+
allow_population_by_field_name = True
69+
6770
@property
6871
def _llm_type(self) -> str:
6972
"""Return type of LLM."""
@@ -102,30 +105,12 @@ async def _agenerate(
102105
None, partial(self._generate, **kwargs), prompts, stop, run_manager
103106
)
104107

105-
def _stream(
106-
self,
107-
prompt: str,
108-
stop: Optional[List[str]] = None,
109-
run_manager: Optional[CallbackManagerForLLMRun] = None,
110-
**kwargs: Any,
111-
) -> Iterator[GenerationChunk]:
112-
raise NotImplementedError
113-
114-
async def _astream(
115-
self,
116-
prompt: str,
117-
stop: Optional[List[str]] = None,
118-
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
119-
**kwargs: Any,
120-
) -> AsyncIterator[GenerationChunk]:
121-
raise NotImplementedError
122-
123108
def _invoke_completion(
124109
self,
125110
prompt: str,
126111
model: str,
127112
stop_sequences: Optional[List[str]] = None,
128-
**kwargs,
113+
**kwargs: Any,
129114
) -> CompletionsResponse:
130115
return self.client.completion.create(
131116
prompt=prompt,

libs/partners/ai21/tests/integration_tests/test_llms.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Test AI21LLM llm."""
2+
23
import pytest
34
from ai21.models import Penalty
45

@@ -45,9 +46,8 @@ def test_stream() -> None:
4546
"""Test streaming tokens from AI21."""
4647
llm = AI21()
4748

48-
with pytest.raises(NotImplementedError):
49-
for token in llm.stream("I'm Pickle Rick"):
50-
assert isinstance(token, str)
49+
for token in llm.stream("I'm Pickle Rick"):
50+
assert isinstance(token, str)
5151

5252

5353
@pytest.mark.requires("ai21")
@@ -109,7 +109,7 @@ def test__generate() -> None:
109109
)
110110

111111
assert len(llm_result.generations) > 0
112-
assert llm_result.llm_output["token_count"] != 0
112+
assert llm_result.llm_output["token_count"] != 0 # type: ignore
113113

114114

115115
@pytest.mark.requires("ai21")
@@ -121,4 +121,4 @@ async def test__agenerate() -> None:
121121
)
122122

123123
assert len(llm_result.generations) > 0
124-
assert llm_result.llm_output["token_count"] != 0
124+
assert llm_result.llm_output["token_count"] != 0 # type: ignore

libs/partners/ai21/tests/unit_tests/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535

3636
@pytest.fixture
37-
def mocked_completion_response(mocker: MockerFixture):
37+
def mocked_completion_response(mocker: MockerFixture) -> Mock:
3838
mocked_response = mocker.MagicMock(spec=CompletionsResponse)
3939
mocked_response.prompt = "this is a test prompt"
4040
mocked_response.completions = [
@@ -48,7 +48,7 @@ def mocked_completion_response(mocker: MockerFixture):
4848

4949
@pytest.fixture
5050
def mock_client_with_completion(
51-
mocker: MockerFixture, mocked_completion_response
51+
mocker: MockerFixture, mocked_completion_response: Mock
5252
) -> Mock:
5353
mock_client = mocker.MagicMock(spec=AI21Client)
5454
mock_client.completion = mocker.MagicMock()

libs/partners/ai21/tests/unit_tests/test_chat_models.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Test chat model integration."""
22
from typing import List, Optional
3-
from unittest.mock import call
3+
from unittest.mock import Mock, call
44

55
import pytest
66
from ai21 import MissingApiKeyError
@@ -40,7 +40,7 @@ def test_initialization__when_default_parameters_in_init() -> None:
4040

4141

4242
@pytest.mark.requires("ai21")
43-
def test_initialization__when_custom_parameters_in_init():
43+
def test_initialization__when_custom_parameters_in_init() -> None:
4444
model = "j2-mid"
4545
num_results = 1
4646
max_tokens = 10
@@ -97,7 +97,7 @@ def test_initialization__when_custom_parameters_in_init():
9797
)
9898
def test_convert_message_to_ai21_message(
9999
message: BaseMessage, expected_ai21_message: ChatMessage
100-
):
100+
) -> None:
101101
ai21_message = _convert_message_to_ai21_message(message)
102102
assert ai21_message == expected_ai21_message
103103

@@ -115,8 +115,8 @@ def test_convert_message_to_ai21_message(
115115
],
116116
)
117117
def test_convert_message_to_ai21_message__when_invalid_role__should_raise_exception(
118-
message,
119-
):
118+
message: BaseMessage,
119+
) -> None:
120120
with pytest.raises(ValueError) as e:
121121
_convert_message_to_ai21_message(message)
122122
assert e.value.args[0] == (
@@ -157,15 +157,17 @@ def test_convert_message_to_ai21_message__when_invalid_role__should_raise_except
157157
],
158158
)
159159
def test_convert_messages(
160-
messages, expected_system: Optional[str], expected_messages: List[ChatMessage]
161-
):
160+
messages: List[BaseMessage],
161+
expected_system: Optional[str],
162+
expected_messages: List[ChatMessage],
163+
) -> None:
162164
system, ai21_messages = _convert_messages_to_ai21_messages(messages)
163165
assert ai21_messages == expected_messages
164166
assert system == expected_system
165167

166168

167169
@pytest.mark.requires("ai21")
168-
def test_convert_messages_when_system_is_not_first__should_raise_value_error():
170+
def test_convert_messages_when_system_is_not_first__should_raise_value_error() -> None:
169171
messages = [
170172
HumanMessage(content="Human Message Content 1"),
171173
SystemMessage(content="System Message Content 1"),
@@ -175,7 +177,7 @@ def test_convert_messages_when_system_is_not_first__should_raise_value_error():
175177

176178

177179
@pytest.mark.requires("ai21")
178-
def test_invoke(mock_client_with_chat):
180+
def test_invoke(mock_client_with_chat: Mock) -> None:
179181
chat_input = "I'm Pickle Rick"
180182

181183
llm = ChatAI21(
@@ -195,7 +197,7 @@ def test_invoke(mock_client_with_chat):
195197

196198

197199
@pytest.mark.requires("ai21")
198-
def test_generate(mock_client_with_chat):
200+
def test_generate(mock_client_with_chat: Mock) -> None:
199201
messages0 = [
200202
HumanMessage(content="I'm Pickle Rick"),
201203
AIMessage(content="Hello Pickle Rick! I am your AI Assistant"),
@@ -216,9 +218,14 @@ def test_generate(mock_client_with_chat):
216218
call(
217219
model="j2-ultra",
218220
messages=[
219-
ChatMessage(role=RoleType.USER, text=messages0[0].content),
220-
ChatMessage(role=RoleType.ASSISTANT, text=messages0[1].content),
221-
ChatMessage(role=RoleType.USER, text=messages0[2].content),
221+
ChatMessage(
222+
role=RoleType.USER,
223+
text=str(messages0[0].content),
224+
),
225+
ChatMessage(
226+
role=RoleType.ASSISTANT, text=str(messages0[1].content)
227+
),
228+
ChatMessage(role=RoleType.USER, text=str(messages0[2].content)),
222229
],
223230
system="",
224231
stop_sequences=None,
@@ -227,7 +234,7 @@ def test_generate(mock_client_with_chat):
227234
call(
228235
model="j2-ultra",
229236
messages=[
230-
ChatMessage(role=RoleType.USER, text=messages1[1].content),
237+
ChatMessage(role=RoleType.USER, text=str(messages1[1].content)),
231238
],
232239
system="system message",
233240
stop_sequences=None,

libs/partners/ai21/tests/unit_tests/test_llms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""Test AI21 Chat API wrapper."""
2-
from unittest.mock import call
2+
from unittest.mock import Mock, call
33

44
import pytest
55
from ai21 import MissingApiKeyError
@@ -51,7 +51,7 @@ def test_initialization__when_custom_parameters_to_init() -> None:
5151

5252

5353
@pytest.mark.requires("ai21")
54-
def test_generate(mock_client_with_completion):
54+
def test_generate(mock_client_with_completion: Mock) -> None:
5555
# Setup test
5656
prompt0 = "Hi, my name is what?"
5757
prompt1 = "My name is who?"

0 commit comments

Comments
 (0)