8
8
9
9
10
10
@pytest .fixture
11
- def litellm_client_cls ():
12
- with unittest .mock .patch .object (strands .models .litellm .litellm , "LiteLLM " ) as mock_client_cls :
13
- yield mock_client_cls
11
+ def litellm_acompletion ():
12
+ with unittest .mock .patch .object (strands .models .litellm .litellm , "acompletion " ) as mock_acompletion :
13
+ yield mock_acompletion
14
14
15
15
16
16
@pytest .fixture
17
- def litellm_client ( litellm_client_cls ):
18
- return litellm_client_cls . return_value
17
+ def api_key ( ):
18
+ return "a1"
19
19
20
20
21
21
@pytest .fixture
@@ -24,10 +24,10 @@ def model_id():
24
24
25
25
26
26
@pytest .fixture
27
- def model (litellm_client , model_id ):
28
- _ = litellm_client
27
+ def model (litellm_acompletion , api_key , model_id ):
28
+ _ = litellm_acompletion
29
29
30
- return LiteLLMModel (model_id = model_id )
30
+ return LiteLLMModel (client_args = { "api_key" : api_key }, model_id = model_id )
31
31
32
32
33
33
@pytest .fixture
@@ -49,17 +49,6 @@ class TestOutputModel(pydantic.BaseModel):
49
49
return TestOutputModel
50
50
51
51
52
- def test__init__ (litellm_client_cls , model_id ):
53
- model = LiteLLMModel ({"api_key" : "k1" }, model_id = model_id , params = {"max_tokens" : 1 })
54
-
55
- tru_config = model .get_config ()
56
- exp_config = {"model_id" : "m1" , "params" : {"max_tokens" : 1 }}
57
-
58
- assert tru_config == exp_config
59
-
60
- litellm_client_cls .assert_called_once_with (api_key = "k1" )
61
-
62
-
63
52
def test_update_config (model , model_id ):
64
53
model .update_config (model_id = model_id )
65
54
@@ -116,7 +105,7 @@ def test_format_request_message_content(content, exp_result):
116
105
117
106
118
107
@pytest .mark .asyncio
119
- async def test_stream (litellm_client , model , alist ):
108
+ async def test_stream (litellm_acompletion , api_key , model_id , model , agenerator , alist ):
120
109
mock_tool_call_1_part_1 = unittest .mock .Mock (index = 0 )
121
110
mock_tool_call_2_part_1 = unittest .mock .Mock (index = 1 )
122
111
mock_delta_1 = unittest .mock .Mock (
@@ -148,8 +137,8 @@ async def test_stream(litellm_client, model, alist):
148
137
mock_event_5 = unittest .mock .Mock (choices = [unittest .mock .Mock (finish_reason = "tool_calls" , delta = mock_delta_5 )])
149
138
mock_event_6 = unittest .mock .Mock ()
150
139
151
- litellm_client . chat . completions . create . return_value = iter (
152
- [mock_event_1 , mock_event_2 , mock_event_3 , mock_event_4 , mock_event_5 , mock_event_6 ]
140
+ litellm_acompletion . side_effect = unittest . mock . AsyncMock (
141
+ return_value = agenerator ( [mock_event_1 , mock_event_2 , mock_event_3 , mock_event_4 , mock_event_5 , mock_event_6 ])
153
142
)
154
143
155
144
messages = [{"role" : "user" , "content" : [{"type" : "text" , "text" : "calculate 2+2" }]}]
@@ -196,18 +185,20 @@ async def test_stream(litellm_client, model, alist):
196
185
]
197
186
198
187
assert tru_events == exp_events
188
+
199
189
expected_request = {
200
- "model" : "m1" ,
190
+ "api_key" : api_key ,
191
+ "model" : model_id ,
201
192
"messages" : [{"role" : "user" , "content" : [{"text" : "calculate 2+2" , "type" : "text" }]}],
202
193
"stream" : True ,
203
194
"stream_options" : {"include_usage" : True },
204
195
"tools" : [],
205
196
}
206
- litellm_client . chat . completions . create .assert_called_once_with (** expected_request )
197
+ litellm_acompletion .assert_called_once_with (** expected_request )
207
198
208
199
209
200
@pytest .mark .asyncio
210
- async def test_structured_output (litellm_client , model , test_output_model_cls , alist ):
201
+ async def test_structured_output (litellm_acompletion , model , test_output_model_cls , alist ):
211
202
messages = [{"role" : "user" , "content" : [{"text" : "Generate a person" }]}]
212
203
213
204
mock_choice = unittest .mock .Mock ()
@@ -216,7 +207,7 @@ async def test_structured_output(litellm_client, model, test_output_model_cls, a
216
207
mock_response = unittest .mock .Mock ()
217
208
mock_response .choices = [mock_choice ]
218
209
219
- litellm_client . chat . completions . create . return_value = mock_response
210
+ litellm_acompletion . side_effect = unittest . mock . AsyncMock ( return_value = mock_response )
220
211
221
212
with unittest .mock .patch .object (strands .models .litellm , "supports_response_schema" , return_value = True ):
222
213
stream = model .structured_output (test_output_model_cls , messages )
0 commit comments