Skip to content
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ __pycache__*
.pytest_cache
.ruff_cache
*.bak
.vscode
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "strands-agents"
version = "0.1.4"
version = "0.1.5"
description = "A model-driven approach to building AI agents in just a few lines of code"
readme = "README.md"
requires-python = ">=3.10"
Expand Down
105 changes: 70 additions & 35 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,27 +328,17 @@ def __call__(self, prompt: str, **kwargs: Any) -> AgentResult:
- metrics: Performance metrics from the event loop
- state: The final state of the event loop
"""
model_id = self.model.config.get("model_id") if hasattr(self.model, "config") else None

self.trace_span = self.tracer.start_agent_span(
prompt=prompt,
model_id=model_id,
tools=self.tool_names,
system_prompt=self.system_prompt,
custom_trace_attributes=self.trace_attributes,
)
self._start_agent_trace_span(prompt)

try:
# Run the event loop and get the result
result = self._run_loop(prompt, kwargs)

if self.trace_span:
self.tracer.end_agent_span(span=self.trace_span, response=result)
self._end_agent_trace_span(response=result)

return result
except Exception as e:
if self.trace_span:
self.tracer.end_agent_span(span=self.trace_span, error=e)
self._end_agent_trace_span(error=e)

# Re-raise the exception to preserve original behavior
raise
Expand Down Expand Up @@ -383,6 +373,8 @@ async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]:
yield event["data"]
```
"""
self._start_agent_trace_span(prompt)

_stop_event = uuid4()

queue = asyncio.Queue[Any]()
Expand All @@ -400,8 +392,10 @@ def target_callback() -> None:
nonlocal kwargs

try:
self._run_loop(prompt, kwargs, supplementary_callback_handler=queuing_callback_handler)
except BaseException as e:
result = self._run_loop(prompt, kwargs, supplementary_callback_handler=queuing_callback_handler)
self._end_agent_trace_span(response=result)
except Exception as e:
self._end_agent_trace_span(error=e)
enqueue(e)
finally:
enqueue(_stop_event)
Expand All @@ -414,7 +408,7 @@ def target_callback() -> None:
item = await queue.get()
if item == _stop_event:
break
if isinstance(item, BaseException):
if isinstance(item, Exception):
raise item
yield item
finally:
Expand Down Expand Up @@ -457,27 +451,28 @@ def _execute_event_loop_cycle(self, callback_handler: Callable, kwargs: dict[str
Returns:
The result of the event loop cycle.
"""
kwargs.pop("agent", None)
kwargs.pop("model", None)
kwargs.pop("system_prompt", None)
kwargs.pop("tool_execution_handler", None)
kwargs.pop("event_loop_metrics", None)
kwargs.pop("callback_handler", None)
kwargs.pop("tool_handler", None)
kwargs.pop("messages", None)
kwargs.pop("tool_config", None)
# Extract parameters with fallbacks to instance values
system_prompt = kwargs.pop("system_prompt", self.system_prompt)
model = kwargs.pop("model", self.model)
tool_execution_handler = kwargs.pop("tool_execution_handler", self.thread_pool_wrapper)
event_loop_metrics = kwargs.pop("event_loop_metrics", self.event_loop_metrics)
callback_handler_override = kwargs.pop("callback_handler", callback_handler)
tool_handler = kwargs.pop("tool_handler", self.tool_handler)
messages = kwargs.pop("messages", self.messages)
tool_config = kwargs.pop("tool_config", self.tool_config)
kwargs.pop("agent", None) # Remove agent to avoid conflicts

try:
# Execute the main event loop cycle
stop_reason, message, metrics, state = event_loop_cycle(
model=self.model,
system_prompt=self.system_prompt,
messages=self.messages, # will be modified by event_loop_cycle
tool_config=self.tool_config,
callback_handler=callback_handler,
tool_handler=self.tool_handler,
tool_execution_handler=self.thread_pool_wrapper,
event_loop_metrics=self.event_loop_metrics,
model=model,
system_prompt=system_prompt,
messages=messages, # will be modified by event_loop_cycle
tool_config=tool_config,
callback_handler=callback_handler_override,
tool_handler=tool_handler,
tool_execution_handler=tool_execution_handler,
event_loop_metrics=event_loop_metrics,
agent=self,
event_loop_parent_span=self.trace_span,
**kwargs,
Expand All @@ -488,8 +483,8 @@ def _execute_event_loop_cycle(self, callback_handler: Callable, kwargs: dict[str
except ContextWindowOverflowException as e:
# Try reducing the context size and retrying

self.conversation_manager.reduce_context(self.messages, e=e)
return self._execute_event_loop_cycle(callback_handler, kwargs)
self.conversation_manager.reduce_context(messages, e=e)
return self._execute_event_loop_cycle(callback_handler_override, kwargs)

def _record_tool_execution(
self,
Expand Down Expand Up @@ -545,3 +540,43 @@ def _record_tool_execution(
messages.append(tool_use_msg)
messages.append(tool_result_msg)
messages.append(assistant_msg)

def _start_agent_trace_span(self, prompt: str) -> None:
"""Starts a trace span for the agent.

Args:
prompt: The natural language prompt from the user.
"""
model_id = self.model.config.get("model_id") if hasattr(self.model, "config") else None

self.trace_span = self.tracer.start_agent_span(
prompt=prompt,
model_id=model_id,
tools=self.tool_names,
system_prompt=self.system_prompt,
custom_trace_attributes=self.trace_attributes,
)

def _end_agent_trace_span(
self,
response: Optional[AgentResult] = None,
error: Optional[Exception] = None,
) -> None:
"""Ends a trace span for the agent.

Args:
span: The span to end.
response: Response to record as a trace attribute.
error: Error to record as a trace attribute.
"""
if self.trace_span:
trace_attributes: Dict[str, Any] = {
"span": self.trace_span,
}

if response:
trace_attributes["response"] = response
if error:
trace_attributes["error"] = error

self.tracer.end_agent_span(**trace_attributes)
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
"""Sliding window conversation history management."""

import json
import logging
from typing import List, Optional, cast
from typing import Optional

from ...types.content import ContentBlock, Message, Messages
from ...types.content import Message, Messages
from ...types.exceptions import ContextWindowOverflowException
from ...types.tools import ToolResult
from .conversation_manager import ConversationManager

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -110,8 +108,9 @@ def _remove_dangling_messages(self, messages: Messages) -> None:
def reduce_context(self, messages: Messages, e: Optional[Exception] = None) -> None:
"""Trim the oldest messages to reduce the conversation context size.

The method handles special cases where tool results need to be converted to regular content blocks to maintain
conversation coherence after trimming.
The method handles special cases where trimming the messages leads to:
- toolResult with no corresponding toolUse
- toolUse with no corresponding toolResult

Args:
messages: The messages to reduce.
Expand All @@ -126,52 +125,24 @@ def reduce_context(self, messages: Messages, e: Optional[Exception] = None) -> N
# If the number of messages is less than the window_size, then we default to 2, otherwise, trim to window size
trim_index = 2 if len(messages) <= self.window_size else len(messages) - self.window_size

# Throw if we cannot trim any messages from the conversation
if trim_index >= len(messages):
raise ContextWindowOverflowException("Unable to trim conversation context!") from e

# If the message at the cut index has ToolResultContent, then we map that to ContentBlock. This gets around the
# limitation of needing ToolUse and ToolResults to be paired.
if any("toolResult" in content for content in messages[trim_index]["content"]):
if len(messages[trim_index]["content"]) == 1:
messages[trim_index]["content"] = self._map_tool_result_content(
cast(ToolResult, messages[trim_index]["content"][0]["toolResult"])
# Find the next valid trim_index
while trim_index < len(messages):
if (
# Oldest message cannot be a toolResult because it needs a toolUse preceding it
any("toolResult" in content for content in messages[trim_index]["content"])
or (
# Oldest message can be a toolUse only if a toolResult immediately follows it.
any("toolUse" in content for content in messages[trim_index]["content"])
and trim_index + 1 < len(messages)
and not any("toolResult" in content for content in messages[trim_index + 1]["content"])
)

# If there is more content than just one ToolResultContent, then we cannot cut at this index.
):
trim_index += 1
else:
raise ContextWindowOverflowException("Unable to trim conversation context!") from e
break
else:
# If we didn't find a valid trim_index, then we throw
raise ContextWindowOverflowException("Unable to trim conversation context!") from e

# Overwrite message history
messages[:] = messages[trim_index:]

def _map_tool_result_content(self, tool_result: ToolResult) -> List[ContentBlock]:
"""Convert a ToolResult to a list of standard ContentBlocks.

This method transforms tool result content into standard content blocks that can be preserved when trimming the
conversation history.

Args:
tool_result: The ToolResult to convert.

Returns:
A list of content blocks representing the tool result.
"""
contents = []
text_content = "Tool Result Status: " + tool_result["status"] if tool_result["status"] else ""

for tool_result_content in tool_result["content"]:
if "text" in tool_result_content:
text_content = "\nTool Result Text Content: " + tool_result_content["text"] + f"\n{text_content}"
elif "json" in tool_result_content:
text_content = (
"\nTool Result JSON Content: " + json.dumps(tool_result_content["json"]) + f"\n{text_content}"
)
elif "image" in tool_result_content:
contents.append(ContentBlock(image=tool_result_content["image"]))
elif "document" in tool_result_content:
contents.append(ContentBlock(document=tool_result_content["document"]))
else:
logger.warning("unsupported content type")
contents.append(ContentBlock(text=text_content))
return contents
Loading