Skip to content

Commit 01839f7

Browse files
Unshurejsamuel1
authored andcommitted
refactor: remove kwargs spread after agent call (strands-agents#289)
* refactor: remove kwargs spread after agent call * fix: Add local method override * fix: fix unit tests
1 parent 7f9349d commit 01839f7

File tree

8 files changed

+230
-241
lines changed

8 files changed

+230
-241
lines changed

src/strands/agent/agent.py

Lines changed: 67 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import os
1515
import random
1616
from concurrent.futures import ThreadPoolExecutor
17-
from typing import Any, AsyncIterator, Callable, Generator, Mapping, Optional, Type, TypeVar, Union, cast
17+
from typing import Any, AsyncIterator, Callable, Generator, List, Mapping, Optional, Type, TypeVar, Union, cast
1818

1919
from opentelemetry import trace
2020
from pydantic import BaseModel
@@ -31,7 +31,7 @@
3131
from ..types.content import ContentBlock, Message, Messages
3232
from ..types.exceptions import ContextWindowOverflowException
3333
from ..types.models import Model
34-
from ..types.tools import ToolConfig
34+
from ..types.tools import ToolConfig, ToolResult, ToolUse
3535
from ..types.traces import AttributeValue
3636
from .agent_result import AgentResult
3737
from .conversation_manager import (
@@ -98,104 +98,56 @@ def __getattr__(self, name: str) -> Callable[..., Any]:
9898
AttributeError: If no tool with the given name exists or if multiple tools match the given name.
9999
"""
100100

101-
def find_normalized_tool_name() -> Optional[str]:
102-
"""Lookup the tool represented by name, replacing characters with underscores as necessary."""
103-
tool_registry = self._agent.tool_registry.registry
104-
105-
if tool_registry.get(name, None):
106-
return name
107-
108-
# If the desired name contains underscores, it might be a placeholder for characters that can't be
109-
# represented as python identifiers but are valid as tool names, such as dashes. In that case, find
110-
# all tools that can be represented with the normalized name
111-
if "_" in name:
112-
filtered_tools = [
113-
tool_name for (tool_name, tool) in tool_registry.items() if tool_name.replace("-", "_") == name
114-
]
115-
116-
# The registry itself defends against similar names, so we can just take the first match
117-
if filtered_tools:
118-
return filtered_tools[0]
119-
120-
raise AttributeError(f"Tool '{name}' not found")
121-
122-
def caller(**kwargs: Any) -> Any:
101+
def caller(
102+
user_message_override: Optional[str] = None,
103+
record_direct_tool_call: Optional[bool] = None,
104+
**kwargs: Any,
105+
) -> Any:
123106
"""Call a tool directly by name.
124107
125108
Args:
109+
user_message_override: Optional custom message to record instead of default
110+
record_direct_tool_call: Whether to record direct tool calls in message history. Overrides class
111+
attribute if provided.
126112
**kwargs: Keyword arguments to pass to the tool.
127113
128-
- user_message_override: Custom message to record instead of default
129-
- tool_execution_handler: Custom handler for tool execution
130-
- event_loop_metrics: Custom metrics collector
131-
- messages: Custom message history to use
132-
- tool_config: Custom tool configuration
133-
- callback_handler: Custom callback handler
134-
- record_direct_tool_call: Whether to record this call in history
135-
136114
Returns:
137115
The result returned by the tool.
138116
139117
Raises:
140118
AttributeError: If the tool doesn't exist.
141119
"""
142-
normalized_name = find_normalized_tool_name()
120+
normalized_name = self._find_normalized_tool_name(name)
143121

144122
# Create unique tool ID and set up the tool request
145123
tool_id = f"tooluse_{name}_{random.randint(100000000, 999999999)}"
146-
tool_use = {
124+
tool_use: ToolUse = {
147125
"toolUseId": tool_id,
148126
"name": normalized_name,
149127
"input": kwargs.copy(),
150128
}
151129

152-
# Extract tool execution parameters
153-
user_message_override = kwargs.get("user_message_override", None)
154-
tool_execution_handler = kwargs.get("tool_execution_handler", self._agent.thread_pool_wrapper)
155-
event_loop_metrics = kwargs.get("event_loop_metrics", self._agent.event_loop_metrics)
156-
messages = kwargs.get("messages", self._agent.messages)
157-
tool_config = kwargs.get("tool_config", self._agent.tool_config)
158-
callback_handler = kwargs.get("callback_handler", self._agent.callback_handler)
159-
record_direct_tool_call = kwargs.get("record_direct_tool_call", self._agent.record_direct_tool_call)
160-
161-
# Process tool call
162-
handler_kwargs = {
163-
k: v
164-
for k, v in kwargs.items()
165-
if k
166-
not in [
167-
"tool_execution_handler",
168-
"event_loop_metrics",
169-
"messages",
170-
"tool_config",
171-
"callback_handler",
172-
"tool_handler",
173-
"system_prompt",
174-
"model",
175-
"model_id",
176-
"user_message_override",
177-
"agent",
178-
"record_direct_tool_call",
179-
]
180-
}
181-
182130
# Execute the tool
183131
tool_result = self._agent.tool_handler.process(
184132
tool=tool_use,
185133
model=self._agent.model,
186134
system_prompt=self._agent.system_prompt,
187-
messages=messages,
188-
tool_config=tool_config,
189-
callback_handler=callback_handler,
190-
tool_execution_handler=tool_execution_handler,
191-
event_loop_metrics=event_loop_metrics,
192-
agent=self._agent,
193-
**handler_kwargs,
135+
messages=self._agent.messages,
136+
tool_config=self._agent.tool_config,
137+
callback_handler=self._agent.callback_handler,
138+
kwargs=kwargs,
194139
)
195140

196-
if record_direct_tool_call:
141+
if record_direct_tool_call is not None:
142+
should_record_direct_tool_call = record_direct_tool_call
143+
else:
144+
should_record_direct_tool_call = self._agent.record_direct_tool_call
145+
146+
if should_record_direct_tool_call:
197147
# Create a record of this tool execution in the message history
198-
self._agent._record_tool_execution(tool_use, tool_result, user_message_override, messages)
148+
self._agent._record_tool_execution(
149+
tool_use, tool_result, user_message_override, self._agent.messages
150+
)
199151

200152
# Apply window management
201153
self._agent.conversation_manager.apply_management(self._agent)
@@ -204,6 +156,27 @@ def caller(**kwargs: Any) -> Any:
204156

205157
return caller
206158

159+
def _find_normalized_tool_name(self, name: str) -> str:
160+
"""Lookup the tool represented by name, replacing characters with underscores as necessary."""
161+
tool_registry = self._agent.tool_registry.registry
162+
163+
if tool_registry.get(name, None):
164+
return name
165+
166+
# If the desired name contains underscores, it might be a placeholder for characters that can't be
167+
# represented as python identifiers but are valid as tool names, such as dashes. In that case, find
168+
# all tools that can be represented with the normalized name
169+
if "_" in name:
170+
filtered_tools = [
171+
tool_name for (tool_name, tool) in tool_registry.items() if tool_name.replace("-", "_") == name
172+
]
173+
174+
# The registry itself defends against similar names, so we can just take the first match
175+
if filtered_tools:
176+
return filtered_tools[0]
177+
178+
raise AttributeError(f"Tool '{name}' not found")
179+
207180
def __init__(
208181
self,
209182
model: Union[Model, str, None] = None,
@@ -411,7 +384,7 @@ def __call__(self, prompt: str, **kwargs: Any) -> AgentResult:
411384
412385
Args:
413386
prompt: The natural language prompt from the user.
414-
**kwargs: Additional parameters to pass to the event loop.
387+
**kwargs: Additional parameters to pass through the event loop.
415388
416389
Returns:
417390
Result object containing:
@@ -577,35 +550,35 @@ def _execute_event_loop_cycle(
577550
# Get dynamic tool config (now simple and sync!)
578551
tool_config = self._select_tools_for_context(current_prompt, messages)
579552

580-
kwargs.pop("agent", None) # Remove agent to avoid conflicts
553+
# Add `Agent` to kwargs to keep backwards-compatibility
554+
kwargs["agent"] = self
581555

582556
try:
583557
# Execute the main event loop cycle
584558
yield from event_loop_cycle(
585-
model=model,
586-
system_prompt=system_prompt,
587-
messages=messages, # will be modified by event_loop_cycle
588-
tool_config=tool_config,
589-
callback_handler=callback_handler_override,
590-
tool_handler=tool_handler,
591-
tool_execution_handler=tool_execution_handler,
592-
event_loop_metrics=event_loop_metrics,
593-
agent=self,
559+
model=self.model,
560+
system_prompt=self.system_prompt,
561+
messages=self.messages, # will be modified by event_loop_cycle
562+
tool_config=self.tool_config,
563+
callback_handler=callback_handler,
564+
tool_handler=self.tool_handler,
565+
tool_execution_handler=self.thread_pool_wrapper,
566+
event_loop_metrics=self.event_loop_metrics,
594567
event_loop_parent_span=self.trace_span,
595-
**kwargs,
568+
kwargs=kwargs,
596569
)
597570

598571
except ContextWindowOverflowException as e:
599572
# Try reducing the context size and retrying
600573
self.conversation_manager.reduce_context(self, e=e)
601-
yield from self._execute_event_loop_cycle(callback_handler_override, kwargs)
574+
yield from self._execute_event_loop_cycle(callback_handler, kwargs)
602575

603576
def _record_tool_execution(
604577
self,
605-
tool: dict[str, Any],
606-
tool_result: dict[str, Any],
578+
tool: ToolUse,
579+
tool_result: ToolResult,
607580
user_message_override: Optional[str],
608-
messages: list[dict[str, Any]],
581+
messages: Messages,
609582
) -> None:
610583
"""Record a tool execution in the message history.
611584
@@ -623,7 +596,7 @@ def _record_tool_execution(
623596
messages: The message history to append to.
624597
"""
625598
# Create user message describing the tool call
626-
user_msg_content = [
599+
user_msg_content: List[ContentBlock] = [
627600
{"text": (f"agent.tool.{tool['name']} direct tool call.\nInput parameters: {json.dumps(tool['input'])}\n")}
628601
]
629602

@@ -632,19 +605,19 @@ def _record_tool_execution(
632605
user_msg_content.insert(0, {"text": f"{user_message_override}\n"})
633606

634607
# Create the message sequence
635-
user_msg = {
608+
user_msg: Message = {
636609
"role": "user",
637610
"content": user_msg_content,
638611
}
639-
tool_use_msg = {
612+
tool_use_msg: Message = {
640613
"role": "assistant",
641614
"content": [{"toolUse": tool}],
642615
}
643-
tool_result_msg = {
616+
tool_result_msg: Message = {
644617
"role": "user",
645618
"content": [{"toolResult": tool_result}],
646619
}
647-
assistant_msg = {
620+
assistant_msg: Message = {
648621
"role": "assistant",
649622
"content": [{"text": f"agent.{tool['name']} was called"}],
650623
}

0 commit comments

Comments
 (0)