11"""Test chat model integration."""
22from typing import List , Optional
3- from unittest .mock import call
3+ from unittest .mock import Mock , call
44
55import pytest
66from ai21 import MissingApiKeyError
@@ -40,7 +40,7 @@ def test_initialization__when_default_parameters_in_init() -> None:
4040
4141
4242@pytest .mark .requires ("ai21" )
43- def test_initialization__when_custom_parameters_in_init ():
43+ def test_initialization__when_custom_parameters_in_init () -> None :
4444 model = "j2-mid"
4545 num_results = 1
4646 max_tokens = 10
@@ -97,7 +97,7 @@ def test_initialization__when_custom_parameters_in_init():
9797)
9898def test_convert_message_to_ai21_message (
9999 message : BaseMessage , expected_ai21_message : ChatMessage
100- ):
100+ ) -> None :
101101 ai21_message = _convert_message_to_ai21_message (message )
102102 assert ai21_message == expected_ai21_message
103103
@@ -115,8 +115,8 @@ def test_convert_message_to_ai21_message(
115115 ],
116116)
117117def test_convert_message_to_ai21_message__when_invalid_role__should_raise_exception (
118- message ,
119- ):
118+ message : BaseMessage ,
119+ ) -> None :
120120 with pytest .raises (ValueError ) as e :
121121 _convert_message_to_ai21_message (message )
122122 assert e .value .args [0 ] == (
@@ -157,15 +157,17 @@ def test_convert_message_to_ai21_message__when_invalid_role__should_raise_except
157157 ],
158158)
159159def test_convert_messages (
160- messages , expected_system : Optional [str ], expected_messages : List [ChatMessage ]
161- ):
160+ messages : List [BaseMessage ],
161+ expected_system : Optional [str ],
162+ expected_messages : List [ChatMessage ],
163+ ) -> None :
162164 system , ai21_messages = _convert_messages_to_ai21_messages (messages )
163165 assert ai21_messages == expected_messages
164166 assert system == expected_system
165167
166168
167169@pytest .mark .requires ("ai21" )
168- def test_convert_messages_when_system_is_not_first__should_raise_value_error ():
170+ def test_convert_messages_when_system_is_not_first__should_raise_value_error () -> None :
169171 messages = [
170172 HumanMessage (content = "Human Message Content 1" ),
171173 SystemMessage (content = "System Message Content 1" ),
@@ -175,7 +177,7 @@ def test_convert_messages_when_system_is_not_first__should_raise_value_error():
175177
176178
177179@pytest .mark .requires ("ai21" )
178- def test_invoke (mock_client_with_chat ) :
180+ def test_invoke (mock_client_with_chat : Mock ) -> None :
179181 chat_input = "I'm Pickle Rick"
180182
181183 llm = ChatAI21 (
@@ -195,7 +197,7 @@ def test_invoke(mock_client_with_chat):
195197
196198
197199@pytest .mark .requires ("ai21" )
198- def test_generate (mock_client_with_chat ) :
200+ def test_generate (mock_client_with_chat : Mock ) -> None :
199201 messages0 = [
200202 HumanMessage (content = "I'm Pickle Rick" ),
201203 AIMessage (content = "Hello Pickle Rick! I am your AI Assistant" ),
@@ -216,9 +218,14 @@ def test_generate(mock_client_with_chat):
216218 call (
217219 model = "j2-ultra" ,
218220 messages = [
219- ChatMessage (role = RoleType .USER , text = messages0 [0 ].content ),
220- ChatMessage (role = RoleType .ASSISTANT , text = messages0 [1 ].content ),
221- ChatMessage (role = RoleType .USER , text = messages0 [2 ].content ),
221+ ChatMessage (
222+ role = RoleType .USER ,
223+ text = str (messages0 [0 ].content ),
224+ ),
225+ ChatMessage (
226+ role = RoleType .ASSISTANT , text = str (messages0 [1 ].content )
227+ ),
228+ ChatMessage (role = RoleType .USER , text = str (messages0 [2 ].content )),
222229 ],
223230 system = "" ,
224231 stop_sequences = None ,
@@ -227,7 +234,7 @@ def test_generate(mock_client_with_chat):
227234 call (
228235 model = "j2-ultra" ,
229236 messages = [
230- ChatMessage (role = RoleType .USER , text = messages1 [1 ].content ),
237+ ChatMessage (role = RoleType .USER , text = str ( messages1 [1 ].content ) ),
231238 ],
232239 system = "system message" ,
233240 stop_sequences = None ,
0 commit comments