Skip to content

Commit ca4a567

Browse files
authored
models - litellm - async (#414)
1 parent 289abae commit ca4a567

File tree

3 files changed

+64
-46
lines changed

3 files changed

+64
-46
lines changed

src/strands/models/litellm.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,11 @@ def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config:
4848
https://github.com/BerriAI/litellm/blob/main/litellm/main.py.
4949
**model_config: Configuration options for the LiteLLM model.
5050
"""
51+
self.client_args = client_args or {}
5152
self.config = dict(model_config)
5253

5354
logger.debug("config=<%s> | initializing", self.config)
5455

55-
client_args = client_args or {}
56-
self.client = litellm.LiteLLM(**client_args)
57-
5856
@override
5957
def update_config(self, **model_config: Unpack[LiteLLMConfig]) -> None: # type: ignore[override]
6058
"""Update the LiteLLM model configuration with the provided arguments.
@@ -124,15 +122,15 @@ async def stream(
124122
logger.debug("formatted request=<%s>", request)
125123

126124
logger.debug("invoking model")
127-
response = self.client.chat.completions.create(**request)
125+
response = await litellm.acompletion(**self.client_args, **request)
128126

129127
logger.debug("got response from model")
130128
yield self.format_chunk({"chunk_type": "message_start"})
131129
yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"})
132130

133131
tool_calls: dict[int, list[Any]] = {}
134132

135-
for event in response:
133+
async for event in response:
136134
# Defensive: skip events with empty or missing choices
137135
if not getattr(event, "choices", None):
138136
continue
@@ -171,7 +169,7 @@ async def stream(
171169
yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason})
172170

173171
# Skip remaining events as we don't have use for anything except the final usage payload
174-
for event in response:
172+
async for event in response:
175173
_ = event
176174

177175
yield self.format_chunk({"chunk_type": "metadata", "data": event.usage})
@@ -191,10 +189,8 @@ async def structured_output(
191189
Yields:
192190
Model events with the last being the structured output.
193191
"""
194-
# The LiteLLM `Client` inits with Chat().
195-
# Chat() inits with self.completions
196-
# completions() has a method `create()` which wraps the real completion API of Litellm
197-
response = self.client.chat.completions.create(
192+
response = await litellm.acompletion(
193+
**self.client_args,
198194
model=self.get_config()["model_id"],
199195
messages=self.format_request(prompt)["messages"],
200196
response_format=output_model,

tests/strands/models/test_litellm.py

Lines changed: 17 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@
88

99

1010
@pytest.fixture
11-
def litellm_client_cls():
12-
with unittest.mock.patch.object(strands.models.litellm.litellm, "LiteLLM") as mock_client_cls:
13-
yield mock_client_cls
11+
def litellm_acompletion():
12+
with unittest.mock.patch.object(strands.models.litellm.litellm, "acompletion") as mock_acompletion:
13+
yield mock_acompletion
1414

1515

1616
@pytest.fixture
17-
def litellm_client(litellm_client_cls):
18-
return litellm_client_cls.return_value
17+
def api_key():
18+
return "a1"
1919

2020

2121
@pytest.fixture
@@ -24,10 +24,10 @@ def model_id():
2424

2525

2626
@pytest.fixture
27-
def model(litellm_client, model_id):
28-
_ = litellm_client
27+
def model(litellm_acompletion, api_key, model_id):
28+
_ = litellm_acompletion
2929

30-
return LiteLLMModel(model_id=model_id)
30+
return LiteLLMModel(client_args={"api_key": api_key}, model_id=model_id)
3131

3232

3333
@pytest.fixture
@@ -49,17 +49,6 @@ class TestOutputModel(pydantic.BaseModel):
4949
return TestOutputModel
5050

5151

52-
def test__init__(litellm_client_cls, model_id):
53-
model = LiteLLMModel({"api_key": "k1"}, model_id=model_id, params={"max_tokens": 1})
54-
55-
tru_config = model.get_config()
56-
exp_config = {"model_id": "m1", "params": {"max_tokens": 1}}
57-
58-
assert tru_config == exp_config
59-
60-
litellm_client_cls.assert_called_once_with(api_key="k1")
61-
62-
6352
def test_update_config(model, model_id):
6453
model.update_config(model_id=model_id)
6554

@@ -116,7 +105,7 @@ def test_format_request_message_content(content, exp_result):
116105

117106

118107
@pytest.mark.asyncio
119-
async def test_stream(litellm_client, model, alist):
108+
async def test_stream(litellm_acompletion, api_key, model_id, model, agenerator, alist):
120109
mock_tool_call_1_part_1 = unittest.mock.Mock(index=0)
121110
mock_tool_call_2_part_1 = unittest.mock.Mock(index=1)
122111
mock_delta_1 = unittest.mock.Mock(
@@ -148,8 +137,8 @@ async def test_stream(litellm_client, model, alist):
148137
mock_event_5 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="tool_calls", delta=mock_delta_5)])
149138
mock_event_6 = unittest.mock.Mock()
150139

151-
litellm_client.chat.completions.create.return_value = iter(
152-
[mock_event_1, mock_event_2, mock_event_3, mock_event_4, mock_event_5, mock_event_6]
140+
litellm_acompletion.side_effect = unittest.mock.AsyncMock(
141+
return_value=agenerator([mock_event_1, mock_event_2, mock_event_3, mock_event_4, mock_event_5, mock_event_6])
153142
)
154143

155144
messages = [{"role": "user", "content": [{"type": "text", "text": "calculate 2+2"}]}]
@@ -196,18 +185,20 @@ async def test_stream(litellm_client, model, alist):
196185
]
197186

198187
assert tru_events == exp_events
188+
199189
expected_request = {
200-
"model": "m1",
190+
"api_key": api_key,
191+
"model": model_id,
201192
"messages": [{"role": "user", "content": [{"text": "calculate 2+2", "type": "text"}]}],
202193
"stream": True,
203194
"stream_options": {"include_usage": True},
204195
"tools": [],
205196
}
206-
litellm_client.chat.completions.create.assert_called_once_with(**expected_request)
197+
litellm_acompletion.assert_called_once_with(**expected_request)
207198

208199

209200
@pytest.mark.asyncio
210-
async def test_structured_output(litellm_client, model, test_output_model_cls, alist):
201+
async def test_structured_output(litellm_acompletion, model, test_output_model_cls, alist):
211202
messages = [{"role": "user", "content": [{"text": "Generate a person"}]}]
212203

213204
mock_choice = unittest.mock.Mock()
@@ -216,7 +207,7 @@ async def test_structured_output(litellm_client, model, test_output_model_cls, a
216207
mock_response = unittest.mock.Mock()
217208
mock_response.choices = [mock_choice]
218209

219-
litellm_client.chat.completions.create.return_value = mock_response
210+
litellm_acompletion.side_effect = unittest.mock.AsyncMock(return_value=mock_response)
220211

221212
with unittest.mock.patch.object(strands.models.litellm, "supports_response_schema", return_value=True):
222213
stream = model.structured_output(test_output_model_cls, messages)

tests_integ/models/test_model_litellm.py

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,17 @@ def agent(model, tools):
2929
return Agent(model=model, tools=tools)
3030

3131

32+
@pytest.fixture
33+
def weather():
34+
class Weather(pydantic.BaseModel):
35+
"""Extracts the time and weather from the user's message with the exact strings."""
36+
37+
time: str
38+
weather: str
39+
40+
return Weather(time="12:00", weather="sunny")
41+
42+
3243
@pytest.fixture
3344
def yellow_color():
3445
class Color(pydantic.BaseModel):
@@ -44,24 +55,44 @@ def lower(_, value):
4455
return Color(name="yellow")
4556

4657

47-
def test_agent(agent):
58+
def test_agent_invoke(agent):
4859
result = agent("What is the time and weather in New York?")
4960
text = result.message["content"][0]["text"].lower()
5061

5162
assert all(string in text for string in ["12:00", "sunny"])
5263

5364

54-
def test_structured_output(model):
55-
class Weather(pydantic.BaseModel):
56-
time: str
57-
weather: str
65+
@pytest.mark.asyncio
66+
async def test_agent_invoke_async(agent):
67+
result = await agent.invoke_async("What is the time and weather in New York?")
68+
text = result.message["content"][0]["text"].lower()
69+
70+
assert all(string in text for string in ["12:00", "sunny"])
71+
72+
73+
@pytest.mark.asyncio
74+
async def test_agent_stream_async(agent):
75+
stream = agent.stream_async("What is the time and weather in New York?")
76+
async for event in stream:
77+
_ = event
78+
79+
result = event["result"]
80+
text = result.message["content"][0]["text"].lower()
81+
82+
assert all(string in text for string in ["12:00", "sunny"])
83+
84+
85+
def test_structured_output(agent, weather):
86+
tru_weather = agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny")
87+
exp_weather = weather
88+
assert tru_weather == exp_weather
5889

59-
agent_no_tools = Agent(model=model)
6090

61-
result = agent_no_tools.structured_output(Weather, "The time is 12:00 and the weather is sunny")
62-
assert isinstance(result, Weather)
63-
assert result.time == "12:00"
64-
assert result.weather == "sunny"
91+
@pytest.mark.asyncio
92+
async def test_agent_structured_output_async(agent, weather):
93+
tru_weather = await agent.structured_output_async(type(weather), "The time is 12:00 and the weather is sunny")
94+
exp_weather = weather
95+
assert tru_weather == exp_weather
6596

6697

6798
def test_invoke_multi_modal_input(agent, yellow_img):

0 commit comments

Comments
 (0)