Skip to content

models - litellm - async #414

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
models - litellm - async
  • Loading branch information
pgrayy committed Jul 11, 2025
commit 0d6676f7c52ecea3a52ad5c20cb0893c8b880f16
16 changes: 6 additions & 10 deletions src/strands/models/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,11 @@ def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config:
https://github.com/BerriAI/litellm/blob/main/litellm/main.py.
**model_config: Configuration options for the LiteLLM model.
"""
self.client_args = client_args or {}
self.config = dict(model_config)

logger.debug("config=<%s> | initializing", self.config)

client_args = client_args or {}
self.client = litellm.LiteLLM(**client_args)

@override
def update_config(self, **model_config: Unpack[LiteLLMConfig]) -> None: # type: ignore[override]
"""Update the LiteLLM model configuration with the provided arguments.
Expand Down Expand Up @@ -124,15 +122,15 @@ async def stream(
logger.debug("formatted request=<%s>", request)

logger.debug("invoking model")
response = self.client.chat.completions.create(**request)
response = await litellm.acompletion(**self.client_args, **request)

logger.debug("got response from model")
yield self.format_chunk({"chunk_type": "message_start"})
yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"})

tool_calls: dict[int, list[Any]] = {}

for event in response:
async for event in response:
# Defensive: skip events with empty or missing choices
if not getattr(event, "choices", None):
continue
Expand Down Expand Up @@ -171,7 +169,7 @@ async def stream(
yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason})

# Skip remaining events as we don't have use for anything except the final usage payload
for event in response:
async for event in response:
_ = event

yield self.format_chunk({"chunk_type": "metadata", "data": event.usage})
Expand All @@ -191,10 +189,8 @@ async def structured_output(
Yields:
Model events with the last being the structured output.
"""
# The LiteLLM `Client` inits with Chat().
# Chat() inits with self.completions
# completions() has a method `create()` which wraps the real completion API of Litellm
response = self.client.chat.completions.create(
response = await litellm.acompletion(
**self.client_args,
model=self.get_config()["model_id"],
messages=self.format_request(prompt)["messages"],
response_format=output_model,
Expand Down
43 changes: 17 additions & 26 deletions tests/strands/models/test_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@


@pytest.fixture
def litellm_client_cls():
with unittest.mock.patch.object(strands.models.litellm.litellm, "LiteLLM") as mock_client_cls:
yield mock_client_cls
def litellm_acompletion():
with unittest.mock.patch.object(strands.models.litellm.litellm, "acompletion") as mock_acompletion:
yield mock_acompletion


@pytest.fixture
def litellm_client(litellm_client_cls):
return litellm_client_cls.return_value
def api_key():
return "a1"


@pytest.fixture
Expand All @@ -24,10 +24,10 @@ def model_id():


@pytest.fixture
def model(litellm_client, model_id):
_ = litellm_client
def model(litellm_acompletion, api_key, model_id):
_ = litellm_acompletion

return LiteLLMModel(model_id=model_id)
return LiteLLMModel(client_args={"api_key": api_key}, model_id=model_id)


@pytest.fixture
Expand All @@ -49,17 +49,6 @@ class TestOutputModel(pydantic.BaseModel):
return TestOutputModel


def test__init__(litellm_client_cls, model_id):
model = LiteLLMModel({"api_key": "k1"}, model_id=model_id, params={"max_tokens": 1})

tru_config = model.get_config()
exp_config = {"model_id": "m1", "params": {"max_tokens": 1}}

assert tru_config == exp_config

litellm_client_cls.assert_called_once_with(api_key="k1")


def test_update_config(model, model_id):
model.update_config(model_id=model_id)

Expand Down Expand Up @@ -116,7 +105,7 @@ def test_format_request_message_content(content, exp_result):


@pytest.mark.asyncio
async def test_stream(litellm_client, model, alist):
async def test_stream(litellm_acompletion, api_key, model_id, model, agenerator, alist):
mock_tool_call_1_part_1 = unittest.mock.Mock(index=0)
mock_tool_call_2_part_1 = unittest.mock.Mock(index=1)
mock_delta_1 = unittest.mock.Mock(
Expand Down Expand Up @@ -148,8 +137,8 @@ async def test_stream(litellm_client, model, alist):
mock_event_5 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="tool_calls", delta=mock_delta_5)])
mock_event_6 = unittest.mock.Mock()

litellm_client.chat.completions.create.return_value = iter(
[mock_event_1, mock_event_2, mock_event_3, mock_event_4, mock_event_5, mock_event_6]
litellm_acompletion.side_effect = unittest.mock.AsyncMock(
return_value=agenerator([mock_event_1, mock_event_2, mock_event_3, mock_event_4, mock_event_5, mock_event_6])
)

messages = [{"role": "user", "content": [{"type": "text", "text": "calculate 2+2"}]}]
Expand Down Expand Up @@ -196,18 +185,20 @@ async def test_stream(litellm_client, model, alist):
]

assert tru_events == exp_events

expected_request = {
"model": "m1",
"api_key": api_key,
"model": model_id,
"messages": [{"role": "user", "content": [{"text": "calculate 2+2", "type": "text"}]}],
"stream": True,
"stream_options": {"include_usage": True},
"tools": [],
}
litellm_client.chat.completions.create.assert_called_once_with(**expected_request)
litellm_acompletion.assert_called_once_with(**expected_request)


@pytest.mark.asyncio
async def test_structured_output(litellm_client, model, test_output_model_cls, alist):
async def test_structured_output(litellm_acompletion, model, test_output_model_cls, alist):
messages = [{"role": "user", "content": [{"text": "Generate a person"}]}]

mock_choice = unittest.mock.Mock()
Expand All @@ -216,7 +207,7 @@ async def test_structured_output(litellm_client, model, test_output_model_cls, a
mock_response = unittest.mock.Mock()
mock_response.choices = [mock_choice]

litellm_client.chat.completions.create.return_value = mock_response
litellm_acompletion.side_effect = unittest.mock.AsyncMock(return_value=mock_response)

with unittest.mock.patch.object(strands.models.litellm, "supports_response_schema", return_value=True):
stream = model.structured_output(test_output_model_cls, messages)
Expand Down
51 changes: 41 additions & 10 deletions tests_integ/models/test_model_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,17 @@ def agent(model, tools):
return Agent(model=model, tools=tools)


@pytest.fixture
def weather():
class Weather(pydantic.BaseModel):
"""Extracts the time and weather from the user's message with the exact strings."""

time: str
weather: str

return Weather(time="12:00", weather="sunny")


@pytest.fixture
def yellow_color():
class Color(pydantic.BaseModel):
Expand All @@ -44,24 +55,44 @@ def lower(_, value):
return Color(name="yellow")


def test_agent(agent):
def test_agent_invoke(agent):
result = agent("What is the time and weather in New York?")
text = result.message["content"][0]["text"].lower()

assert all(string in text for string in ["12:00", "sunny"])


def test_structured_output(model):
class Weather(pydantic.BaseModel):
time: str
weather: str
@pytest.mark.asyncio
async def test_agent_invoke_async(agent):
result = await agent.invoke_async("What is the time and weather in New York?")
text = result.message["content"][0]["text"].lower()

assert all(string in text for string in ["12:00", "sunny"])


@pytest.mark.asyncio
async def test_agent_stream_async(agent):
stream = agent.stream_async("What is the time and weather in New York?")
async for event in stream:
_ = event

result = event["result"]
text = result.message["content"][0]["text"].lower()

assert all(string in text for string in ["12:00", "sunny"])


def test_structured_output(agent, weather):
tru_weather = agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny")
exp_weather = weather
assert tru_weather == exp_weather

agent_no_tools = Agent(model=model)

result = agent_no_tools.structured_output(Weather, "The time is 12:00 and the weather is sunny")
assert isinstance(result, Weather)
assert result.time == "12:00"
assert result.weather == "sunny"
@pytest.mark.asyncio
async def test_agent_structured_output_async(agent, weather):
tru_weather = await agent.structured_output_async(type(weather), "The time is 12:00 and the weather is sunny")
exp_weather = weather
assert tru_weather == exp_weather


def test_invoke_multi_modal_input(agent, yellow_img):
Expand Down