-
Notifications
You must be signed in to change notification settings - Fork 151
iterative streaming #241
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
iterative streaming #241
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -146,14 +146,19 @@ def event_loop_cycle( | |
) | ||
|
||
try: | ||
stop_reason, message, usage, metrics, kwargs["request_state"] = stream_messages( | ||
model, | ||
system_prompt, | ||
messages, | ||
tool_config, | ||
callback_handler, | ||
**kwargs, | ||
) | ||
# TODO: As part of the migration to async-iterator, we will continue moving callback_handler calls up the | ||
# call stack. At this point, we converted all events that were previously passed to the handler in | ||
# `stream_messages` into yielded events that now have the "callback" key. To maintain backwards | ||
# compatability, we need to combine the event with kwargs before passing to the handler. This we will | ||
# revisit when migrating to strongly typed events. | ||
for event in stream_messages(model, system_prompt, messages, tool_config): | ||
if "callback" in event: | ||
inputs = {**event["callback"], **(kwargs if "delta" in event["callback"] else {})} | ||
callback_handler(**inputs) | ||
else: | ||
stop_reason, message, usage, metrics = event["stop"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
kwargs.setdefault("request_state", {}) | ||
|
||
if model_invoke_span: | ||
tracer.end_model_invoke_span(model_invoke_span, message, usage) | ||
break # Success! Break out of retry loop | ||
|
@@ -370,7 +375,7 @@ def _handle_tool_execution( | |
kwargs (Dict[str, Any]): Additional keyword arguments, including request state. | ||
|
||
Returns: | ||
Tuple[StopReason, Message, EventLoopMetrics, Dict[str, Any]]: | ||
Tuple[StopReason, Message, EventLoopMetrics, Dict[str, Any]]: | ||
- The stop reason, | ||
- The updated message, | ||
- The updated event loop metrics, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
|
||
import json | ||
import logging | ||
from typing import Any, Dict, Iterable, List, Optional, Tuple | ||
from typing import Any, Generator, Iterable, Optional | ||
|
||
from ..types.content import ContentBlock, Message, Messages | ||
from ..types.models import Model | ||
|
@@ -80,7 +80,7 @@ def handle_message_start(event: MessageStartEvent, message: Message) -> Message: | |
return message | ||
|
||
|
||
def handle_content_block_start(event: ContentBlockStartEvent) -> Dict[str, Any]: | ||
def handle_content_block_start(event: ContentBlockStartEvent) -> dict[str, Any]: | ||
"""Handles the start of a content block by extracting tool usage information if any. | ||
|
||
Args: | ||
|
@@ -102,61 +102,59 @@ def handle_content_block_start(event: ContentBlockStartEvent) -> Dict[str, Any]: | |
|
||
|
||
def handle_content_block_delta( | ||
event: ContentBlockDeltaEvent, state: Dict[str, Any], callback_handler: Any, **kwargs: Any | ||
) -> Dict[str, Any]: | ||
event: ContentBlockDeltaEvent, state: dict[str, Any] | ||
) -> tuple[dict[str, Any], dict[str, Any]]: | ||
"""Handles content block delta updates by appending text, tool input, or reasoning content to the state. | ||
|
||
Args: | ||
event: Delta event. | ||
state: The current state of message processing. | ||
callback_handler: Callback for processing events as they happen. | ||
**kwargs: Additional keyword arguments to pass to the callback handler. | ||
|
||
Returns: | ||
Updated state with appended text or tool input. | ||
""" | ||
delta_content = event["delta"] | ||
|
||
callback_event = {} | ||
|
||
if "toolUse" in delta_content: | ||
if "input" not in state["current_tool_use"]: | ||
state["current_tool_use"]["input"] = "" | ||
|
||
state["current_tool_use"]["input"] += delta_content["toolUse"]["input"] | ||
callback_handler(delta=delta_content, current_tool_use=state["current_tool_use"], **kwargs) | ||
callback_event["callback"] = {"delta": delta_content, "current_tool_use": state["current_tool_use"]} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a temporary format to maintain backwards compatibility. We are planning on formalizing and strongly typing the construction of our agent events for consistency and ease of use. |
||
|
||
elif "text" in delta_content: | ||
state["text"] += delta_content["text"] | ||
callback_handler(data=delta_content["text"], delta=delta_content, **kwargs) | ||
callback_event["callback"] = {"data": delta_content["text"], "delta": delta_content} | ||
|
||
elif "reasoningContent" in delta_content: | ||
if "text" in delta_content["reasoningContent"]: | ||
if "reasoningText" not in state: | ||
state["reasoningText"] = "" | ||
|
||
state["reasoningText"] += delta_content["reasoningContent"]["text"] | ||
callback_handler( | ||
reasoningText=delta_content["reasoningContent"]["text"], | ||
delta=delta_content, | ||
reasoning=True, | ||
**kwargs, | ||
) | ||
callback_event["callback"] = { | ||
"reasoningText": delta_content["reasoningContent"]["text"], | ||
"delta": delta_content, | ||
"reasoning": True, | ||
} | ||
|
||
elif "signature" in delta_content["reasoningContent"]: | ||
if "signature" not in state: | ||
state["signature"] = "" | ||
|
||
state["signature"] += delta_content["reasoningContent"]["signature"] | ||
callback_handler( | ||
reasoning_signature=delta_content["reasoningContent"]["signature"], | ||
delta=delta_content, | ||
reasoning=True, | ||
**kwargs, | ||
) | ||
callback_event["callback"] = { | ||
"reasoning_signature": delta_content["reasoningContent"]["signature"], | ||
"delta": delta_content, | ||
"reasoning": True, | ||
} | ||
|
||
return state | ||
return state, callback_event | ||
|
||
|
||
def handle_content_block_stop(state: Dict[str, Any]) -> Dict[str, Any]: | ||
def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]: | ||
"""Handles the end of a content block by finalizing tool usage, text content, or reasoning content. | ||
|
||
Args: | ||
|
@@ -165,7 +163,7 @@ def handle_content_block_stop(state: Dict[str, Any]) -> Dict[str, Any]: | |
Returns: | ||
Updated state with finalized content block. | ||
""" | ||
content: List[ContentBlock] = state["content"] | ||
content: list[ContentBlock] = state["content"] | ||
|
||
current_tool_use = state["current_tool_use"] | ||
text = state["text"] | ||
|
@@ -223,7 +221,7 @@ def handle_message_stop(event: MessageStopEvent) -> StopReason: | |
return event["stopReason"] | ||
|
||
|
||
def handle_redact_content(event: RedactContentEvent, messages: Messages, state: Dict[str, Any]) -> None: | ||
def handle_redact_content(event: RedactContentEvent, messages: Messages, state: dict[str, Any]) -> None: | ||
"""Handles redacting content from the input or output. | ||
|
||
Args: | ||
|
@@ -238,7 +236,7 @@ def handle_redact_content(event: RedactContentEvent, messages: Messages, state: | |
state["message"]["content"] = [{"text": event["redactAssistantContentMessage"]}] | ||
|
||
|
||
def extract_usage_metrics(event: MetadataEvent) -> Tuple[Usage, Metrics]: | ||
def extract_usage_metrics(event: MetadataEvent) -> tuple[Usage, Metrics]: | ||
"""Extracts usage metrics from the metadata chunk. | ||
|
||
Args: | ||
|
@@ -255,25 +253,20 @@ def extract_usage_metrics(event: MetadataEvent) -> Tuple[Usage, Metrics]: | |
|
||
def process_stream( | ||
chunks: Iterable[StreamEvent], | ||
callback_handler: Any, | ||
messages: Messages, | ||
**kwargs: Any, | ||
) -> Tuple[StopReason, Message, Usage, Metrics, Any]: | ||
) -> Generator[dict[str, Any], None, None]: | ||
"""Processes the response stream from the API, constructing the final message and extracting usage metrics. | ||
|
||
Args: | ||
chunks: The chunks of the response stream from the model. | ||
callback_handler: Callback for processing events as they happen. | ||
messages: The agents messages. | ||
**kwargs: Additional keyword arguments that will be passed to the callback handler. | ||
And also returned in the request_state. | ||
|
||
Returns: | ||
The reason for stopping, the constructed message, the usage metrics, and the updated request state. | ||
The reason for stopping, the constructed message, and the usage metrics. | ||
""" | ||
stop_reason: StopReason = "end_turn" | ||
|
||
state: Dict[str, Any] = { | ||
state: dict[str, Any] = { | ||
"message": {"role": "assistant", "content": []}, | ||
"text": "", | ||
"current_tool_use": {}, | ||
|
@@ -285,18 +278,16 @@ def process_stream( | |
usage: Usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0) | ||
metrics: Metrics = Metrics(latencyMs=0) | ||
|
||
kwargs.setdefault("request_state", {}) | ||
|
||
for chunk in chunks: | ||
# Callback handler call here allows each event to be visible to the caller | ||
callback_handler(event=chunk) | ||
yield {"callback": {"event": chunk}} | ||
|
||
if "messageStart" in chunk: | ||
state["message"] = handle_message_start(chunk["messageStart"], state["message"]) | ||
elif "contentBlockStart" in chunk: | ||
state["current_tool_use"] = handle_content_block_start(chunk["contentBlockStart"]) | ||
elif "contentBlockDelta" in chunk: | ||
state = handle_content_block_delta(chunk["contentBlockDelta"], state, callback_handler, **kwargs) | ||
state, callback_event = handle_content_block_delta(chunk["contentBlockDelta"], state) | ||
yield callback_event | ||
elif "contentBlockStop" in chunk: | ||
state = handle_content_block_stop(state) | ||
elif "messageStop" in chunk: | ||
|
@@ -306,35 +297,30 @@ def process_stream( | |
elif "redactContent" in chunk: | ||
handle_redact_content(chunk["redactContent"], messages, state) | ||
|
||
return stop_reason, state["message"], usage, metrics, kwargs["request_state"] | ||
yield {"stop": (stop_reason, state["message"], usage, metrics)} | ||
|
||
|
||
def stream_messages( | ||
model: Model, | ||
system_prompt: Optional[str], | ||
messages: Messages, | ||
tool_config: Optional[ToolConfig], | ||
callback_handler: Any, | ||
**kwargs: Any, | ||
) -> Tuple[StopReason, Message, Usage, Metrics, Any]: | ||
) -> Generator[dict[str, Any], None, None]: | ||
"""Streams messages to the model and processes the response. | ||
|
||
Args: | ||
model: Model provider. | ||
system_prompt: The system prompt to send. | ||
messages: List of messages to send. | ||
tool_config: Configuration for the tools to use. | ||
callback_handler: Callback for processing events as they happen. | ||
**kwargs: Additional keyword arguments that will be passed to the callback handler. | ||
And also returned in the request_state. | ||
|
||
Returns: | ||
The reason for stopping, the final message, the usage metrics, and updated request state. | ||
The reason for stopping, the final message, and the usage metrics | ||
""" | ||
logger.debug("model=<%s> | streaming messages", model) | ||
|
||
messages = remove_blank_messages_content_text(messages) | ||
tool_specs = [tool["toolSpec"] for tool in tool_config.get("tools", [])] or None if tool_config else None | ||
|
||
chunks = model.converse(messages, tool_specs, system_prompt) | ||
return process_stream(chunks, callback_handler, messages, **kwargs) | ||
yield from process_stream(chunks, messages) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -733,3 +733,103 @@ def test_event_loop_cycle_with_parent_span( | |
mock_tracer.start_event_loop_cycle_span.assert_called_once_with( | ||
event_loop_kwargs=unittest.mock.ANY, parent_span=parent_span, messages=messages | ||
) | ||
|
||
|
||
def test_event_loop_cycle_callback( | ||
model, | ||
model_id, | ||
system_prompt, | ||
messages, | ||
tool_config, | ||
callback_handler, | ||
tool_handler, | ||
tool_execution_handler, | ||
): | ||
model.converse.return_value = [ | ||
{"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "test"}}}}, | ||
{"contentBlockDelta": {"delta": {"toolUse": {"input": '{"value"}'}}}}, | ||
{"contentBlockStop": {}}, | ||
{"contentBlockStart": {"start": {}}}, | ||
{"contentBlockDelta": {"delta": {"reasoningContent": {"text": "value"}}}}, | ||
{"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "value"}}}}, | ||
{"contentBlockStop": {}}, | ||
{"contentBlockStart": {"start": {}}}, | ||
{"contentBlockDelta": {"delta": {"text": "value"}}}, | ||
{"contentBlockStop": {}}, | ||
] | ||
|
||
strands.event_loop.event_loop.event_loop_cycle( | ||
model=model, | ||
model_id=model_id, | ||
system_prompt=system_prompt, | ||
messages=messages, | ||
tool_config=tool_config, | ||
callback_handler=callback_handler, | ||
tool_handler=tool_handler, | ||
tool_execution_handler=tool_execution_handler, | ||
) | ||
|
||
callback_handler.assert_has_calls( | ||
[ | ||
call(start=True), | ||
call(start_event_loop=True), | ||
call(event={"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "test"}}}}), | ||
call(event={"contentBlockDelta": {"delta": {"toolUse": {"input": '{"value"}'}}}}), | ||
call( | ||
delta={"toolUse": {"input": '{"value"}'}}, | ||
current_tool_use={"toolUseId": "123", "name": "test", "input": {}}, | ||
model_id="m1", | ||
event_loop_cycle_id=unittest.mock.ANY, | ||
request_state={}, | ||
event_loop_cycle_trace=unittest.mock.ANY, | ||
event_loop_cycle_span=None, | ||
), | ||
call(event={"contentBlockStop": {}}), | ||
call(event={"contentBlockStart": {"start": {}}}), | ||
call(event={"contentBlockDelta": {"delta": {"reasoningContent": {"text": "value"}}}}), | ||
call( | ||
reasoningText="value", | ||
delta={"reasoningContent": {"text": "value"}}, | ||
reasoning=True, | ||
model_id="m1", | ||
event_loop_cycle_id=unittest.mock.ANY, | ||
request_state={}, | ||
event_loop_cycle_trace=unittest.mock.ANY, | ||
event_loop_cycle_span=None, | ||
), | ||
call(event={"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "value"}}}}), | ||
call( | ||
reasoning_signature="value", | ||
delta={"reasoningContent": {"signature": "value"}}, | ||
reasoning=True, | ||
model_id="m1", | ||
event_loop_cycle_id=unittest.mock.ANY, | ||
request_state={}, | ||
event_loop_cycle_trace=unittest.mock.ANY, | ||
event_loop_cycle_span=None, | ||
), | ||
call(event={"contentBlockStop": {}}), | ||
call(event={"contentBlockStart": {"start": {}}}), | ||
call(event={"contentBlockDelta": {"delta": {"text": "value"}}}), | ||
call( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a comment about duplicate events yielded while streaming. This is an example (813-814) and will be addressed as part of #242. |
||
data="value", | ||
delta={"text": "value"}, | ||
model_id="m1", | ||
event_loop_cycle_id=unittest.mock.ANY, | ||
request_state={}, | ||
event_loop_cycle_trace=unittest.mock.ANY, | ||
event_loop_cycle_span=None, | ||
), | ||
call(event={"contentBlockStop": {}}), | ||
call( | ||
message={ | ||
"role": "assistant", | ||
"content": [ | ||
{"toolUse": {"toolUseId": "123", "name": "test", "input": {}}}, | ||
{"reasoningContent": {"reasoningText": {"text": "value", "signature": "value"}}}, | ||
{"text": "value"}, | ||
], | ||
}, | ||
), | ||
], | ||
) |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zastrowm had the following comment:
This was a good call out. We didn't actually have the best coverage on callbacks. I added a test that actually caught an issue with passing kwargs which we need for backwards compatibility. I personally don't think passing kwargs is useful and so we should revisit when we formally type events (tracking issue).