Skip to content

[REFACTOR] Unify Model Interface Around Single Entry Point (model.stream) #400

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/strands/event_loop/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ async def stream_messages(

messages = remove_blank_messages_content_text(messages)

chunks = model.converse(messages, tool_specs if tool_specs else None, system_prompt)
chunks = model.stream(messages, tool_specs if tool_specs else None, system_prompt)

async for event in process_stream(chunks, messages):
yield event
30 changes: 20 additions & 10 deletions src/strands/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,6 @@ def _format_request_messages(self, messages: Messages) -> list[dict[str, Any]]:

return formatted_messages

@override
def format_request(
self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
) -> dict[str, Any]:
Expand Down Expand Up @@ -225,7 +224,6 @@ def format_request(
**(self.config.get("params") or {}),
}

@override
def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
"""Format the Anthropic response events into standardized message chunks.

Expand Down Expand Up @@ -344,27 +342,37 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
raise RuntimeError(f"event_type=<{event['type']} | unknown type")

@override
async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
"""Send the request to the Anthropic model and get the streaming response.
async def stream(
self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
) -> AsyncGenerator[StreamEvent, None]:
"""Stream conversation with the Anthropic model.

Args:
request: The formatted request to send to the Anthropic model.
messages: List of message objects to be processed by the model.
tool_specs: List of tool specifications to make available to the model.
system_prompt: System prompt to provide context to the model.

Returns:
An iterable of response events from the Anthropic model.
Yields:
Formatted message chunks from the model.

Raises:
ContextWindowOverflowException: If the input exceeds the model's context window.
ModelThrottledException: If the request is throttled by Anthropic.
"""
logger.debug("formatting request")
request = self.format_request(messages, tool_specs, system_prompt)
logger.debug("formatted request=<%s>", request)

logger.debug("invoking model")
try:
async with self.client.messages.stream(**request) as stream:
logger.debug("got response from model")
async for event in stream:
if event.type in AnthropicModel.EVENT_TYPES:
yield event.model_dump()
yield self.format_chunk(event.model_dump())

usage = event.message.usage # type: ignore
yield {"type": "metadata", "usage": usage.model_dump()}
yield self.format_chunk({"type": "metadata", "usage": usage.model_dump()})

except anthropic.RateLimitError as error:
raise ModelThrottledException(str(error)) from error
Expand All @@ -375,6 +383,8 @@ async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any]

raise error

logger.debug("finished streaming response from model")

@override
async def structured_output(
self, output_model: Type[T], prompt: Messages
Expand All @@ -390,7 +400,7 @@ async def structured_output(
"""
tool_spec = convert_pydantic_to_tool_spec(output_model)

response = self.converse(messages=prompt, tool_specs=[tool_spec])
response = self.stream(messages=prompt, tool_specs=[tool_spec])
async for event in process_stream(response, prompt):
yield event

Expand Down
28 changes: 19 additions & 9 deletions src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,6 @@ def get_config(self) -> BedrockConfig:
"""
return self.config

@override
def format_request(
self,
messages: Messages,
Expand Down Expand Up @@ -246,7 +245,6 @@ def format_request(
),
}

@override
def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
"""Format the Bedrock response events into standardized message chunks.

Expand Down Expand Up @@ -315,25 +313,35 @@ def _generate_redaction_events(self) -> list[StreamEvent]:
return events

@override
async def stream(self, request: dict[str, Any]) -> AsyncGenerator[StreamEvent, None]:
"""Send the request to the Bedrock model and get the response.
async def stream(
self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
) -> AsyncGenerator[StreamEvent, None]:
"""Stream conversation with the Bedrock model.

This method calls either the Bedrock converse_stream API or the converse API
based on the streaming parameter in the configuration.

Args:
request: The formatted request to send to the Bedrock model
messages: List of message objects to be processed by the model.
tool_specs: List of tool specifications to make available to the model.
system_prompt: System prompt to provide context to the model.

Returns:
An iterable of response events from the Bedrock model
Yields:
Formatted message chunks from the model.

Raises:
ContextWindowOverflowException: If the input exceeds the model's context window.
ModelThrottledException: If the model service is throttling requests.
"""
logger.debug("formatting request")
request = self.format_request(messages, tool_specs, system_prompt)
logger.debug("formatted request=<%s>", request)

logger.debug("invoking model")
streaming = self.config.get("streaming", True)

try:
logger.debug("got response from model")
if streaming:
# Streaming implementation
response = self.client.converse_stream(**request)
Expand All @@ -347,7 +355,7 @@ async def stream(self, request: dict[str, Any]) -> AsyncGenerator[StreamEvent, N
if self._has_blocked_guardrail(guardrail_data):
for event in self._generate_redaction_events():
yield event
yield chunk
yield self.format_chunk(chunk)
else:
# Non-streaming implementation
response = self.client.converse(**request)
Expand Down Expand Up @@ -406,6 +414,8 @@ async def stream(self, request: dict[str, Any]) -> AsyncGenerator[StreamEvent, N
# Otherwise raise the error
raise e

logger.debug("finished streaming response from model")

def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Iterable[StreamEvent]:
"""Convert a non-streaming response to the streaming format.

Expand Down Expand Up @@ -531,7 +541,7 @@ async def structured_output(
"""
tool_spec = convert_pydantic_to_tool_spec(output_model)

response = self.converse(messages=prompt, tool_specs=[tool_spec])
response = self.stream(messages=prompt, tool_specs=[tool_spec])
async for event in process_stream(response, prompt):
yield event

Expand Down
58 changes: 38 additions & 20 deletions src/strands/models/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

from ..types.content import ContentBlock, Messages
from ..types.models.openai import OpenAIModel
from ..types.streaming import StreamEvent
from ..types.tools import ToolSpec

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -104,19 +106,29 @@ def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]
return super().format_request_message_content(content)

@override
async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
"""Send the request to the LiteLLM model and get the streaming response.
async def stream(
self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
) -> AsyncGenerator[StreamEvent, None]:
"""Stream conversation with the LiteLLM model.

Args:
request: The formatted request to send to the LiteLLM model.
messages: List of message objects to be processed by the model.
tool_specs: List of tool specifications to make available to the model.
system_prompt: System prompt to provide context to the model.

Returns:
An iterable of response events from the LiteLLM model.
Yields:
Formatted message chunks from the model.
"""
logger.debug("formatting request")
request = self.format_request(messages, tool_specs, system_prompt)
logger.debug("formatted request=<%s>", request)

logger.debug("invoking model")
response = self.client.chat.completions.create(**request)

yield {"chunk_type": "message_start"}
yield {"chunk_type": "content_start", "data_type": "text"}
logger.debug("got response from model")
yield self.format_chunk({"chunk_type": "message_start"})
yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"})

tool_calls: dict[int, list[Any]] = {}

Expand All @@ -127,38 +139,44 @@ async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any]
choice = event.choices[0]

if choice.delta.content:
yield {"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content}
yield self.format_chunk(
{"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content}
)

if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content:
yield {
"chunk_type": "content_delta",
"data_type": "reasoning_content",
"data": choice.delta.reasoning_content,
}
yield self.format_chunk(
{
"chunk_type": "content_delta",
"data_type": "reasoning_content",
"data": choice.delta.reasoning_content,
}
)

for tool_call in choice.delta.tool_calls or []:
tool_calls.setdefault(tool_call.index, []).append(tool_call)

if choice.finish_reason:
break

yield {"chunk_type": "content_stop", "data_type": "text"}
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})

for tool_deltas in tool_calls.values():
yield {"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]}
yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]})

for tool_delta in tool_deltas:
yield {"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta}
yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta})

yield {"chunk_type": "content_stop", "data_type": "tool"}
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})

yield {"chunk_type": "message_stop", "data": choice.finish_reason}
yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason})

# Skip remaining events as we don't have use for anything except the final usage payload
for event in response:
_ = event

yield {"chunk_type": "metadata", "data": event.usage}
yield self.format_chunk({"chunk_type": "metadata", "data": event.usage})

logger.debug("finished streaming response from model")

@override
async def structured_output(
Expand All @@ -178,7 +196,7 @@ async def structured_output(
# completions() has a method `create()` which wraps the real completion API of Litellm
response = self.client.chat.completions.create(
model=self.get_config()["model_id"],
messages=super().format_request(prompt)["messages"],
messages=self.format_request(prompt)["messages"],
response_format=output_model,
)

Expand Down
46 changes: 29 additions & 17 deletions src/strands/models/llamaapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,6 @@ def _format_request_messages(self, messages: Messages, system_prompt: Optional[s

return [message for message in formatted_messages if message["content"] or "tool_calls" in message]

@override
def format_request(
self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
) -> dict[str, Any]:
Expand Down Expand Up @@ -249,7 +248,6 @@ def format_request(

return request

@override
def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
"""Format the Llama API model response events into standardized message chunks.

Expand Down Expand Up @@ -324,24 +322,34 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type")

@override
async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
"""Send the request to the model and get a streaming response.
async def stream(
self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
) -> AsyncGenerator[StreamEvent, None]:
"""Stream conversation with the LlamaAPI model.

Args:
request: The formatted request to send to the model.
messages: List of message objects to be processed by the model.
tool_specs: List of tool specifications to make available to the model.
system_prompt: System prompt to provide context to the model.

Returns:
The model's response.
Yields:
Formatted message chunks from the model.

Raises:
ModelThrottledException: When the model service is throttling requests from the client.
"""
logger.debug("formatting request")
request = self.format_request(messages, tool_specs, system_prompt)
logger.debug("formatted request=<%s>", request)

logger.debug("invoking model")
try:
response = self.client.chat.completions.create(**request)
except llama_api_client.RateLimitError as e:
raise ModelThrottledException(str(e)) from e

yield {"chunk_type": "message_start"}
logger.debug("got response from model")
yield self.format_chunk({"chunk_type": "message_start"})

stop_reason = None
tool_calls: dict[Any, list[Any]] = {}
Expand All @@ -350,9 +358,11 @@ async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any]
metrics_event = None
for chunk in response:
if chunk.event.event_type == "start":
yield {"chunk_type": "content_start", "data_type": "text"}
yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"})
elif chunk.event.event_type in ["progress", "complete"] and chunk.event.delta.type == "text":
yield {"chunk_type": "content_delta", "data_type": "text", "data": chunk.event.delta.text}
yield self.format_chunk(
{"chunk_type": "content_delta", "data_type": "text", "data": chunk.event.delta.text}
)
else:
if chunk.event.delta.type == "tool_call":
if chunk.event.delta.id:
Expand All @@ -364,29 +374,31 @@ async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any]
elif chunk.event.event_type == "metrics":
metrics_event = chunk.event.metrics
else:
yield chunk
yield self.format_chunk(chunk)

if stop_reason is None:
stop_reason = chunk.event.stop_reason

# stopped generation
if stop_reason:
yield {"chunk_type": "content_stop", "data_type": "text"}
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})

for tool_deltas in tool_calls.values():
tool_start, tool_deltas = tool_deltas[0], tool_deltas[1:]
yield {"chunk_type": "content_start", "data_type": "tool", "data": tool_start}
yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_start})

for tool_delta in tool_deltas:
yield {"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta}
yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta})

yield {"chunk_type": "content_stop", "data_type": "tool"}
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})

yield {"chunk_type": "message_stop", "data": stop_reason}
yield self.format_chunk({"chunk_type": "message_stop", "data": stop_reason})

# we may have a metrics event here
if metrics_event:
yield {"chunk_type": "metadata", "data": metrics_event}
yield self.format_chunk({"chunk_type": "metadata", "data": metrics_event})

logger.debug("finished streaming response from model")

@override
def structured_output(
Expand Down
Loading