|
9 | 9 | 2. Method-style for direct tool access: `agent.tool.tool_name(param1="value")` |
10 | 10 | """ |
11 | 11 |
|
12 | | -import json |
13 | 12 | import logging |
14 | | -import random |
15 | 13 | import warnings |
16 | 14 | from typing import ( |
17 | 15 | TYPE_CHECKING, |
|
52 | 50 | from ..session.session_manager import SessionManager |
53 | 51 | from ..telemetry.metrics import EventLoopMetrics |
54 | 52 | from ..telemetry.tracer import get_tracer, serialize |
| 53 | +from ..tools._caller import _ToolCaller |
55 | 54 | from ..tools.executors import ConcurrentToolExecutor |
56 | 55 | from ..tools.executors._executor import ToolExecutor |
57 | 56 | from ..tools.registry import ToolRegistry |
58 | 57 | from ..tools.structured_output._structured_output_context import StructuredOutputContext |
59 | 58 | from ..tools.watcher import ToolWatcher |
60 | | -from ..types._events import AgentResultEvent, InitEventLoopEvent, ModelStreamChunkEvent, ToolInterruptEvent, TypedEvent |
| 59 | +from ..types._events import AgentResultEvent, InitEventLoopEvent, ModelStreamChunkEvent, TypedEvent |
61 | 60 | from ..types.agent import AgentInput |
62 | 61 | from ..types.content import ContentBlock, Message, Messages, SystemContentBlock |
63 | 62 | from ..types.exceptions import ContextWindowOverflowException |
64 | | -from ..types.tools import ToolResult, ToolUse |
65 | 63 | from ..types.traces import AttributeValue |
66 | 64 | from .agent_result import AgentResult |
67 | 65 | from .conversation_manager import ( |
@@ -101,114 +99,8 @@ class Agent: |
101 | 99 | 6. Produces a final response |
102 | 100 | """ |
103 | 101 |
|
104 | | - class ToolCaller: |
105 | | - """Call tool as a function.""" |
106 | | - |
107 | | - def __init__(self, agent: "Agent") -> None: |
108 | | - """Initialize instance. |
109 | | -
|
110 | | - Args: |
111 | | - agent: Agent reference that will accept tool results. |
112 | | - """ |
113 | | - # WARNING: Do not add any other member variables or methods as this could result in a name conflict with |
114 | | - # agent tools and thus break their execution. |
115 | | - self._agent = agent |
116 | | - |
117 | | - def __getattr__(self, name: str) -> Callable[..., Any]: |
118 | | - """Call tool as a function. |
119 | | -
|
120 | | - This method enables the method-style interface (e.g., `agent.tool.tool_name(param="value")`). |
121 | | - It matches underscore-separated names to hyphenated tool names (e.g., 'some_thing' matches 'some-thing'). |
122 | | -
|
123 | | - Args: |
124 | | - name: The name of the attribute (tool) being accessed. |
125 | | -
|
126 | | - Returns: |
127 | | - A function that when called will execute the named tool. |
128 | | -
|
129 | | - Raises: |
130 | | - AttributeError: If no tool with the given name exists or if multiple tools match the given name. |
131 | | - """ |
132 | | - |
133 | | - def caller( |
134 | | - user_message_override: Optional[str] = None, |
135 | | - record_direct_tool_call: Optional[bool] = None, |
136 | | - **kwargs: Any, |
137 | | - ) -> Any: |
138 | | - """Call a tool directly by name. |
139 | | -
|
140 | | - Args: |
141 | | - user_message_override: Optional custom message to record instead of default |
142 | | - record_direct_tool_call: Whether to record direct tool calls in message history. Overrides class |
143 | | - attribute if provided. |
144 | | - **kwargs: Keyword arguments to pass to the tool. |
145 | | -
|
146 | | - Returns: |
147 | | - The result returned by the tool. |
148 | | -
|
149 | | - Raises: |
150 | | - AttributeError: If the tool doesn't exist. |
151 | | - """ |
152 | | - if self._agent._interrupt_state.activated: |
153 | | - raise RuntimeError("cannot directly call tool during interrupt") |
154 | | - |
155 | | - normalized_name = self._find_normalized_tool_name(name) |
156 | | - |
157 | | - # Create unique tool ID and set up the tool request |
158 | | - tool_id = f"tooluse_{name}_{random.randint(100000000, 999999999)}" |
159 | | - tool_use: ToolUse = { |
160 | | - "toolUseId": tool_id, |
161 | | - "name": normalized_name, |
162 | | - "input": kwargs.copy(), |
163 | | - } |
164 | | - tool_results: list[ToolResult] = [] |
165 | | - invocation_state = kwargs |
166 | | - |
167 | | - async def acall() -> ToolResult: |
168 | | - async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state): |
169 | | - if isinstance(event, ToolInterruptEvent): |
170 | | - self._agent._interrupt_state.deactivate() |
171 | | - raise RuntimeError("cannot raise interrupt in direct tool call") |
172 | | - |
173 | | - tool_result = tool_results[0] |
174 | | - |
175 | | - if record_direct_tool_call is not None: |
176 | | - should_record_direct_tool_call = record_direct_tool_call |
177 | | - else: |
178 | | - should_record_direct_tool_call = self._agent.record_direct_tool_call |
179 | | - |
180 | | - if should_record_direct_tool_call: |
181 | | - # Create a record of this tool execution in the message history |
182 | | - await self._agent._record_tool_execution(tool_use, tool_result, user_message_override) |
183 | | - |
184 | | - return tool_result |
185 | | - |
186 | | - tool_result = run_async(acall) |
187 | | - self._agent.conversation_manager.apply_management(self._agent) |
188 | | - return tool_result |
189 | | - |
190 | | - return caller |
191 | | - |
192 | | - def _find_normalized_tool_name(self, name: str) -> str: |
193 | | - """Lookup the tool represented by name, replacing characters with underscores as necessary.""" |
194 | | - tool_registry = self._agent.tool_registry.registry |
195 | | - |
196 | | - if tool_registry.get(name, None): |
197 | | - return name |
198 | | - |
199 | | - # If the desired name contains underscores, it might be a placeholder for characters that can't be |
200 | | - # represented as python identifiers but are valid as tool names, such as dashes. In that case, find |
201 | | - # all tools that can be represented with the normalized name |
202 | | - if "_" in name: |
203 | | - filtered_tools = [ |
204 | | - tool_name for (tool_name, tool) in tool_registry.items() if tool_name.replace("-", "_") == name |
205 | | - ] |
206 | | - |
207 | | - # The registry itself defends against similar names, so we can just take the first match |
208 | | - if filtered_tools: |
209 | | - return filtered_tools[0] |
210 | | - |
211 | | - raise AttributeError(f"Tool '{name}' not found") |
| 102 | + # For backwards compatibility |
| 103 | + ToolCaller = _ToolCaller |
212 | 104 |
|
213 | 105 | def __init__( |
214 | 106 | self, |
@@ -347,7 +239,7 @@ def __init__( |
347 | 239 | else: |
348 | 240 | self.state = AgentState() |
349 | 241 |
|
350 | | - self.tool_caller = Agent.ToolCaller(self) |
| 242 | + self.tool_caller = _ToolCaller(self) |
351 | 243 |
|
352 | 244 | self.hooks = HookRegistry() |
353 | 245 |
|
@@ -395,7 +287,7 @@ def system_prompt(self, value: str | list[SystemContentBlock] | None) -> None: |
395 | 287 | self._system_prompt, self._system_prompt_content = self._initialize_system_prompt(value) |
396 | 288 |
|
397 | 289 | @property |
398 | | - def tool(self) -> ToolCaller: |
| 290 | + def tool(self) -> _ToolCaller: |
399 | 291 | """Call tool as a function. |
400 | 292 |
|
401 | 293 | Returns: |
@@ -854,71 +746,6 @@ async def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages: |
854 | 746 | raise ValueError("Input prompt must be of type: `str | list[Contentblock] | Messages | None`.") |
855 | 747 | return messages |
856 | 748 |
|
857 | | - async def _record_tool_execution( |
858 | | - self, |
859 | | - tool: ToolUse, |
860 | | - tool_result: ToolResult, |
861 | | - user_message_override: Optional[str], |
862 | | - ) -> None: |
863 | | - """Record a tool execution in the message history. |
864 | | -
|
865 | | - Creates a sequence of messages that represent the tool execution: |
866 | | -
|
867 | | - 1. A user message describing the tool call |
868 | | - 2. An assistant message with the tool use |
869 | | - 3. A user message with the tool result |
870 | | - 4. An assistant message acknowledging the tool call |
871 | | -
|
872 | | - Args: |
873 | | - tool: The tool call information. |
874 | | - tool_result: The result returned by the tool. |
875 | | - user_message_override: Optional custom message to include. |
876 | | - """ |
877 | | - # Filter tool input parameters to only include those defined in tool spec |
878 | | - filtered_input = self._filter_tool_parameters_for_recording(tool["name"], tool["input"]) |
879 | | - |
880 | | - # Create user message describing the tool call |
881 | | - input_parameters = json.dumps(filtered_input, default=lambda o: f"<<non-serializable: {type(o).__qualname__}>>") |
882 | | - |
883 | | - user_msg_content: list[ContentBlock] = [ |
884 | | - {"text": (f"agent.tool.{tool['name']} direct tool call.\nInput parameters: {input_parameters}\n")} |
885 | | - ] |
886 | | - |
887 | | - # Add override message if provided |
888 | | - if user_message_override: |
889 | | - user_msg_content.insert(0, {"text": f"{user_message_override}\n"}) |
890 | | - |
891 | | - # Create filtered tool use for message history |
892 | | - filtered_tool: ToolUse = { |
893 | | - "toolUseId": tool["toolUseId"], |
894 | | - "name": tool["name"], |
895 | | - "input": filtered_input, |
896 | | - } |
897 | | - |
898 | | - # Create the message sequence |
899 | | - user_msg: Message = { |
900 | | - "role": "user", |
901 | | - "content": user_msg_content, |
902 | | - } |
903 | | - tool_use_msg: Message = { |
904 | | - "role": "assistant", |
905 | | - "content": [{"toolUse": filtered_tool}], |
906 | | - } |
907 | | - tool_result_msg: Message = { |
908 | | - "role": "user", |
909 | | - "content": [{"toolResult": tool_result}], |
910 | | - } |
911 | | - assistant_msg: Message = { |
912 | | - "role": "assistant", |
913 | | - "content": [{"text": f"agent.tool.{tool['name']} was called."}], |
914 | | - } |
915 | | - |
916 | | - # Add to message history |
917 | | - await self._append_message(user_msg) |
918 | | - await self._append_message(tool_use_msg) |
919 | | - await self._append_message(tool_result_msg) |
920 | | - await self._append_message(assistant_msg) |
921 | | - |
922 | 749 | def _start_agent_trace_span(self, messages: Messages) -> trace_api.Span: |
923 | 750 | """Starts a trace span for the agent. |
924 | 751 |
|
@@ -960,25 +787,6 @@ def _end_agent_trace_span( |
960 | 787 |
|
961 | 788 | self.tracer.end_agent_span(**trace_attributes) |
962 | 789 |
|
963 | | - def _filter_tool_parameters_for_recording(self, tool_name: str, input_params: dict[str, Any]) -> dict[str, Any]: |
964 | | - """Filter input parameters to only include those defined in the tool specification. |
965 | | -
|
966 | | - Args: |
967 | | - tool_name: Name of the tool to get specification for |
968 | | - input_params: Original input parameters |
969 | | -
|
970 | | - Returns: |
971 | | - Filtered parameters containing only those defined in tool spec |
972 | | - """ |
973 | | - all_tools_config = self.tool_registry.get_all_tools_config() |
974 | | - tool_spec = all_tools_config.get(tool_name) |
975 | | - |
976 | | - if not tool_spec or "inputSchema" not in tool_spec: |
977 | | - return input_params.copy() |
978 | | - |
979 | | - properties = tool_spec["inputSchema"]["json"]["properties"] |
980 | | - return {k: v for k, v in input_params.items() if k in properties} |
981 | | - |
982 | 790 | def _initialize_system_prompt( |
983 | 791 | self, system_prompt: str | list[SystemContentBlock] | None |
984 | 792 | ) -> tuple[str | None, list[SystemContentBlock] | None]: |
|
0 commit comments