Skip to content

Commit 53e45ac

Browse files
committed
models - openai - async client
1 parent 5190645 commit 53e45ac

File tree

6 files changed

+226
-44
lines changed

6 files changed

+226
-44
lines changed

src/strands/models/litellm.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from typing_extensions import Unpack, override
1414

1515
from ..types.content import ContentBlock, Messages
16-
from .openai import OpenAIModel
16+
from ..types.models.openai import OpenAIModel
1717

1818
logger = logging.getLogger(__name__)
1919

@@ -103,6 +103,63 @@ def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]
103103

104104
return super().format_request_message_content(content)
105105

106+
@override
107+
async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
108+
"""Send the request to the LiteLLM model and get the streaming response.
109+
110+
Args:
111+
request: The formatted request to send to the LiteLLM model.
112+
113+
Returns:
114+
An iterable of response events from the LiteLLM model.
115+
"""
116+
response = self.client.chat.completions.create(**request)
117+
118+
yield {"chunk_type": "message_start"}
119+
yield {"chunk_type": "content_start", "data_type": "text"}
120+
121+
tool_calls: dict[int, list[Any]] = {}
122+
123+
for event in response:
124+
# Defensive: skip events with empty or missing choices
125+
if not getattr(event, "choices", None):
126+
continue
127+
choice = event.choices[0]
128+
129+
if choice.delta.content:
130+
yield {"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content}
131+
132+
if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content:
133+
yield {
134+
"chunk_type": "content_delta",
135+
"data_type": "reasoning_content",
136+
"data": choice.delta.reasoning_content,
137+
}
138+
139+
for tool_call in choice.delta.tool_calls or []:
140+
tool_calls.setdefault(tool_call.index, []).append(tool_call)
141+
142+
if choice.finish_reason:
143+
break
144+
145+
yield {"chunk_type": "content_stop", "data_type": "text"}
146+
147+
for tool_deltas in tool_calls.values():
148+
yield {"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]}
149+
150+
for tool_delta in tool_deltas:
151+
yield {"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta}
152+
153+
yield {"chunk_type": "content_stop", "data_type": "tool"}
154+
155+
yield {"chunk_type": "message_stop", "data": choice.finish_reason}
156+
157+
# Skip remaining events as we don't have use for anything except the final usage payload
158+
for event in response:
159+
_ = event
160+
161+
yield {"chunk_type": "metadata", "data": event.usage}
162+
106163
@override
107164
async def structured_output(
108165
self, output_model: Type[T], prompt: Messages

src/strands/models/openai.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config:
6161
logger.debug("config=<%s> | initializing", self.config)
6262

6363
client_args = client_args or {}
64-
self.client = openai.OpenAI(**client_args)
64+
self.client = openai.AsyncOpenAI(**client_args)
6565

6666
@override
6767
def update_config(self, **model_config: Unpack[OpenAIConfig]) -> None: # type: ignore[override]
@@ -91,14 +91,14 @@ async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any]
9191
Returns:
9292
An iterable of response events from the OpenAI model.
9393
"""
94-
response = self.client.chat.completions.create(**request)
94+
response = await self.client.chat.completions.create(**request)
9595

9696
yield {"chunk_type": "message_start"}
9797
yield {"chunk_type": "content_start", "data_type": "text"}
9898

9999
tool_calls: dict[int, list[Any]] = {}
100100

101-
for event in response:
101+
async for event in response:
102102
# Defensive: skip events with empty or missing choices
103103
if not getattr(event, "choices", None):
104104
continue
@@ -133,7 +133,7 @@ async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any]
133133
yield {"chunk_type": "message_stop", "data": choice.finish_reason}
134134

135135
# Skip remaining events as we don't have use for anything except the final usage payload
136-
for event in response:
136+
async for event in response:
137137
_ = event
138138

139139
yield {"chunk_type": "metadata", "data": event.usage}
@@ -151,7 +151,7 @@ async def structured_output(
151151
Yields:
152152
Model events with the last being the structured output.
153153
"""
154-
response: ParsedChatCompletion = self.client.beta.chat.completions.parse( # type: ignore
154+
response: ParsedChatCompletion = await self.client.beta.chat.completions.parse( # type: ignore
155155
model=self.get_config()["model_id"],
156156
messages=super().format_request(prompt)["messages"],
157157
response_format=output_model,

tests-integ/test_model_litellm.py

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
from strands.models.litellm import LiteLLMModel
77

88

9-
@pytest.fixture
9+
@pytest.fixture(scope="module")
1010
def model():
1111
return LiteLLMModel(model_id="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0")
1212

1313

14-
@pytest.fixture
14+
@pytest.fixture(scope="module")
1515
def tools():
1616
@strands.tool
1717
def tool_time() -> str:
@@ -24,26 +24,57 @@ def tool_weather() -> str:
2424
return [tool_time, tool_weather]
2525

2626

27-
@pytest.fixture
27+
@pytest.fixture(scope="module")
2828
def agent(model, tools):
2929
return Agent(model=model, tools=tools)
3030

3131

32-
def test_agent(agent):
32+
@pytest.fixture(scope="module")
33+
def weather():
34+
class Weather(BaseModel):
35+
"""Extracts the time and weather from the user's message with the exact strings."""
36+
37+
time: str
38+
weather: str
39+
40+
return Weather(time="12:00", weather="sunny")
41+
42+
43+
def test_agent_invoke(agent):
3344
result = agent("What is the time and weather in New York?")
3445
text = result.message["content"][0]["text"].lower()
3546

3647
assert all(string in text for string in ["12:00", "sunny"])
3748

3849

39-
def test_structured_output(model):
40-
class Weather(BaseModel):
41-
time: str
42-
weather: str
50+
@pytest.mark.asyncio
51+
async def test_agent_invoke_async(agent):
52+
result = await agent.invoke_async("What is the time and weather in New York?")
53+
text = result.message["content"][0]["text"].lower()
54+
55+
assert all(string in text for string in ["12:00", "sunny"])
56+
57+
58+
@pytest.mark.asyncio
59+
async def test_agent_stream_async(agent):
60+
stream = agent.stream_async("What is the time and weather in New York?")
61+
async for event in stream:
62+
_ = event
63+
64+
result = event["result"]
65+
text = result.message["content"][0]["text"].lower()
66+
67+
assert all(string in text for string in ["12:00", "sunny"])
68+
69+
70+
def test_agent_structured_output(agent, weather):
71+
tru_weather = agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny")
72+
exp_weather = weather
73+
assert tru_weather == exp_weather
4374

44-
agent_no_tools = Agent(model=model)
4575

46-
result = agent_no_tools.structured_output(Weather, "The time is 12:00 and the weather is sunny")
47-
assert isinstance(result, Weather)
48-
assert result.time == "12:00"
49-
assert result.weather == "sunny"
76+
@pytest.mark.asyncio
77+
async def test_agent_structured_output_async(agent, weather):
78+
tru_weather = await agent.structured_output_async(type(weather), "The time is 12:00 and the weather is sunny")
79+
exp_weather = weather
80+
assert tru_weather == exp_weather

tests-integ/test_model_openai.py

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from strands.models.openai import OpenAIModel
1313

1414

15-
@pytest.fixture
15+
@pytest.fixture(scope="module")
1616
def model():
1717
return OpenAIModel(
1818
model_id="gpt-4o",
@@ -22,7 +22,7 @@ def model():
2222
)
2323

2424

25-
@pytest.fixture
25+
@pytest.fixture(scope="module")
2626
def tools():
2727
@strands.tool
2828
def tool_time() -> str:
@@ -35,36 +35,65 @@ def tool_weather() -> str:
3535
return [tool_time, tool_weather]
3636

3737

38-
@pytest.fixture
38+
@pytest.fixture(scope="module")
3939
def agent(model, tools):
4040
return Agent(model=model, tools=tools)
4141

4242

43-
@pytest.fixture
43+
@pytest.fixture(scope="module")
44+
def weather():
45+
class Weather(BaseModel):
46+
"""Extracts the time and weather from the user's message with the exact strings."""
47+
48+
time: str
49+
weather: str
50+
51+
return Weather(time="12:00", weather="sunny")
52+
53+
54+
@pytest.fixture(scope="module")
4455
def test_image_path(request):
4556
return request.config.rootpath / "tests-integ" / "test_image.png"
4657

4758

48-
def test_agent(agent):
59+
def test_agent_invoke(agent):
4960
result = agent("What is the time and weather in New York?")
5061
text = result.message["content"][0]["text"].lower()
5162

5263
assert all(string in text for string in ["12:00", "sunny"])
5364

5465

55-
def test_structured_output(model):
56-
class Weather(BaseModel):
57-
"""Extracts the time and weather from the user's message with the exact strings."""
66+
@pytest.mark.asyncio
67+
async def test_agent_invoke_async(agent):
68+
result = await agent.invoke_async("What is the time and weather in New York?")
69+
text = result.message["content"][0]["text"].lower()
5870

59-
time: str
60-
weather: str
71+
assert all(string in text for string in ["12:00", "sunny"])
72+
73+
74+
@pytest.mark.asyncio
75+
async def test_agent_stream_async(agent):
76+
stream = agent.stream_async("What is the time and weather in New York?")
77+
async for event in stream:
78+
_ = event
79+
80+
result = event["result"]
81+
text = result.message["content"][0]["text"].lower()
82+
83+
assert all(string in text for string in ["12:00", "sunny"])
84+
85+
86+
def test_agent_structured_output(agent, weather):
87+
tru_weather = agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny")
88+
exp_weather = weather
89+
assert tru_weather == exp_weather
6190

62-
agent = Agent(model=model)
6391

64-
result = agent.structured_output(Weather, "The time is 12:00 and the weather is sunny")
65-
assert isinstance(result, Weather)
66-
assert result.time == "12:00"
67-
assert result.weather == "sunny"
92+
@pytest.mark.asyncio
93+
async def test_agent_structured_output_async(agent, weather):
94+
tru_weather = await agent.structured_output_async(type(weather), "The time is 12:00 and the weather is sunny")
95+
exp_weather = weather
96+
assert tru_weather == exp_weather
6897

6998

7099
def test_tool_returning_images(model, test_image_path):

tests/strands/models/test_litellm.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,69 @@ def test_format_request_message_content(content, exp_result):
115115
assert tru_result == exp_result
116116

117117

118+
@pytest.mark.asyncio
119+
async def test_stream(litellm_client, model, alist):
120+
mock_tool_call_1_part_1 = unittest.mock.Mock(index=0)
121+
mock_tool_call_2_part_1 = unittest.mock.Mock(index=1)
122+
mock_delta_1 = unittest.mock.Mock(
123+
reasoning_content="",
124+
content=None,
125+
tool_calls=None,
126+
)
127+
mock_delta_2 = unittest.mock.Mock(
128+
reasoning_content="\nI'm thinking",
129+
content=None,
130+
tool_calls=None,
131+
)
132+
mock_delta_3 = unittest.mock.Mock(
133+
content="I'll calculate", tool_calls=[mock_tool_call_1_part_1, mock_tool_call_2_part_1], reasoning_content=None
134+
)
135+
136+
mock_tool_call_1_part_2 = unittest.mock.Mock(index=0)
137+
mock_tool_call_2_part_2 = unittest.mock.Mock(index=1)
138+
mock_delta_4 = unittest.mock.Mock(
139+
content="that for you", tool_calls=[mock_tool_call_1_part_2, mock_tool_call_2_part_2], reasoning_content=None
140+
)
141+
142+
mock_delta_5 = unittest.mock.Mock(content="", tool_calls=None, reasoning_content=None)
143+
144+
mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_1)])
145+
mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_2)])
146+
mock_event_3 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_3)])
147+
mock_event_4 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_4)])
148+
mock_event_5 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="tool_calls", delta=mock_delta_5)])
149+
mock_event_6 = unittest.mock.Mock()
150+
151+
litellm_client.chat.completions.create.return_value = iter(
152+
[mock_event_1, mock_event_2, mock_event_3, mock_event_4, mock_event_5, mock_event_6]
153+
)
154+
155+
request = {"model": "m1", "messages": [{"role": "user", "content": [{"type": "text", "text": "calculate 2+2"}]}]}
156+
response = model.stream(request)
157+
tru_events = await alist(response)
158+
exp_events = [
159+
{"chunk_type": "message_start"},
160+
{"chunk_type": "content_start", "data_type": "text"},
161+
{"chunk_type": "content_delta", "data_type": "reasoning_content", "data": "\nI'm thinking"},
162+
{"chunk_type": "content_delta", "data_type": "text", "data": "I'll calculate"},
163+
{"chunk_type": "content_delta", "data_type": "text", "data": "that for you"},
164+
{"chunk_type": "content_stop", "data_type": "text"},
165+
{"chunk_type": "content_start", "data_type": "tool", "data": mock_tool_call_1_part_1},
166+
{"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_1_part_1},
167+
{"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_1_part_2},
168+
{"chunk_type": "content_stop", "data_type": "tool"},
169+
{"chunk_type": "content_start", "data_type": "tool", "data": mock_tool_call_2_part_1},
170+
{"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_2_part_1},
171+
{"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_2_part_2},
172+
{"chunk_type": "content_stop", "data_type": "tool"},
173+
{"chunk_type": "message_stop", "data": "tool_calls"},
174+
{"chunk_type": "metadata", "data": mock_event_6.usage},
175+
]
176+
177+
assert tru_events == exp_events
178+
litellm_client.chat.completions.create.assert_called_once_with(**request)
179+
180+
118181
@pytest.mark.asyncio
119182
async def test_structured_output(litellm_client, model, test_output_model_cls, alist):
120183
messages = [{"role": "user", "content": [{"text": "Generate a person"}]}]

0 commit comments

Comments
 (0)