Skip to content

refactor: remove kwargs spread after agent call #289

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 67 additions & 103 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import os
import random
from concurrent.futures import ThreadPoolExecutor
from typing import Any, AsyncIterator, Callable, Generator, Mapping, Optional, Type, TypeVar, Union, cast
from typing import Any, AsyncIterator, Callable, Generator, List, Mapping, Optional, Type, TypeVar, Union, cast

from opentelemetry import trace
from pydantic import BaseModel
Expand All @@ -31,7 +31,7 @@
from ..types.content import ContentBlock, Message, Messages
from ..types.exceptions import ContextWindowOverflowException
from ..types.models import Model
from ..types.tools import ToolConfig
from ..types.tools import ToolConfig, ToolResult, ToolUse
from ..types.traces import AttributeValue
from .agent_result import AgentResult
from .conversation_manager import (
Expand Down Expand Up @@ -97,104 +97,56 @@ def __getattr__(self, name: str) -> Callable[..., Any]:
AttributeError: If no tool with the given name exists or if multiple tools match the given name.
"""

def find_normalized_tool_name() -> Optional[str]:
"""Lookup the tool represented by name, replacing characters with underscores as necessary."""
tool_registry = self._agent.tool_registry.registry

if tool_registry.get(name, None):
return name

# If the desired name contains underscores, it might be a placeholder for characters that can't be
# represented as python identifiers but are valid as tool names, such as dashes. In that case, find
# all tools that can be represented with the normalized name
if "_" in name:
filtered_tools = [
tool_name for (tool_name, tool) in tool_registry.items() if tool_name.replace("-", "_") == name
]

# The registry itself defends against similar names, so we can just take the first match
if filtered_tools:
return filtered_tools[0]

raise AttributeError(f"Tool '{name}' not found")

def caller(**kwargs: Any) -> Any:
def caller(
user_message_override: Optional[str] = None,
record_direct_tool_call: Optional[bool] = None,
**kwargs: Any,
) -> Any:
"""Call a tool directly by name.

Args:
user_message_override: Optional custom message to record instead of default
record_direct_tool_call: Whether to record direct tool calls in message history. Overrides class
attribute if provided.
**kwargs: Keyword arguments to pass to the tool.

- user_message_override: Custom message to record instead of default
- tool_execution_handler: Custom handler for tool execution
- event_loop_metrics: Custom metrics collector
- messages: Custom message history to use
- tool_config: Custom tool configuration
- callback_handler: Custom callback handler
- record_direct_tool_call: Whether to record this call in history

Returns:
The result returned by the tool.

Raises:
AttributeError: If the tool doesn't exist.
"""
normalized_name = find_normalized_tool_name()
normalized_name = self._find_normalized_tool_name(name)

# Create unique tool ID and set up the tool request
tool_id = f"tooluse_{name}_{random.randint(100000000, 999999999)}"
tool_use = {
tool_use: ToolUse = {
"toolUseId": tool_id,
"name": normalized_name,
"input": kwargs.copy(),
}

# Extract tool execution parameters
user_message_override = kwargs.get("user_message_override", None)
tool_execution_handler = kwargs.get("tool_execution_handler", self._agent.thread_pool_wrapper)
event_loop_metrics = kwargs.get("event_loop_metrics", self._agent.event_loop_metrics)
messages = kwargs.get("messages", self._agent.messages)
tool_config = kwargs.get("tool_config", self._agent.tool_config)
callback_handler = kwargs.get("callback_handler", self._agent.callback_handler)
record_direct_tool_call = kwargs.get("record_direct_tool_call", self._agent.record_direct_tool_call)

# Process tool call
handler_kwargs = {
k: v
for k, v in kwargs.items()
if k
not in [
"tool_execution_handler",
"event_loop_metrics",
"messages",
"tool_config",
"callback_handler",
"tool_handler",
"system_prompt",
"model",
"model_id",
"user_message_override",
"agent",
"record_direct_tool_call",
]
}

# Execute the tool
tool_result = self._agent.tool_handler.process(
tool=tool_use,
model=self._agent.model,
system_prompt=self._agent.system_prompt,
messages=messages,
tool_config=tool_config,
callback_handler=callback_handler,
tool_execution_handler=tool_execution_handler,
event_loop_metrics=event_loop_metrics,
agent=self._agent,
**handler_kwargs,
messages=self._agent.messages,
tool_config=self._agent.tool_config,
callback_handler=self._agent.callback_handler,
kwargs=kwargs,
)

if record_direct_tool_call:
if record_direct_tool_call is not None:
should_record_direct_tool_call = record_direct_tool_call
else:
should_record_direct_tool_call = self._agent.record_direct_tool_call

if should_record_direct_tool_call:
# Create a record of this tool execution in the message history
self._agent._record_tool_execution(tool_use, tool_result, user_message_override, messages)
self._agent._record_tool_execution(
tool_use, tool_result, user_message_override, self._agent.messages
)

# Apply window management
self._agent.conversation_manager.apply_management(self._agent)
Expand All @@ -203,6 +155,27 @@ def caller(**kwargs: Any) -> Any:

return caller

def _find_normalized_tool_name(self, name: str) -> str:
"""Lookup the tool represented by name, replacing characters with underscores as necessary."""
tool_registry = self._agent.tool_registry.registry

if tool_registry.get(name, None):
return name

# If the desired name contains underscores, it might be a placeholder for characters that can't be
# represented as python identifiers but are valid as tool names, such as dashes. In that case, find
# all tools that can be represented with the normalized name
if "_" in name:
filtered_tools = [
tool_name for (tool_name, tool) in tool_registry.items() if tool_name.replace("-", "_") == name
]

# The registry itself defends against similar names, so we can just take the first match
if filtered_tools:
return filtered_tools[0]

raise AttributeError(f"Tool '{name}' not found")

def __init__(
self,
model: Union[Model, str, None] = None,
Expand Down Expand Up @@ -371,7 +344,7 @@ def __call__(self, prompt: str, **kwargs: Any) -> AgentResult:

Args:
prompt: The natural language prompt from the user.
**kwargs: Additional parameters to pass to the event loop.
**kwargs: Additional parameters to pass through the event loop.

Returns:
Result object containing:
Expand Down Expand Up @@ -514,44 +487,35 @@ def _execute_event_loop_cycle(
Yields:
Events of the loop cycle.
"""
# Extract parameters with fallbacks to instance values
system_prompt = kwargs.pop("system_prompt", self.system_prompt)
model = kwargs.pop("model", self.model)
tool_execution_handler = kwargs.pop("tool_execution_handler", self.thread_pool_wrapper)
event_loop_metrics = kwargs.pop("event_loop_metrics", self.event_loop_metrics)
callback_handler_override = kwargs.pop("callback_handler", callback_handler)
tool_handler = kwargs.pop("tool_handler", self.tool_handler)
messages = kwargs.pop("messages", self.messages)
tool_config = kwargs.pop("tool_config", self.tool_config)
kwargs.pop("agent", None) # Remove agent to avoid conflicts
# Add `Agent` to kwargs to keep backwards-compatibility
kwargs["agent"] = self

try:
# Execute the main event loop cycle
yield from event_loop_cycle(
model=model,
system_prompt=system_prompt,
messages=messages, # will be modified by event_loop_cycle
tool_config=tool_config,
callback_handler=callback_handler_override,
tool_handler=tool_handler,
tool_execution_handler=tool_execution_handler,
event_loop_metrics=event_loop_metrics,
agent=self,
model=self.model,
system_prompt=self.system_prompt,
messages=self.messages, # will be modified by event_loop_cycle
tool_config=self.tool_config,
callback_handler=callback_handler,
tool_handler=self.tool_handler,
tool_execution_handler=self.thread_pool_wrapper,
event_loop_metrics=self.event_loop_metrics,
event_loop_parent_span=self.trace_span,
**kwargs,
kwargs=kwargs,
)

except ContextWindowOverflowException as e:
# Try reducing the context size and retrying
self.conversation_manager.reduce_context(self, e=e)
yield from self._execute_event_loop_cycle(callback_handler_override, kwargs)
yield from self._execute_event_loop_cycle(callback_handler, kwargs)

def _record_tool_execution(
self,
tool: dict[str, Any],
tool_result: dict[str, Any],
tool: ToolUse,
tool_result: ToolResult,
user_message_override: Optional[str],
messages: list[dict[str, Any]],
messages: Messages,
) -> None:
"""Record a tool execution in the message history.

Expand All @@ -569,7 +533,7 @@ def _record_tool_execution(
messages: The message history to append to.
"""
# Create user message describing the tool call
user_msg_content = [
user_msg_content: List[ContentBlock] = [
{"text": (f"agent.tool.{tool['name']} direct tool call.\nInput parameters: {json.dumps(tool['input'])}\n")}
]

Expand All @@ -578,19 +542,19 @@ def _record_tool_execution(
user_msg_content.insert(0, {"text": f"{user_message_override}\n"})

# Create the message sequence
user_msg = {
user_msg: Message = {
"role": "user",
"content": user_msg_content,
}
tool_use_msg = {
tool_use_msg: Message = {
"role": "assistant",
"content": [{"toolUse": tool}],
}
tool_result_msg = {
tool_result_msg: Message = {
"role": "user",
"content": [{"toolResult": tool_result}],
}
assistant_msg = {
assistant_msg: Message = {
"role": "assistant",
"content": [{"text": f"agent.{tool['name']} was called"}],
}
Expand Down
Loading