Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 6 additions & 198 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@
2. Method-style for direct tool access: `agent.tool.tool_name(param1="value")`
"""

import json
import logging
import random
import warnings
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -52,16 +50,16 @@
from ..session.session_manager import SessionManager
from ..telemetry.metrics import EventLoopMetrics
from ..telemetry.tracer import get_tracer, serialize
from ..tools._caller import _ToolCaller
from ..tools.executors import ConcurrentToolExecutor
from ..tools.executors._executor import ToolExecutor
from ..tools.registry import ToolRegistry
from ..tools.structured_output._structured_output_context import StructuredOutputContext
from ..tools.watcher import ToolWatcher
from ..types._events import AgentResultEvent, InitEventLoopEvent, ModelStreamChunkEvent, ToolInterruptEvent, TypedEvent
from ..types._events import AgentResultEvent, InitEventLoopEvent, ModelStreamChunkEvent, TypedEvent
from ..types.agent import AgentInput
from ..types.content import ContentBlock, Message, Messages, SystemContentBlock
from ..types.exceptions import ContextWindowOverflowException
from ..types.tools import ToolResult, ToolUse
from ..types.traces import AttributeValue
from .agent_result import AgentResult
from .conversation_manager import (
Expand Down Expand Up @@ -101,114 +99,8 @@ class Agent:
6. Produces a final response
"""

class ToolCaller:
"""Call tool as a function."""

def __init__(self, agent: "Agent") -> None:
"""Initialize instance.

Args:
agent: Agent reference that will accept tool results.
"""
# WARNING: Do not add any other member variables or methods as this could result in a name conflict with
# agent tools and thus break their execution.
self._agent = agent

def __getattr__(self, name: str) -> Callable[..., Any]:
"""Call tool as a function.

This method enables the method-style interface (e.g., `agent.tool.tool_name(param="value")`).
It matches underscore-separated names to hyphenated tool names (e.g., 'some_thing' matches 'some-thing').

Args:
name: The name of the attribute (tool) being accessed.

Returns:
A function that when called will execute the named tool.

Raises:
AttributeError: If no tool with the given name exists or if multiple tools match the given name.
"""

def caller(
user_message_override: Optional[str] = None,
record_direct_tool_call: Optional[bool] = None,
**kwargs: Any,
) -> Any:
"""Call a tool directly by name.

Args:
user_message_override: Optional custom message to record instead of default
record_direct_tool_call: Whether to record direct tool calls in message history. Overrides class
attribute if provided.
**kwargs: Keyword arguments to pass to the tool.

Returns:
The result returned by the tool.

Raises:
AttributeError: If the tool doesn't exist.
"""
if self._agent._interrupt_state.activated:
raise RuntimeError("cannot directly call tool during interrupt")

normalized_name = self._find_normalized_tool_name(name)

# Create unique tool ID and set up the tool request
tool_id = f"tooluse_{name}_{random.randint(100000000, 999999999)}"
tool_use: ToolUse = {
"toolUseId": tool_id,
"name": normalized_name,
"input": kwargs.copy(),
}
tool_results: list[ToolResult] = []
invocation_state = kwargs

async def acall() -> ToolResult:
async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state):
if isinstance(event, ToolInterruptEvent):
self._agent._interrupt_state.deactivate()
raise RuntimeError("cannot raise interrupt in direct tool call")

tool_result = tool_results[0]

if record_direct_tool_call is not None:
should_record_direct_tool_call = record_direct_tool_call
else:
should_record_direct_tool_call = self._agent.record_direct_tool_call

if should_record_direct_tool_call:
# Create a record of this tool execution in the message history
await self._agent._record_tool_execution(tool_use, tool_result, user_message_override)

return tool_result

tool_result = run_async(acall)
self._agent.conversation_manager.apply_management(self._agent)
return tool_result

return caller

def _find_normalized_tool_name(self, name: str) -> str:
"""Lookup the tool represented by name, replacing characters with underscores as necessary."""
tool_registry = self._agent.tool_registry.registry

if tool_registry.get(name, None):
return name

# If the desired name contains underscores, it might be a placeholder for characters that can't be
# represented as python identifiers but are valid as tool names, such as dashes. In that case, find
# all tools that can be represented with the normalized name
if "_" in name:
filtered_tools = [
tool_name for (tool_name, tool) in tool_registry.items() if tool_name.replace("-", "_") == name
]

# The registry itself defends against similar names, so we can just take the first match
if filtered_tools:
return filtered_tools[0]

raise AttributeError(f"Tool '{name}' not found")
# For backwards compatibility
ToolCaller = _ToolCaller

def __init__(
self,
Expand Down Expand Up @@ -347,7 +239,7 @@ def __init__(
else:
self.state = AgentState()

self.tool_caller = Agent.ToolCaller(self)
self.tool_caller = _ToolCaller(self)

self.hooks = HookRegistry()

Expand Down Expand Up @@ -395,7 +287,7 @@ def system_prompt(self, value: str | list[SystemContentBlock] | None) -> None:
self._system_prompt, self._system_prompt_content = self._initialize_system_prompt(value)

@property
def tool(self) -> ToolCaller:
def tool(self) -> _ToolCaller:
"""Call tool as a function.

Returns:
Expand Down Expand Up @@ -854,71 +746,6 @@ async def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages:
raise ValueError("Input prompt must be of type: `str | list[Contentblock] | Messages | None`.")
return messages

async def _record_tool_execution(
self,
tool: ToolUse,
tool_result: ToolResult,
user_message_override: Optional[str],
) -> None:
"""Record a tool execution in the message history.

Creates a sequence of messages that represent the tool execution:

1. A user message describing the tool call
2. An assistant message with the tool use
3. A user message with the tool result
4. An assistant message acknowledging the tool call

Args:
tool: The tool call information.
tool_result: The result returned by the tool.
user_message_override: Optional custom message to include.
"""
# Filter tool input parameters to only include those defined in tool spec
filtered_input = self._filter_tool_parameters_for_recording(tool["name"], tool["input"])

# Create user message describing the tool call
input_parameters = json.dumps(filtered_input, default=lambda o: f"<<non-serializable: {type(o).__qualname__}>>")

user_msg_content: list[ContentBlock] = [
{"text": (f"agent.tool.{tool['name']} direct tool call.\nInput parameters: {input_parameters}\n")}
]

# Add override message if provided
if user_message_override:
user_msg_content.insert(0, {"text": f"{user_message_override}\n"})

# Create filtered tool use for message history
filtered_tool: ToolUse = {
"toolUseId": tool["toolUseId"],
"name": tool["name"],
"input": filtered_input,
}

# Create the message sequence
user_msg: Message = {
"role": "user",
"content": user_msg_content,
}
tool_use_msg: Message = {
"role": "assistant",
"content": [{"toolUse": filtered_tool}],
}
tool_result_msg: Message = {
"role": "user",
"content": [{"toolResult": tool_result}],
}
assistant_msg: Message = {
"role": "assistant",
"content": [{"text": f"agent.tool.{tool['name']} was called."}],
}

# Add to message history
await self._append_message(user_msg)
await self._append_message(tool_use_msg)
await self._append_message(tool_result_msg)
await self._append_message(assistant_msg)

def _start_agent_trace_span(self, messages: Messages) -> trace_api.Span:
"""Starts a trace span for the agent.

Expand Down Expand Up @@ -960,25 +787,6 @@ def _end_agent_trace_span(

self.tracer.end_agent_span(**trace_attributes)

def _filter_tool_parameters_for_recording(self, tool_name: str, input_params: dict[str, Any]) -> dict[str, Any]:
"""Filter input parameters to only include those defined in the tool specification.

Args:
tool_name: Name of the tool to get specification for
input_params: Original input parameters

Returns:
Filtered parameters containing only those defined in tool spec
"""
all_tools_config = self.tool_registry.get_all_tools_config()
tool_spec = all_tools_config.get(tool_name)

if not tool_spec or "inputSchema" not in tool_spec:
return input_params.copy()

properties = tool_spec["inputSchema"]["json"]["properties"]
return {k: v for k, v in input_params.items() if k in properties}

def _initialize_system_prompt(
self, system_prompt: str | list[SystemContentBlock] | None
) -> tuple[str | None, list[SystemContentBlock] | None]:
Expand Down
Loading
Loading