Skip to content

Commit 2baacdf

Browse files
authored
[REFACTOR] Unify Model Interface Around Single Entry Point (model.stream) (#400)
1 parent ea81326 commit 2baacdf

24 files changed

+646
-497
lines changed

src/strands/event_loop/streaming.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,7 @@ async def stream_messages(
321321

322322
messages = remove_blank_messages_content_text(messages)
323323

324-
chunks = model.converse(messages, tool_specs if tool_specs else None, system_prompt)
324+
chunks = model.stream(messages, tool_specs if tool_specs else None, system_prompt)
325+
325326
async for event in process_stream(chunks, messages):
326327
yield event

src/strands/models/anthropic.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,6 @@ def _format_request_messages(self, messages: Messages) -> list[dict[str, Any]]:
191191

192192
return formatted_messages
193193

194-
@override
195194
def format_request(
196195
self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
197196
) -> dict[str, Any]:
@@ -225,7 +224,6 @@ def format_request(
225224
**(self.config.get("params") or {}),
226225
}
227226

228-
@override
229227
def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
230228
"""Format the Anthropic response events into standardized message chunks.
231229
@@ -344,27 +342,37 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
344342
raise RuntimeError(f"event_type=<{event['type']} | unknown type")
345343

346344
@override
347-
async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
348-
"""Send the request to the Anthropic model and get the streaming response.
345+
async def stream(
346+
self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
347+
) -> AsyncGenerator[StreamEvent, None]:
348+
"""Stream conversation with the Anthropic model.
349349
350350
Args:
351-
request: The formatted request to send to the Anthropic model.
351+
messages: List of message objects to be processed by the model.
352+
tool_specs: List of tool specifications to make available to the model.
353+
system_prompt: System prompt to provide context to the model.
352354
353-
Returns:
354-
An iterable of response events from the Anthropic model.
355+
Yields:
356+
Formatted message chunks from the model.
355357
356358
Raises:
357359
ContextWindowOverflowException: If the input exceeds the model's context window.
358360
ModelThrottledException: If the request is throttled by Anthropic.
359361
"""
362+
logger.debug("formatting request")
363+
request = self.format_request(messages, tool_specs, system_prompt)
364+
logger.debug("formatted request=<%s>", request)
365+
366+
logger.debug("invoking model")
360367
try:
361368
async with self.client.messages.stream(**request) as stream:
369+
logger.debug("got response from model")
362370
async for event in stream:
363371
if event.type in AnthropicModel.EVENT_TYPES:
364-
yield event.model_dump()
372+
yield self.format_chunk(event.model_dump())
365373

366374
usage = event.message.usage # type: ignore
367-
yield {"type": "metadata", "usage": usage.model_dump()}
375+
yield self.format_chunk({"type": "metadata", "usage": usage.model_dump()})
368376

369377
except anthropic.RateLimitError as error:
370378
raise ModelThrottledException(str(error)) from error
@@ -375,6 +383,8 @@ async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any]
375383

376384
raise error
377385

386+
logger.debug("finished streaming response from model")
387+
378388
@override
379389
async def structured_output(
380390
self, output_model: Type[T], prompt: Messages
@@ -390,7 +400,7 @@ async def structured_output(
390400
"""
391401
tool_spec = convert_pydantic_to_tool_spec(output_model)
392402

393-
response = self.converse(messages=prompt, tool_specs=[tool_spec])
403+
response = self.stream(messages=prompt, tool_specs=[tool_spec])
394404
async for event in process_stream(response, prompt):
395405
yield event
396406

src/strands/models/bedrock.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,6 @@ def get_config(self) -> BedrockConfig:
162162
"""
163163
return self.config
164164

165-
@override
166165
def format_request(
167166
self,
168167
messages: Messages,
@@ -246,7 +245,6 @@ def format_request(
246245
),
247246
}
248247

249-
@override
250248
def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
251249
"""Format the Bedrock response events into standardized message chunks.
252250
@@ -315,25 +313,35 @@ def _generate_redaction_events(self) -> list[StreamEvent]:
315313
return events
316314

317315
@override
318-
async def stream(self, request: dict[str, Any]) -> AsyncGenerator[StreamEvent, None]:
319-
"""Send the request to the Bedrock model and get the response.
316+
async def stream(
317+
self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
318+
) -> AsyncGenerator[StreamEvent, None]:
319+
"""Stream conversation with the Bedrock model.
320320
321321
This method calls either the Bedrock converse_stream API or the converse API
322322
based on the streaming parameter in the configuration.
323323
324324
Args:
325-
request: The formatted request to send to the Bedrock model
325+
messages: List of message objects to be processed by the model.
326+
tool_specs: List of tool specifications to make available to the model.
327+
system_prompt: System prompt to provide context to the model.
326328
327-
Returns:
328-
An iterable of response events from the Bedrock model
329+
Yields:
330+
Formatted message chunks from the model.
329331
330332
Raises:
331333
ContextWindowOverflowException: If the input exceeds the model's context window.
332334
ModelThrottledException: If the model service is throttling requests.
333335
"""
336+
logger.debug("formatting request")
337+
request = self.format_request(messages, tool_specs, system_prompt)
338+
logger.debug("formatted request=<%s>", request)
339+
340+
logger.debug("invoking model")
334341
streaming = self.config.get("streaming", True)
335342

336343
try:
344+
logger.debug("got response from model")
337345
if streaming:
338346
# Streaming implementation
339347
response = self.client.converse_stream(**request)
@@ -347,7 +355,7 @@ async def stream(self, request: dict[str, Any]) -> AsyncGenerator[StreamEvent, N
347355
if self._has_blocked_guardrail(guardrail_data):
348356
for event in self._generate_redaction_events():
349357
yield event
350-
yield chunk
358+
yield self.format_chunk(chunk)
351359
else:
352360
# Non-streaming implementation
353361
response = self.client.converse(**request)
@@ -406,6 +414,8 @@ async def stream(self, request: dict[str, Any]) -> AsyncGenerator[StreamEvent, N
406414
# Otherwise raise the error
407415
raise e
408416

417+
logger.debug("finished streaming response from model")
418+
409419
def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Iterable[StreamEvent]:
410420
"""Convert a non-streaming response to the streaming format.
411421
@@ -531,7 +541,7 @@ async def structured_output(
531541
"""
532542
tool_spec = convert_pydantic_to_tool_spec(output_model)
533543

534-
response = self.converse(messages=prompt, tool_specs=[tool_spec])
544+
response = self.stream(messages=prompt, tool_specs=[tool_spec])
535545
async for event in process_stream(response, prompt):
536546
yield event
537547

src/strands/models/litellm.py

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
from ..types.content import ContentBlock, Messages
1616
from ..types.models.openai import OpenAIModel
17+
from ..types.streaming import StreamEvent
18+
from ..types.tools import ToolSpec
1719

1820
logger = logging.getLogger(__name__)
1921

@@ -104,19 +106,29 @@ def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]
104106
return super().format_request_message_content(content)
105107

106108
@override
107-
async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
108-
"""Send the request to the LiteLLM model and get the streaming response.
109+
async def stream(
110+
self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
111+
) -> AsyncGenerator[StreamEvent, None]:
112+
"""Stream conversation with the LiteLLM model.
109113
110114
Args:
111-
request: The formatted request to send to the LiteLLM model.
115+
messages: List of message objects to be processed by the model.
116+
tool_specs: List of tool specifications to make available to the model.
117+
system_prompt: System prompt to provide context to the model.
112118
113-
Returns:
114-
An iterable of response events from the LiteLLM model.
119+
Yields:
120+
Formatted message chunks from the model.
115121
"""
122+
logger.debug("formatting request")
123+
request = self.format_request(messages, tool_specs, system_prompt)
124+
logger.debug("formatted request=<%s>", request)
125+
126+
logger.debug("invoking model")
116127
response = self.client.chat.completions.create(**request)
117128

118-
yield {"chunk_type": "message_start"}
119-
yield {"chunk_type": "content_start", "data_type": "text"}
129+
logger.debug("got response from model")
130+
yield self.format_chunk({"chunk_type": "message_start"})
131+
yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"})
120132

121133
tool_calls: dict[int, list[Any]] = {}
122134

@@ -127,38 +139,44 @@ async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any]
127139
choice = event.choices[0]
128140

129141
if choice.delta.content:
130-
yield {"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content}
142+
yield self.format_chunk(
143+
{"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content}
144+
)
131145

132146
if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content:
133-
yield {
134-
"chunk_type": "content_delta",
135-
"data_type": "reasoning_content",
136-
"data": choice.delta.reasoning_content,
137-
}
147+
yield self.format_chunk(
148+
{
149+
"chunk_type": "content_delta",
150+
"data_type": "reasoning_content",
151+
"data": choice.delta.reasoning_content,
152+
}
153+
)
138154

139155
for tool_call in choice.delta.tool_calls or []:
140156
tool_calls.setdefault(tool_call.index, []).append(tool_call)
141157

142158
if choice.finish_reason:
143159
break
144160

145-
yield {"chunk_type": "content_stop", "data_type": "text"}
161+
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})
146162

147163
for tool_deltas in tool_calls.values():
148-
yield {"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]}
164+
yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]})
149165

150166
for tool_delta in tool_deltas:
151-
yield {"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta}
167+
yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta})
152168

153-
yield {"chunk_type": "content_stop", "data_type": "tool"}
169+
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})
154170

155-
yield {"chunk_type": "message_stop", "data": choice.finish_reason}
171+
yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason})
156172

157173
# Skip remaining events as we don't have use for anything except the final usage payload
158174
for event in response:
159175
_ = event
160176

161-
yield {"chunk_type": "metadata", "data": event.usage}
177+
yield self.format_chunk({"chunk_type": "metadata", "data": event.usage})
178+
179+
logger.debug("finished streaming response from model")
162180

163181
@override
164182
async def structured_output(
@@ -178,7 +196,7 @@ async def structured_output(
178196
# completions() has a method `create()` which wraps the real completion API of Litellm
179197
response = self.client.chat.completions.create(
180198
model=self.get_config()["model_id"],
181-
messages=super().format_request(prompt)["messages"],
199+
messages=self.format_request(prompt)["messages"],
182200
response_format=output_model,
183201
)
184202

src/strands/models/llamaapi.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,6 @@ def _format_request_messages(self, messages: Messages, system_prompt: Optional[s
202202

203203
return [message for message in formatted_messages if message["content"] or "tool_calls" in message]
204204

205-
@override
206205
def format_request(
207206
self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
208207
) -> dict[str, Any]:
@@ -249,7 +248,6 @@ def format_request(
249248

250249
return request
251250

252-
@override
253251
def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
254252
"""Format the Llama API model response events into standardized message chunks.
255253
@@ -324,24 +322,34 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
324322
raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type")
325323

326324
@override
327-
async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
328-
"""Send the request to the model and get a streaming response.
325+
async def stream(
326+
self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
327+
) -> AsyncGenerator[StreamEvent, None]:
328+
"""Stream conversation with the LlamaAPI model.
329329
330330
Args:
331-
request: The formatted request to send to the model.
331+
messages: List of message objects to be processed by the model.
332+
tool_specs: List of tool specifications to make available to the model.
333+
system_prompt: System prompt to provide context to the model.
332334
333-
Returns:
334-
The model's response.
335+
Yields:
336+
Formatted message chunks from the model.
335337
336338
Raises:
337339
ModelThrottledException: When the model service is throttling requests from the client.
338340
"""
341+
logger.debug("formatting request")
342+
request = self.format_request(messages, tool_specs, system_prompt)
343+
logger.debug("formatted request=<%s>", request)
344+
345+
logger.debug("invoking model")
339346
try:
340347
response = self.client.chat.completions.create(**request)
341348
except llama_api_client.RateLimitError as e:
342349
raise ModelThrottledException(str(e)) from e
343350

344-
yield {"chunk_type": "message_start"}
351+
logger.debug("got response from model")
352+
yield self.format_chunk({"chunk_type": "message_start"})
345353

346354
stop_reason = None
347355
tool_calls: dict[Any, list[Any]] = {}
@@ -350,9 +358,11 @@ async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any]
350358
metrics_event = None
351359
for chunk in response:
352360
if chunk.event.event_type == "start":
353-
yield {"chunk_type": "content_start", "data_type": "text"}
361+
yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"})
354362
elif chunk.event.event_type in ["progress", "complete"] and chunk.event.delta.type == "text":
355-
yield {"chunk_type": "content_delta", "data_type": "text", "data": chunk.event.delta.text}
363+
yield self.format_chunk(
364+
{"chunk_type": "content_delta", "data_type": "text", "data": chunk.event.delta.text}
365+
)
356366
else:
357367
if chunk.event.delta.type == "tool_call":
358368
if chunk.event.delta.id:
@@ -364,29 +374,31 @@ async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any]
364374
elif chunk.event.event_type == "metrics":
365375
metrics_event = chunk.event.metrics
366376
else:
367-
yield chunk
377+
yield self.format_chunk(chunk)
368378

369379
if stop_reason is None:
370380
stop_reason = chunk.event.stop_reason
371381

372382
# stopped generation
373383
if stop_reason:
374-
yield {"chunk_type": "content_stop", "data_type": "text"}
384+
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})
375385

376386
for tool_deltas in tool_calls.values():
377387
tool_start, tool_deltas = tool_deltas[0], tool_deltas[1:]
378-
yield {"chunk_type": "content_start", "data_type": "tool", "data": tool_start}
388+
yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_start})
379389

380390
for tool_delta in tool_deltas:
381-
yield {"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta}
391+
yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta})
382392

383-
yield {"chunk_type": "content_stop", "data_type": "tool"}
393+
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})
384394

385-
yield {"chunk_type": "message_stop", "data": stop_reason}
395+
yield self.format_chunk({"chunk_type": "message_stop", "data": stop_reason})
386396

387397
# we may have a metrics event here
388398
if metrics_event:
389-
yield {"chunk_type": "metadata", "data": metrics_event}
399+
yield self.format_chunk({"chunk_type": "metadata", "data": metrics_event})
400+
401+
logger.debug("finished streaming response from model")
390402

391403
@override
392404
def structured_output(

0 commit comments

Comments
 (0)