Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ test-lint = [
"hatch fmt --linter --check"
]
test = [
"hatch test --cover --cov-report html --cov-report xml {args}"
"hatch test --cover --cov-report term-missing --cov-report html --cov-report xml {args}"
]
test-integ = [
"hatch test tests-integ {args}"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
"""Sliding window conversation history management."""

import json
import logging
from typing import List, Optional, cast
from typing import List, Optional

from ...types.content import ContentBlock, Message, Messages
from ...types.content import Message, Messages
from ...types.exceptions import ContextWindowOverflowException
from ...types.tools import ToolResult
from .conversation_manager import ConversationManager

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -36,6 +34,34 @@ def is_assistant_message(message: Message) -> bool:
return message["role"] == "assistant"


def has_tool_use(message: Message) -> bool:
"""Check if a message contains toolUse content."""
return any("toolUse" in content for content in message["content"])


def has_tool_result(message: Message) -> bool:
"""Check if a message contains toolResult content."""
return any("toolResult" in content for content in message["content"])


def get_tool_use_ids(message: Message) -> List[str]:
"""Get all toolUse IDs from a message."""
ids = []
for content in message["content"]:
if "toolUse" in content:
ids.append(content["toolUse"]["toolUseId"])
return ids


def get_tool_result_ids(message: Message) -> List[str]:
"""Get all toolResult IDs from a message."""
ids = []
for content in message["content"]:
if "toolResult" in content:
ids.append(content["toolResult"]["toolUseId"])
return ids


class SlidingWindowConversationManager(ConversationManager):
"""Implements a sliding window strategy for managing conversation history.

Expand Down Expand Up @@ -95,23 +121,23 @@ def _remove_dangling_messages(self, messages: Messages) -> None:
"""
# remove any dangling user messages with no ToolResult
if len(messages) > 0 and is_user_message(messages[-1]):
if not any("toolResult" in content for content in messages[-1]["content"]):
if not has_tool_result(messages[-1]):
messages.pop()

# remove any dangling assistant messages with ToolUse
if len(messages) > 0 and is_assistant_message(messages[-1]):
if any("toolUse" in content for content in messages[-1]["content"]):
if has_tool_use(messages[-1]):
messages.pop()
# remove remaining dangling user messages with no ToolResult after we popped off an assistant message
if len(messages) > 0 and is_user_message(messages[-1]):
if not any("toolResult" in content for content in messages[-1]["content"]):
if not has_tool_result(messages[-1]):
messages.pop()

def reduce_context(self, messages: Messages, e: Optional[Exception] = None) -> None:
"""Trim the oldest messages to reduce the conversation context size.

The method handles special cases where tool results need to be converted to regular content blocks to maintain
conversation coherence after trimming.
The method ensures that tool use/result pairs are preserved together. If a cut would separate
a toolUse from its corresponding toolResult, it adjusts the cut point to include both.

Args:
messages: The messages to reduce.
Expand All @@ -120,58 +146,66 @@ def reduce_context(self, messages: Messages, e: Optional[Exception] = None) -> N

Raises:
ContextWindowOverflowException: If the context cannot be reduced further.
Such as when the conversation is already minimal or when tool result messages cannot be properly
converted.
"""
# If the number of messages is less than the window_size, then we default to 2, otherwise, trim to window size
# Calculate basic trim index
trim_index = 2 if len(messages) <= self.window_size else len(messages) - self.window_size

# Throw if we cannot trim any messages from the conversation
if trim_index >= len(messages):
raise ContextWindowOverflowException("Unable to trim conversation context!") from e

# If the message at the cut index has ToolResultContent, then we map that to ContentBlock. This gets around the
# limitation of needing ToolUse and ToolResults to be paired.
if any("toolResult" in content for content in messages[trim_index]["content"]):
if len(messages[trim_index]["content"]) == 1:
messages[trim_index]["content"] = self._map_tool_result_content(
cast(ToolResult, messages[trim_index]["content"][0]["toolResult"])
)
# Find a safe cutting point that preserves tool use/result pairs
safe_trim_index = self._find_safe_trim_index(messages, trim_index)

# If there is more content than just one ToolResultContent, then we cannot cut at this index.
else:
raise ContextWindowOverflowException("Unable to trim conversation context!") from e
# If we couldn't find a safe trim point within bounds, fall back to basic trim
if safe_trim_index >= len(messages):
logger.warning(
"safe_trim_index=<%d>, messages_length=<%d> | could not find safe trim point | "
"falling back to basic trim index",
safe_trim_index,
len(messages),
)
safe_trim_index = trim_index

# Overwrite message history
messages[:] = messages[trim_index:]
messages[:] = messages[safe_trim_index:]

def _map_tool_result_content(self, tool_result: ToolResult) -> List[ContentBlock]:
"""Convert a ToolResult to a list of standard ContentBlocks.
def _find_safe_trim_index(self, messages: Messages, initial_trim_index: int) -> int:
"""Find a safe cutting point that preserves tool use/result pairs.

This method transforms tool result content into standard content blocks that can be preserved when trimming the
conversation history.
This method ensures that tool use/result pairs are not separated by the trim.
It adjusts the trim index to keep related tool interactions together.

Args:
tool_result: The ToolResult to convert.
messages: The complete message history
initial_trim_index: The initial trim index based on window size

Returns:
A list of content blocks representing the tool result.
A safe trim index that preserves tool use/result pairs
"""
contents = []
text_content = "Tool Result Status: " + tool_result["status"] if tool_result["status"] else ""

for tool_result_content in tool_result["content"]:
if "text" in tool_result_content:
text_content = "\nTool Result Text Content: " + tool_result_content["text"] + f"\n{text_content}"
elif "json" in tool_result_content:
text_content = (
"\nTool Result JSON Content: " + json.dumps(tool_result_content["json"]) + f"\n{text_content}"
)
elif "image" in tool_result_content:
contents.append(ContentBlock(image=tool_result_content["image"]))
elif "document" in tool_result_content:
contents.append(ContentBlock(document=tool_result_content["document"]))
else:
logger.warning("unsupported content type")
contents.append(ContentBlock(text=text_content))
return contents
# Build a map of tool IDs to their message indices
tool_use_indices = {} # toolUseId -> message index
tool_result_indices = {} # toolUseId -> message index

for i, message in enumerate(messages):
Copy link

@fede-kamel fede-kamel May 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_tool_use_ids() and get_tool_result_ids() here, consider inlining the logic to avoid scanning message["content"] twice per message and allocating temporary lists.

A single-pass loop over message["content"] can extract both toolUseId and toolResultId efficiently:

for content in message.get("content", []):
    if "toolUse" in content:
        tool_id = content["toolUse"].get("toolUseId")
        if tool_id:
            tool_use_indices[tool_id] = i
    if "toolResult" in content:
        tool_id = content["toolResult"].get("toolUseId")
        if tool_id:
            tool_result_indices[tool_id] = i

This avoids unnecessary allocations and keeps the loop flat

for tool_id in get_tool_use_ids(message):
tool_use_indices[tool_id] = i
for tool_id in get_tool_result_ids(message):
tool_result_indices[tool_id] = i

# Start from the initial trim index
safe_index = initial_trim_index

# Adjust if we would cut in the middle of a tool use/result pair
for tool_id, use_idx in tool_use_indices.items():

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you would like to avoid so much nesting

for tool_id, use_idx in tool_use_indices.items():
    if tool_id not in tool_result_indices:
        continue

    result_idx = tool_result_indices[tool_id]

    # If the pair would be split by the cut, move it earlier to keep them together
    if use_idx < safe_index <= result_idx:
        safe_index = min(safe_index, use_idx)
        continue

    # Invalid ordering: toolResult appears before toolUse
    if result_idx < safe_index < use_idx:
        logger.warning("tool_id=<%s> | found toolResult before toolUse", tool_id)

if tool_id in tool_result_indices:
result_idx = tool_result_indices[tool_id]
# If the pair would be split by the cut
if use_idx < safe_index <= result_idx:
# Move the cut to before the tool use to keep the pair together
safe_index = min(safe_index, use_idx)
elif result_idx < safe_index < use_idx:
# This shouldn't happen in valid conversations
logger.warning("tool_id=<%s> | found toolResult before toolUse", tool_id)

return safe_index
Loading