Skip to content

iterative streaming #241

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions src/strands/event_loop/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,14 +146,19 @@ def event_loop_cycle(
)

try:
stop_reason, message, usage, metrics, kwargs["request_state"] = stream_messages(
model,
system_prompt,
messages,
tool_config,
callback_handler,
**kwargs,
)
# TODO: As part of the migration to async-iterator, we will continue moving callback_handler calls up the
# call stack. At this point, we converted all events that were previously passed to the handler in
# `stream_messages` into yielded events that now have the "callback" key. To maintain backwards
# compatability, we need to combine the event with kwargs before passing to the handler. This we will
# revisit when migrating to strongly typed events.
for event in stream_messages(model, system_prompt, messages, tool_config):
if "callback" in event:
inputs = {**event["callback"], **(kwargs if "delta" in event["callback"] else {})}
callback_handler(**inputs)
Copy link
Member Author

@pgrayy pgrayy Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zastrowm had the following comment:

I presume that test_event_loop has enough UT coverage verifying the callbacks of events are as expected? At which point we have a lot of confidence that we maintained backwards compatibility.

This was a good call out. We didn't actually have the best coverage on callbacks. I added a test that actually caught an issue with passing kwargs which we need for backwards compatibility. I personally don't think passing kwargs is useful and so we should revisit when we formally type events (tracking issue).

else:
stop_reason, message, usage, metrics = event["stop"]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zastrowm had the following comment:

So stop right now is not an event that callback gets passed. Given that we don't need backwards compability here, what would you think about turning this into a typed event? Could also be a follow-up PR

Will handle in a follow up PR as part of #242

kwargs.setdefault("request_state", {})

if model_invoke_span:
tracer.end_model_invoke_span(model_invoke_span, message, usage)
break # Success! Break out of retry loop
Expand Down Expand Up @@ -370,7 +375,7 @@ def _handle_tool_execution(
kwargs (Dict[str, Any]): Additional keyword arguments, including request state.

Returns:
Tuple[StopReason, Message, EventLoopMetrics, Dict[str, Any]]:
Tuple[StopReason, Message, EventLoopMetrics, Dict[str, Any]]:
- The stop reason,
- The updated message,
- The updated event loop metrics,
Expand Down
80 changes: 33 additions & 47 deletions src/strands/event_loop/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import json
import logging
from typing import Any, Dict, Iterable, List, Optional, Tuple
from typing import Any, Generator, Iterable, Optional

from ..types.content import ContentBlock, Message, Messages
from ..types.models import Model
Expand Down Expand Up @@ -80,7 +80,7 @@ def handle_message_start(event: MessageStartEvent, message: Message) -> Message:
return message


def handle_content_block_start(event: ContentBlockStartEvent) -> Dict[str, Any]:
def handle_content_block_start(event: ContentBlockStartEvent) -> dict[str, Any]:
"""Handles the start of a content block by extracting tool usage information if any.

Args:
Expand All @@ -102,61 +102,59 @@ def handle_content_block_start(event: ContentBlockStartEvent) -> Dict[str, Any]:


def handle_content_block_delta(
event: ContentBlockDeltaEvent, state: Dict[str, Any], callback_handler: Any, **kwargs: Any
) -> Dict[str, Any]:
event: ContentBlockDeltaEvent, state: dict[str, Any]
) -> tuple[dict[str, Any], dict[str, Any]]:
"""Handles content block delta updates by appending text, tool input, or reasoning content to the state.

Args:
event: Delta event.
state: The current state of message processing.
callback_handler: Callback for processing events as they happen.
**kwargs: Additional keyword arguments to pass to the callback handler.

Returns:
Updated state with appended text or tool input.
"""
delta_content = event["delta"]

callback_event = {}

if "toolUse" in delta_content:
if "input" not in state["current_tool_use"]:
state["current_tool_use"]["input"] = ""

state["current_tool_use"]["input"] += delta_content["toolUse"]["input"]
callback_handler(delta=delta_content, current_tool_use=state["current_tool_use"], **kwargs)
callback_event["callback"] = {"delta": delta_content, "current_tool_use": state["current_tool_use"]}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a temporary format to maintain backwards compatibility. We are planning on formalizing and strongly typing the construction of our agent events for consistency and ease of use.

#242


elif "text" in delta_content:
state["text"] += delta_content["text"]
callback_handler(data=delta_content["text"], delta=delta_content, **kwargs)
callback_event["callback"] = {"data": delta_content["text"], "delta": delta_content}

elif "reasoningContent" in delta_content:
if "text" in delta_content["reasoningContent"]:
if "reasoningText" not in state:
state["reasoningText"] = ""

state["reasoningText"] += delta_content["reasoningContent"]["text"]
callback_handler(
reasoningText=delta_content["reasoningContent"]["text"],
delta=delta_content,
reasoning=True,
**kwargs,
)
callback_event["callback"] = {
"reasoningText": delta_content["reasoningContent"]["text"],
"delta": delta_content,
"reasoning": True,
}

elif "signature" in delta_content["reasoningContent"]:
if "signature" not in state:
state["signature"] = ""

state["signature"] += delta_content["reasoningContent"]["signature"]
callback_handler(
reasoning_signature=delta_content["reasoningContent"]["signature"],
delta=delta_content,
reasoning=True,
**kwargs,
)
callback_event["callback"] = {
"reasoning_signature": delta_content["reasoningContent"]["signature"],
"delta": delta_content,
"reasoning": True,
}

return state
return state, callback_event


def handle_content_block_stop(state: Dict[str, Any]) -> Dict[str, Any]:
def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]:
"""Handles the end of a content block by finalizing tool usage, text content, or reasoning content.

Args:
Expand All @@ -165,7 +163,7 @@ def handle_content_block_stop(state: Dict[str, Any]) -> Dict[str, Any]:
Returns:
Updated state with finalized content block.
"""
content: List[ContentBlock] = state["content"]
content: list[ContentBlock] = state["content"]

current_tool_use = state["current_tool_use"]
text = state["text"]
Expand Down Expand Up @@ -223,7 +221,7 @@ def handle_message_stop(event: MessageStopEvent) -> StopReason:
return event["stopReason"]


def handle_redact_content(event: RedactContentEvent, messages: Messages, state: Dict[str, Any]) -> None:
def handle_redact_content(event: RedactContentEvent, messages: Messages, state: dict[str, Any]) -> None:
"""Handles redacting content from the input or output.

Args:
Expand All @@ -238,7 +236,7 @@ def handle_redact_content(event: RedactContentEvent, messages: Messages, state:
state["message"]["content"] = [{"text": event["redactAssistantContentMessage"]}]


def extract_usage_metrics(event: MetadataEvent) -> Tuple[Usage, Metrics]:
def extract_usage_metrics(event: MetadataEvent) -> tuple[Usage, Metrics]:
"""Extracts usage metrics from the metadata chunk.

Args:
Expand All @@ -255,25 +253,20 @@ def extract_usage_metrics(event: MetadataEvent) -> Tuple[Usage, Metrics]:

def process_stream(
chunks: Iterable[StreamEvent],
callback_handler: Any,
messages: Messages,
**kwargs: Any,
) -> Tuple[StopReason, Message, Usage, Metrics, Any]:
) -> Generator[dict[str, Any], None, None]:
"""Processes the response stream from the API, constructing the final message and extracting usage metrics.

Args:
chunks: The chunks of the response stream from the model.
callback_handler: Callback for processing events as they happen.
messages: The agents messages.
**kwargs: Additional keyword arguments that will be passed to the callback handler.
And also returned in the request_state.

Returns:
The reason for stopping, the constructed message, the usage metrics, and the updated request state.
The reason for stopping, the constructed message, and the usage metrics.
"""
stop_reason: StopReason = "end_turn"

state: Dict[str, Any] = {
state: dict[str, Any] = {
"message": {"role": "assistant", "content": []},
"text": "",
"current_tool_use": {},
Expand All @@ -285,18 +278,16 @@ def process_stream(
usage: Usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0)
metrics: Metrics = Metrics(latencyMs=0)

kwargs.setdefault("request_state", {})

for chunk in chunks:
# Callback handler call here allows each event to be visible to the caller
callback_handler(event=chunk)
yield {"callback": {"event": chunk}}

if "messageStart" in chunk:
state["message"] = handle_message_start(chunk["messageStart"], state["message"])
elif "contentBlockStart" in chunk:
state["current_tool_use"] = handle_content_block_start(chunk["contentBlockStart"])
elif "contentBlockDelta" in chunk:
state = handle_content_block_delta(chunk["contentBlockDelta"], state, callback_handler, **kwargs)
state, callback_event = handle_content_block_delta(chunk["contentBlockDelta"], state)
yield callback_event
elif "contentBlockStop" in chunk:
state = handle_content_block_stop(state)
elif "messageStop" in chunk:
Expand All @@ -306,35 +297,30 @@ def process_stream(
elif "redactContent" in chunk:
handle_redact_content(chunk["redactContent"], messages, state)

return stop_reason, state["message"], usage, metrics, kwargs["request_state"]
yield {"stop": (stop_reason, state["message"], usage, metrics)}


def stream_messages(
model: Model,
system_prompt: Optional[str],
messages: Messages,
tool_config: Optional[ToolConfig],
callback_handler: Any,
**kwargs: Any,
) -> Tuple[StopReason, Message, Usage, Metrics, Any]:
) -> Generator[dict[str, Any], None, None]:
"""Streams messages to the model and processes the response.

Args:
model: Model provider.
system_prompt: The system prompt to send.
messages: List of messages to send.
tool_config: Configuration for the tools to use.
callback_handler: Callback for processing events as they happen.
**kwargs: Additional keyword arguments that will be passed to the callback handler.
And also returned in the request_state.

Returns:
The reason for stopping, the final message, the usage metrics, and updated request state.
The reason for stopping, the final message, and the usage metrics
"""
logger.debug("model=<%s> | streaming messages", model)

messages = remove_blank_messages_content_text(messages)
tool_specs = [tool["toolSpec"] for tool in tool_config.get("tools", [])] or None if tool_config else None

chunks = model.converse(messages, tool_specs, system_prompt)
return process_stream(chunks, callback_handler, messages, **kwargs)
yield from process_stream(chunks, messages)
100 changes: 100 additions & 0 deletions tests/strands/event_loop/test_event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,3 +733,103 @@ def test_event_loop_cycle_with_parent_span(
mock_tracer.start_event_loop_cycle_span.assert_called_once_with(
event_loop_kwargs=unittest.mock.ANY, parent_span=parent_span, messages=messages
)


def test_event_loop_cycle_callback(
model,
model_id,
system_prompt,
messages,
tool_config,
callback_handler,
tool_handler,
tool_execution_handler,
):
model.converse.return_value = [
{"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "test"}}}},
{"contentBlockDelta": {"delta": {"toolUse": {"input": '{"value"}'}}}},
{"contentBlockStop": {}},
{"contentBlockStart": {"start": {}}},
{"contentBlockDelta": {"delta": {"reasoningContent": {"text": "value"}}}},
{"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "value"}}}},
{"contentBlockStop": {}},
{"contentBlockStart": {"start": {}}},
{"contentBlockDelta": {"delta": {"text": "value"}}},
{"contentBlockStop": {}},
]

strands.event_loop.event_loop.event_loop_cycle(
model=model,
model_id=model_id,
system_prompt=system_prompt,
messages=messages,
tool_config=tool_config,
callback_handler=callback_handler,
tool_handler=tool_handler,
tool_execution_handler=tool_execution_handler,
)

callback_handler.assert_has_calls(
[
call(start=True),
call(start_event_loop=True),
call(event={"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "test"}}}}),
call(event={"contentBlockDelta": {"delta": {"toolUse": {"input": '{"value"}'}}}}),
call(
delta={"toolUse": {"input": '{"value"}'}},
current_tool_use={"toolUseId": "123", "name": "test", "input": {}},
model_id="m1",
event_loop_cycle_id=unittest.mock.ANY,
request_state={},
event_loop_cycle_trace=unittest.mock.ANY,
event_loop_cycle_span=None,
),
call(event={"contentBlockStop": {}}),
call(event={"contentBlockStart": {"start": {}}}),
call(event={"contentBlockDelta": {"delta": {"reasoningContent": {"text": "value"}}}}),
call(
reasoningText="value",
delta={"reasoningContent": {"text": "value"}},
reasoning=True,
model_id="m1",
event_loop_cycle_id=unittest.mock.ANY,
request_state={},
event_loop_cycle_trace=unittest.mock.ANY,
event_loop_cycle_span=None,
),
call(event={"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "value"}}}}),
call(
reasoning_signature="value",
delta={"reasoningContent": {"signature": "value"}},
reasoning=True,
model_id="m1",
event_loop_cycle_id=unittest.mock.ANY,
request_state={},
event_loop_cycle_trace=unittest.mock.ANY,
event_loop_cycle_span=None,
),
call(event={"contentBlockStop": {}}),
call(event={"contentBlockStart": {"start": {}}}),
call(event={"contentBlockDelta": {"delta": {"text": "value"}}}),
call(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was a comment about duplicate events yielded while streaming. This is an example (813-814) and will be addressed as part of #242.

data="value",
delta={"text": "value"},
model_id="m1",
event_loop_cycle_id=unittest.mock.ANY,
request_state={},
event_loop_cycle_trace=unittest.mock.ANY,
event_loop_cycle_span=None,
),
call(event={"contentBlockStop": {}}),
call(
message={
"role": "assistant",
"content": [
{"toolUse": {"toolUseId": "123", "name": "test", "input": {}}},
{"reasoningContent": {"reasoningText": {"text": "value", "signature": "value"}}},
{"text": "value"},
],
},
),
],
)
Loading