Skip to content

Commit a4837d4

Browse files
authored
move tool caller definition out of agent module (#1215)
1 parent f554cca commit a4837d4

File tree

4 files changed

+535
-488
lines changed

4 files changed

+535
-488
lines changed

src/strands/agent/agent.py

Lines changed: 6 additions & 198 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,7 @@
99
2. Method-style for direct tool access: `agent.tool.tool_name(param1="value")`
1010
"""
1111

12-
import json
1312
import logging
14-
import random
1513
import warnings
1614
from typing import (
1715
TYPE_CHECKING,
@@ -52,16 +50,16 @@
5250
from ..session.session_manager import SessionManager
5351
from ..telemetry.metrics import EventLoopMetrics
5452
from ..telemetry.tracer import get_tracer, serialize
53+
from ..tools._caller import _ToolCaller
5554
from ..tools.executors import ConcurrentToolExecutor
5655
from ..tools.executors._executor import ToolExecutor
5756
from ..tools.registry import ToolRegistry
5857
from ..tools.structured_output._structured_output_context import StructuredOutputContext
5958
from ..tools.watcher import ToolWatcher
60-
from ..types._events import AgentResultEvent, InitEventLoopEvent, ModelStreamChunkEvent, ToolInterruptEvent, TypedEvent
59+
from ..types._events import AgentResultEvent, InitEventLoopEvent, ModelStreamChunkEvent, TypedEvent
6160
from ..types.agent import AgentInput
6261
from ..types.content import ContentBlock, Message, Messages, SystemContentBlock
6362
from ..types.exceptions import ContextWindowOverflowException
64-
from ..types.tools import ToolResult, ToolUse
6563
from ..types.traces import AttributeValue
6664
from .agent_result import AgentResult
6765
from .conversation_manager import (
@@ -101,114 +99,8 @@ class Agent:
10199
6. Produces a final response
102100
"""
103101

104-
class ToolCaller:
105-
"""Call tool as a function."""
106-
107-
def __init__(self, agent: "Agent") -> None:
108-
"""Initialize instance.
109-
110-
Args:
111-
agent: Agent reference that will accept tool results.
112-
"""
113-
# WARNING: Do not add any other member variables or methods as this could result in a name conflict with
114-
# agent tools and thus break their execution.
115-
self._agent = agent
116-
117-
def __getattr__(self, name: str) -> Callable[..., Any]:
118-
"""Call tool as a function.
119-
120-
This method enables the method-style interface (e.g., `agent.tool.tool_name(param="value")`).
121-
It matches underscore-separated names to hyphenated tool names (e.g., 'some_thing' matches 'some-thing').
122-
123-
Args:
124-
name: The name of the attribute (tool) being accessed.
125-
126-
Returns:
127-
A function that when called will execute the named tool.
128-
129-
Raises:
130-
AttributeError: If no tool with the given name exists or if multiple tools match the given name.
131-
"""
132-
133-
def caller(
134-
user_message_override: Optional[str] = None,
135-
record_direct_tool_call: Optional[bool] = None,
136-
**kwargs: Any,
137-
) -> Any:
138-
"""Call a tool directly by name.
139-
140-
Args:
141-
user_message_override: Optional custom message to record instead of default
142-
record_direct_tool_call: Whether to record direct tool calls in message history. Overrides class
143-
attribute if provided.
144-
**kwargs: Keyword arguments to pass to the tool.
145-
146-
Returns:
147-
The result returned by the tool.
148-
149-
Raises:
150-
AttributeError: If the tool doesn't exist.
151-
"""
152-
if self._agent._interrupt_state.activated:
153-
raise RuntimeError("cannot directly call tool during interrupt")
154-
155-
normalized_name = self._find_normalized_tool_name(name)
156-
157-
# Create unique tool ID and set up the tool request
158-
tool_id = f"tooluse_{name}_{random.randint(100000000, 999999999)}"
159-
tool_use: ToolUse = {
160-
"toolUseId": tool_id,
161-
"name": normalized_name,
162-
"input": kwargs.copy(),
163-
}
164-
tool_results: list[ToolResult] = []
165-
invocation_state = kwargs
166-
167-
async def acall() -> ToolResult:
168-
async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state):
169-
if isinstance(event, ToolInterruptEvent):
170-
self._agent._interrupt_state.deactivate()
171-
raise RuntimeError("cannot raise interrupt in direct tool call")
172-
173-
tool_result = tool_results[0]
174-
175-
if record_direct_tool_call is not None:
176-
should_record_direct_tool_call = record_direct_tool_call
177-
else:
178-
should_record_direct_tool_call = self._agent.record_direct_tool_call
179-
180-
if should_record_direct_tool_call:
181-
# Create a record of this tool execution in the message history
182-
await self._agent._record_tool_execution(tool_use, tool_result, user_message_override)
183-
184-
return tool_result
185-
186-
tool_result = run_async(acall)
187-
self._agent.conversation_manager.apply_management(self._agent)
188-
return tool_result
189-
190-
return caller
191-
192-
def _find_normalized_tool_name(self, name: str) -> str:
193-
"""Lookup the tool represented by name, replacing characters with underscores as necessary."""
194-
tool_registry = self._agent.tool_registry.registry
195-
196-
if tool_registry.get(name, None):
197-
return name
198-
199-
# If the desired name contains underscores, it might be a placeholder for characters that can't be
200-
# represented as python identifiers but are valid as tool names, such as dashes. In that case, find
201-
# all tools that can be represented with the normalized name
202-
if "_" in name:
203-
filtered_tools = [
204-
tool_name for (tool_name, tool) in tool_registry.items() if tool_name.replace("-", "_") == name
205-
]
206-
207-
# The registry itself defends against similar names, so we can just take the first match
208-
if filtered_tools:
209-
return filtered_tools[0]
210-
211-
raise AttributeError(f"Tool '{name}' not found")
102+
# For backwards compatibility
103+
ToolCaller = _ToolCaller
212104

213105
def __init__(
214106
self,
@@ -347,7 +239,7 @@ def __init__(
347239
else:
348240
self.state = AgentState()
349241

350-
self.tool_caller = Agent.ToolCaller(self)
242+
self.tool_caller = _ToolCaller(self)
351243

352244
self.hooks = HookRegistry()
353245

@@ -395,7 +287,7 @@ def system_prompt(self, value: str | list[SystemContentBlock] | None) -> None:
395287
self._system_prompt, self._system_prompt_content = self._initialize_system_prompt(value)
396288

397289
@property
398-
def tool(self) -> ToolCaller:
290+
def tool(self) -> _ToolCaller:
399291
"""Call tool as a function.
400292
401293
Returns:
@@ -854,71 +746,6 @@ async def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages:
854746
raise ValueError("Input prompt must be of type: `str | list[Contentblock] | Messages | None`.")
855747
return messages
856748

857-
async def _record_tool_execution(
858-
self,
859-
tool: ToolUse,
860-
tool_result: ToolResult,
861-
user_message_override: Optional[str],
862-
) -> None:
863-
"""Record a tool execution in the message history.
864-
865-
Creates a sequence of messages that represent the tool execution:
866-
867-
1. A user message describing the tool call
868-
2. An assistant message with the tool use
869-
3. A user message with the tool result
870-
4. An assistant message acknowledging the tool call
871-
872-
Args:
873-
tool: The tool call information.
874-
tool_result: The result returned by the tool.
875-
user_message_override: Optional custom message to include.
876-
"""
877-
# Filter tool input parameters to only include those defined in tool spec
878-
filtered_input = self._filter_tool_parameters_for_recording(tool["name"], tool["input"])
879-
880-
# Create user message describing the tool call
881-
input_parameters = json.dumps(filtered_input, default=lambda o: f"<<non-serializable: {type(o).__qualname__}>>")
882-
883-
user_msg_content: list[ContentBlock] = [
884-
{"text": (f"agent.tool.{tool['name']} direct tool call.\nInput parameters: {input_parameters}\n")}
885-
]
886-
887-
# Add override message if provided
888-
if user_message_override:
889-
user_msg_content.insert(0, {"text": f"{user_message_override}\n"})
890-
891-
# Create filtered tool use for message history
892-
filtered_tool: ToolUse = {
893-
"toolUseId": tool["toolUseId"],
894-
"name": tool["name"],
895-
"input": filtered_input,
896-
}
897-
898-
# Create the message sequence
899-
user_msg: Message = {
900-
"role": "user",
901-
"content": user_msg_content,
902-
}
903-
tool_use_msg: Message = {
904-
"role": "assistant",
905-
"content": [{"toolUse": filtered_tool}],
906-
}
907-
tool_result_msg: Message = {
908-
"role": "user",
909-
"content": [{"toolResult": tool_result}],
910-
}
911-
assistant_msg: Message = {
912-
"role": "assistant",
913-
"content": [{"text": f"agent.tool.{tool['name']} was called."}],
914-
}
915-
916-
# Add to message history
917-
await self._append_message(user_msg)
918-
await self._append_message(tool_use_msg)
919-
await self._append_message(tool_result_msg)
920-
await self._append_message(assistant_msg)
921-
922749
def _start_agent_trace_span(self, messages: Messages) -> trace_api.Span:
923750
"""Starts a trace span for the agent.
924751
@@ -960,25 +787,6 @@ def _end_agent_trace_span(
960787

961788
self.tracer.end_agent_span(**trace_attributes)
962789

963-
def _filter_tool_parameters_for_recording(self, tool_name: str, input_params: dict[str, Any]) -> dict[str, Any]:
964-
"""Filter input parameters to only include those defined in the tool specification.
965-
966-
Args:
967-
tool_name: Name of the tool to get specification for
968-
input_params: Original input parameters
969-
970-
Returns:
971-
Filtered parameters containing only those defined in tool spec
972-
"""
973-
all_tools_config = self.tool_registry.get_all_tools_config()
974-
tool_spec = all_tools_config.get(tool_name)
975-
976-
if not tool_spec or "inputSchema" not in tool_spec:
977-
return input_params.copy()
978-
979-
properties = tool_spec["inputSchema"]["json"]["properties"]
980-
return {k: v for k, v in input_params.items() if k in properties}
981-
982790
def _initialize_system_prompt(
983791
self, system_prompt: str | list[SystemContentBlock] | None
984792
) -> tuple[str | None, list[SystemContentBlock] | None]:

0 commit comments

Comments
 (0)