generated from amazon-archives/__template_Apache-2.0
-
Notifications
You must be signed in to change notification settings - Fork 466
fix(conversation): preserve tool result JSON structure in sliding window management #94
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,12 +1,10 @@ | ||
| """Sliding window conversation history management.""" | ||
|
|
||
| import json | ||
| import logging | ||
| from typing import List, Optional, cast | ||
| from typing import List, Optional | ||
|
|
||
| from ...types.content import ContentBlock, Message, Messages | ||
| from ...types.content import Message, Messages | ||
| from ...types.exceptions import ContextWindowOverflowException | ||
| from ...types.tools import ToolResult | ||
| from .conversation_manager import ConversationManager | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
@@ -36,6 +34,34 @@ def is_assistant_message(message: Message) -> bool: | |
| return message["role"] == "assistant" | ||
|
|
||
|
|
||
| def has_tool_use(message: Message) -> bool: | ||
| """Check if a message contains toolUse content.""" | ||
| return any("toolUse" in content for content in message["content"]) | ||
|
|
||
|
|
||
| def has_tool_result(message: Message) -> bool: | ||
| """Check if a message contains toolResult content.""" | ||
| return any("toolResult" in content for content in message["content"]) | ||
|
|
||
|
|
||
| def get_tool_use_ids(message: Message) -> List[str]: | ||
| """Get all toolUse IDs from a message.""" | ||
| ids = [] | ||
| for content in message["content"]: | ||
| if "toolUse" in content: | ||
| ids.append(content["toolUse"]["toolUseId"]) | ||
| return ids | ||
|
|
||
|
|
||
| def get_tool_result_ids(message: Message) -> List[str]: | ||
| """Get all toolResult IDs from a message.""" | ||
| ids = [] | ||
| for content in message["content"]: | ||
| if "toolResult" in content: | ||
| ids.append(content["toolResult"]["toolUseId"]) | ||
| return ids | ||
|
|
||
|
|
||
| class SlidingWindowConversationManager(ConversationManager): | ||
| """Implements a sliding window strategy for managing conversation history. | ||
|
|
||
|
|
@@ -95,23 +121,23 @@ def _remove_dangling_messages(self, messages: Messages) -> None: | |
| """ | ||
| # remove any dangling user messages with no ToolResult | ||
| if len(messages) > 0 and is_user_message(messages[-1]): | ||
| if not any("toolResult" in content for content in messages[-1]["content"]): | ||
| if not has_tool_result(messages[-1]): | ||
| messages.pop() | ||
|
|
||
| # remove any dangling assistant messages with ToolUse | ||
| if len(messages) > 0 and is_assistant_message(messages[-1]): | ||
| if any("toolUse" in content for content in messages[-1]["content"]): | ||
| if has_tool_use(messages[-1]): | ||
| messages.pop() | ||
| # remove remaining dangling user messages with no ToolResult after we popped off an assistant message | ||
| if len(messages) > 0 and is_user_message(messages[-1]): | ||
| if not any("toolResult" in content for content in messages[-1]["content"]): | ||
| if not has_tool_result(messages[-1]): | ||
| messages.pop() | ||
|
|
||
| def reduce_context(self, messages: Messages, e: Optional[Exception] = None) -> None: | ||
| """Trim the oldest messages to reduce the conversation context size. | ||
|
|
||
| The method handles special cases where tool results need to be converted to regular content blocks to maintain | ||
| conversation coherence after trimming. | ||
| The method ensures that tool use/result pairs are preserved together. If a cut would separate | ||
| a toolUse from its corresponding toolResult, it adjusts the cut point to include both. | ||
|
|
||
| Args: | ||
| messages: The messages to reduce. | ||
|
|
@@ -120,58 +146,66 @@ def reduce_context(self, messages: Messages, e: Optional[Exception] = None) -> N | |
|
|
||
| Raises: | ||
| ContextWindowOverflowException: If the context cannot be reduced further. | ||
| Such as when the conversation is already minimal or when tool result messages cannot be properly | ||
| converted. | ||
| """ | ||
| # If the number of messages is less than the window_size, then we default to 2, otherwise, trim to window size | ||
| # Calculate basic trim index | ||
| trim_index = 2 if len(messages) <= self.window_size else len(messages) - self.window_size | ||
|
|
||
| # Throw if we cannot trim any messages from the conversation | ||
| if trim_index >= len(messages): | ||
| raise ContextWindowOverflowException("Unable to trim conversation context!") from e | ||
|
|
||
| # If the message at the cut index has ToolResultContent, then we map that to ContentBlock. This gets around the | ||
| # limitation of needing ToolUse and ToolResults to be paired. | ||
| if any("toolResult" in content for content in messages[trim_index]["content"]): | ||
| if len(messages[trim_index]["content"]) == 1: | ||
| messages[trim_index]["content"] = self._map_tool_result_content( | ||
| cast(ToolResult, messages[trim_index]["content"][0]["toolResult"]) | ||
| ) | ||
| # Find a safe cutting point that preserves tool use/result pairs | ||
| safe_trim_index = self._find_safe_trim_index(messages, trim_index) | ||
|
|
||
| # If there is more content than just one ToolResultContent, then we cannot cut at this index. | ||
| else: | ||
| raise ContextWindowOverflowException("Unable to trim conversation context!") from e | ||
| # If we couldn't find a safe trim point within bounds, fall back to basic trim | ||
| if safe_trim_index >= len(messages): | ||
| logger.warning( | ||
| "safe_trim_index=<%d>, messages_length=<%d> | could not find safe trim point | " | ||
| "falling back to basic trim index", | ||
| safe_trim_index, | ||
| len(messages), | ||
| ) | ||
| safe_trim_index = trim_index | ||
|
|
||
| # Overwrite message history | ||
| messages[:] = messages[trim_index:] | ||
| messages[:] = messages[safe_trim_index:] | ||
|
|
||
| def _map_tool_result_content(self, tool_result: ToolResult) -> List[ContentBlock]: | ||
| """Convert a ToolResult to a list of standard ContentBlocks. | ||
| def _find_safe_trim_index(self, messages: Messages, initial_trim_index: int) -> int: | ||
| """Find a safe cutting point that preserves tool use/result pairs. | ||
|
|
||
| This method transforms tool result content into standard content blocks that can be preserved when trimming the | ||
| conversation history. | ||
| This method ensures that tool use/result pairs are not separated by the trim. | ||
| It adjusts the trim index to keep related tool interactions together. | ||
|
|
||
| Args: | ||
| tool_result: The ToolResult to convert. | ||
| messages: The complete message history | ||
| initial_trim_index: The initial trim index based on window size | ||
|
|
||
| Returns: | ||
| A list of content blocks representing the tool result. | ||
| A safe trim index that preserves tool use/result pairs | ||
| """ | ||
| contents = [] | ||
| text_content = "Tool Result Status: " + tool_result["status"] if tool_result["status"] else "" | ||
|
|
||
| for tool_result_content in tool_result["content"]: | ||
| if "text" in tool_result_content: | ||
| text_content = "\nTool Result Text Content: " + tool_result_content["text"] + f"\n{text_content}" | ||
| elif "json" in tool_result_content: | ||
| text_content = ( | ||
| "\nTool Result JSON Content: " + json.dumps(tool_result_content["json"]) + f"\n{text_content}" | ||
| ) | ||
| elif "image" in tool_result_content: | ||
| contents.append(ContentBlock(image=tool_result_content["image"])) | ||
| elif "document" in tool_result_content: | ||
| contents.append(ContentBlock(document=tool_result_content["document"])) | ||
| else: | ||
| logger.warning("unsupported content type") | ||
| contents.append(ContentBlock(text=text_content)) | ||
| return contents | ||
| # Build a map of tool IDs to their message indices | ||
| tool_use_indices = {} # toolUseId -> message index | ||
| tool_result_indices = {} # toolUseId -> message index | ||
|
|
||
| for i, message in enumerate(messages): | ||
| for tool_id in get_tool_use_ids(message): | ||
| tool_use_indices[tool_id] = i | ||
| for tool_id in get_tool_result_ids(message): | ||
| tool_result_indices[tool_id] = i | ||
|
|
||
| # Start from the initial trim index | ||
| safe_index = initial_trim_index | ||
|
|
||
| # Adjust if we would cut in the middle of a tool use/result pair | ||
| for tool_id, use_idx in tool_use_indices.items(): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you would like to avoid so much nesting for tool_id, use_idx in tool_use_indices.items():
if tool_id not in tool_result_indices:
continue
result_idx = tool_result_indices[tool_id]
# If the pair would be split by the cut, move it earlier to keep them together
if use_idx < safe_index <= result_idx:
safe_index = min(safe_index, use_idx)
continue
# Invalid ordering: toolResult appears before toolUse
if result_idx < safe_index < use_idx:
logger.warning("tool_id=<%s> | found toolResult before toolUse", tool_id) |
||
| if tool_id in tool_result_indices: | ||
| result_idx = tool_result_indices[tool_id] | ||
| # If the pair would be split by the cut | ||
| if use_idx < safe_index <= result_idx: | ||
| # Move the cut to before the tool use to keep the pair together | ||
| safe_index = min(safe_index, use_idx) | ||
| elif result_idx < safe_index < use_idx: | ||
| # This shouldn't happen in valid conversations | ||
| logger.warning("tool_id=<%s> | found toolResult before toolUse", tool_id) | ||
|
|
||
| return safe_index | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_tool_use_ids() and get_tool_result_ids() here, consider inlining the logic to avoid scanning message["content"] twice per message and allocating temporary lists.
A single-pass loop over message["content"] can extract both toolUseId and toolResultId efficiently:
This avoids unnecessary allocations and keeps the loop flat