Skip to content

iterative tools #345

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def caller(
}

# Execute the tool
events = run_tool(agent=self._agent, tool=tool_use, kwargs=kwargs)
events = run_tool(self._agent, tool_use, kwargs)

try:
while True:
Expand Down
23 changes: 11 additions & 12 deletions src/strands/event_loop/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ async def event_loop_cycle(agent: "Agent", kwargs: dict[str, Any]) -> AsyncGener
- event_loop_cycle_span: Current tracing Span for this cycle

Yields:
Model and tool invocation events. The last event is a tuple containing:
Model and tool stream events. The last event is a tuple containing:

- StopReason: Reason the model stopped generating (e.g., "tool_use")
- Message: The generated message from the model
Expand Down Expand Up @@ -254,14 +254,14 @@ async def recurse_event_loop(agent: "Agent", kwargs: dict[str, Any]) -> AsyncGen
recursive_trace.end()


def run_tool(agent: "Agent", kwargs: dict[str, Any], tool: ToolUse) -> ToolGenerator:
def run_tool(agent: "Agent", tool_use: ToolUse, kwargs: dict[str, Any]) -> ToolGenerator:
"""Process a tool invocation.

Looks up the tool in the registry and invokes it with the provided parameters.
Looks up the tool in the registry and streams it with the provided parameters.

Args:
agent: The agent for which the tool is being executed.
tool: The tool object to process, containing name and parameters.
tool_use: The tool object to process, containing name and parameters.
kwargs: Additional keyword arguments passed to the tool.

Yields:
Expand All @@ -270,9 +270,9 @@ def run_tool(agent: "Agent", kwargs: dict[str, Any], tool: ToolUse) -> ToolGener
Returns:
The final tool result or an error response if the tool fails or is not found.
"""
logger.debug("tool=<%s> | invoking", tool)
tool_use_id = tool["toolUseId"]
tool_name = tool["name"]
logger.debug("tool_use=<%s> | streaming", tool_use)
tool_use_id = tool_use["toolUseId"]
tool_name = tool_use["name"]

# Get the tool info
tool_info = agent.tool_registry.dynamic_tools.get(tool_name)
Expand Down Expand Up @@ -301,8 +301,7 @@ def run_tool(agent: "Agent", kwargs: dict[str, Any], tool: ToolUse) -> ToolGener
}
)

result = tool_func.invoke(tool, **kwargs)
yield {"result": result} # Placeholder until tool_func becomes a generator from which we can yield from
result = yield from tool_func.stream(tool_use, **kwargs)
return result

except Exception as e:
Expand Down Expand Up @@ -341,8 +340,8 @@ async def _handle_tool_execution(
kwargs: Additional keyword arguments, including request state.

Yields:
Tool invocation events along with events yielded from a recursive call to the event loop. The last event is a
tuple containing:
Tool stream events along with events yielded from a recursive call to the event loop. The last event is a tuple
containing:
- The stop reason,
- The updated message,
- The updated event loop metrics,
Expand All @@ -355,7 +354,7 @@ async def _handle_tool_execution(
return

def tool_handler(tool_use: ToolUse) -> ToolGenerator:
return run_tool(agent=agent, kwargs=kwargs, tool=tool_use)
return run_tool(agent, tool_use, kwargs)

tool_events = run_tools(
handler=tool_handler,
Expand Down
53 changes: 27 additions & 26 deletions src/strands/tools/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def my_tool(param1: str, param2: int = 42) -> dict:
from typing import (
Any,
Callable,
Dict,
Generic,
Optional,
ParamSpec,
Expand All @@ -62,7 +61,7 @@ def my_tool(param1: str, param2: int = 42) -> dict:
from pydantic import BaseModel, Field, create_model
from typing_extensions import override

from ..types.tools import AgentTool, JSONSchema, ToolResult, ToolSpec, ToolUse
from ..types.tools import AgentTool, JSONSchema, ToolGenerator, ToolResult, ToolSpec, ToolUse

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -119,7 +118,7 @@ def _create_input_model(self) -> Type[BaseModel]:
Returns:
A Pydantic BaseModel class customized for the function's parameters.
"""
field_definitions: Dict[str, Any] = {}
field_definitions: dict[str, Any] = {}

for name, param in self.signature.parameters.items():
# Skip special parameters
Expand Down Expand Up @@ -179,7 +178,7 @@ def extract_metadata(self) -> ToolSpec:

return tool_spec

def _clean_pydantic_schema(self, schema: Dict[str, Any]) -> None:
def _clean_pydantic_schema(self, schema: dict[str, Any]) -> None:
"""Clean up Pydantic schema to match Strands' expected format.

Pydantic's JSON schema output includes several elements that aren't needed for Strands Agent tools and could
Expand Down Expand Up @@ -227,7 +226,7 @@ def _clean_pydantic_schema(self, schema: Dict[str, Any]) -> None:
if key in prop_schema:
del prop_schema[key]

def validate_input(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
def validate_input(self, input_data: dict[str, Any]) -> dict[str, Any]:
"""Validate input data using the Pydantic model.

This method ensures that the input data meets the expected schema before it's passed to the actual function. It
Expand Down Expand Up @@ -270,32 +269,32 @@ class DecoratedFunctionTool(AgentTool, Generic[P, R]):

_tool_name: str
_tool_spec: ToolSpec
_tool_func: Callable[P, R]
_metadata: FunctionToolMetadata
original_function: Callable[P, R]

def __init__(
self,
function: Callable[P, R],
tool_name: str,
tool_spec: ToolSpec,
tool_func: Callable[P, R],
metadata: FunctionToolMetadata,
):
"""Initialize the decorated function tool.

Args:
function: The original function being decorated.
tool_name: The name to use for the tool (usually the function name).
tool_spec: The tool specification containing metadata for Agent integration.
tool_func: The original function being decorated.
metadata: The FunctionToolMetadata object with extracted function information.
"""
super().__init__()

self.original_function = function
self._tool_name = tool_name
self._tool_spec = tool_spec
self._tool_func = tool_func
self._metadata = metadata
self._tool_name = tool_name

functools.update_wrapper(wrapper=self, wrapped=self.original_function)
functools.update_wrapper(wrapper=self, wrapped=self._tool_func)

def __get__(self, instance: Any, obj_type: Optional[Type] = None) -> "DecoratedFunctionTool[P, R]":
"""Descriptor protocol implementation for proper method binding.
Expand Down Expand Up @@ -323,12 +322,10 @@ def my_tool():
tool = instance.my_tool
```
"""
if instance is not None and not inspect.ismethod(self.original_function):
if instance is not None and not inspect.ismethod(self._tool_func):
# Create a bound method
new_callback = self.original_function.__get__(instance, instance.__class__)
return DecoratedFunctionTool(
function=new_callback, tool_name=self.tool_name, tool_spec=self.tool_spec, metadata=self._metadata
)
tool_func = self._tool_func.__get__(instance, instance.__class__)
return DecoratedFunctionTool(self._tool_name, self._tool_spec, tool_func, self._metadata)

return self

Expand Down Expand Up @@ -360,7 +357,7 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:

return cast(R, self.invoke(tool_use, **kwargs))

return self.original_function(*args, **kwargs)
return self._tool_func(*args, **kwargs)

@property
def tool_name(self) -> str:
Expand Down Expand Up @@ -389,10 +386,11 @@ def tool_type(self) -> str:
"""
return "function"

def invoke(self, tool: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> ToolResult:
"""Invoke the tool with a tool use specification.
@override
def stream(self, tool_use: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> ToolGenerator:
"""Stream the tool with a tool use specification.

This method handles tool use invocations from a Strands Agent. It validates the input,
This method handles tool use streams from a Strands Agent. It validates the input,
calls the function, and formats the result according to the expected tool result format.

Key operations:
Expand All @@ -404,15 +402,17 @@ def invoke(self, tool: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> ToolRes
5. Handle and format any errors that occur

Args:
tool: The tool use specification from the Agent.
tool_use: The tool use specification from the Agent.
*args: Additional positional arguments (not typically used).
**kwargs: Additional keyword arguments, may include 'agent' reference.

Yields:
Events of the tool stream.

Returns:
A standardized tool result dictionary with status and content.
"""
# This is a tool use call - process accordingly
tool_use = tool
tool_use_id = tool_use.get("toolUseId", "unknown")
tool_input = tool_use.get("input", {})

Expand All @@ -424,8 +424,9 @@ def invoke(self, tool: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> ToolRes
if "agent" in kwargs and "agent" in self._metadata.signature.parameters:
validated_input["agent"] = kwargs.get("agent")

# We get "too few arguments here" but because that's because fof the way we're calling it
result = self.original_function(**validated_input) # type: ignore
result = self._tool_func(**validated_input) # type: ignore # "Too few arguments" expected
if inspect.isgenerator(result):
result = yield from result

# FORMAT THE RESULT for Strands Agent
if isinstance(result, dict) and "status" in result and "content" in result:
Expand Down Expand Up @@ -476,7 +477,7 @@ def get_display_properties(self) -> dict[str, str]:
Function properties (e.g., function name).
"""
properties = super().get_display_properties()
properties["Function"] = self.original_function.__name__
properties["Function"] = self._tool_func.__name__
return properties


Expand Down Expand Up @@ -573,7 +574,7 @@ def decorator(f: T) -> "DecoratedFunctionTool[P, R]":
if not isinstance(tool_name, str):
raise ValueError(f"Tool name must be a string, got {type(tool_name)}")

return DecoratedFunctionTool(function=f, tool_name=tool_name, tool_spec=tool_spec, metadata=tool_meta)
return DecoratedFunctionTool(tool_name, tool_spec, f, tool_meta)

# Handle both @tool and @tool() syntax
if func is None:
Expand Down
16 changes: 8 additions & 8 deletions src/strands/tools/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,23 +41,23 @@ def run_tools(
thread_pool: Optional thread pool for parallel processing.

Yields:
Events of the tool invocations. Tool results are appended to `tool_results`.
Events of the tool stream. Tool results are appended to `tool_results`.
"""

def handle(tool: ToolUse) -> ToolGenerator:
def handle(tool_use: ToolUse) -> ToolGenerator:
tracer = get_tracer()
tool_call_span = tracer.start_tool_call_span(tool, parent_span)
tool_call_span = tracer.start_tool_call_span(tool_use, parent_span)

tool_name = tool["name"]
tool_name = tool_use["name"]
tool_trace = Trace(f"Tool: {tool_name}", parent_id=cycle_trace.id, raw_name=tool_name)
tool_start_time = time.time()

result = yield from handler(tool)
result = yield from handler(tool_use)

tool_success = result.get("status") == "success"
tool_duration = time.time() - tool_start_time
message = Message(role="user", content=[{"toolResult": result}])
event_loop_metrics.add_tool_usage(tool, tool_duration, tool_trace, tool_success, message)
event_loop_metrics.add_tool_usage(tool_use, tool_duration, tool_trace, tool_success, message)
cycle_trace.add_child(tool_trace)

if tool_call_span:
Expand All @@ -66,12 +66,12 @@ def handle(tool: ToolUse) -> ToolGenerator:
return result

def work(
tool: ToolUse,
tool_use: ToolUse,
worker_id: int,
worker_queue: queue.Queue,
worker_event: threading.Event,
) -> ToolResult:
events = handle(tool)
events = handle(tool_use)

try:
while True:
Expand Down
2 changes: 1 addition & 1 deletion src/strands/tools/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def load_python_tool(tool_path: str, tool_name: str) -> AgentTool:
if not callable(tool_func):
raise TypeError(f"Tool {tool_name} function is not callable")

return PythonAgentTool(tool_name, tool_spec, callback=tool_func)
return PythonAgentTool(tool_name, tool_spec, tool_func)

except Exception:
logger.exception("tool_name=<%s>, sys_path=<%s> | failed to load python tool", tool_name, sys.path)
Expand Down
24 changes: 17 additions & 7 deletions src/strands/tools/mcp/mcp_agent_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
from typing import TYPE_CHECKING, Any

from mcp.types import Tool as MCPTool
from typing_extensions import override

from ...types.tools import AgentTool, ToolResult, ToolSpec, ToolUse
from ...types.tools import AgentTool, ToolGenerator, ToolSpec, ToolUse

if TYPE_CHECKING:
from .mcp_client import MCPClient
Expand Down Expand Up @@ -73,13 +74,22 @@ def tool_type(self) -> str:
"""
return "python"

def invoke(self, tool: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> ToolResult:
"""Invoke the MCP tool.
@override
def stream(self, tool_use: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> ToolGenerator:
"""Stream the MCP tool.

This method delegates the tool invocation to the MCP server connection,
passing the tool use ID, tool name, and input arguments.
This method delegates the tool stream to the MCP server connection, passing the tool use ID, tool name, and
input arguments.

Yields:
No events.

Returns:
A standardized tool result dictionary with status and content.
"""
logger.debug("invoking MCP tool '%s' with tool_use_id=%s", self.tool_name, tool["toolUseId"])
logger.debug("tool_name=<%s>, tool_use_id=<%s> | streaming", self.tool_name, tool_use["toolUseId"])

return self.mcp_client.call_tool_sync(
tool_use_id=tool["toolUseId"], name=self.tool_name, arguments=tool["input"]
tool_use_id=tool_use["toolUseId"], name=self.tool_name, arguments=tool_use["input"]
)
yield # type: ignore # Need yield to create generator, but left unreachable as we have no events
18 changes: 3 additions & 15 deletions src/strands/tools/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,11 +347,7 @@ def reload_tool(self, tool_name: str) -> None:
# Validate tool spec
self.validate_tool_spec(module.TOOL_SPEC)

new_tool = PythonAgentTool(
tool_name=tool_name,
tool_spec=module.TOOL_SPEC,
callback=tool_function,
)
new_tool = PythonAgentTool(tool_name, module.TOOL_SPEC, tool_function)

# Register the tool
self.register_tool(new_tool)
Expand Down Expand Up @@ -431,11 +427,7 @@ def initialize_tools(self, load_tools_from_directory: bool = True) -> None:
continue

tool_spec = module.TOOL_SPEC
tool = PythonAgentTool(
tool_name=tool_name,
tool_spec=tool_spec,
callback=tool_function,
)
tool = PythonAgentTool(tool_name, tool_spec, tool_function)
self.register_tool(tool)
successful_loads += 1

Expand Down Expand Up @@ -463,11 +455,7 @@ def initialize_tools(self, load_tools_from_directory: bool = True) -> None:
continue

tool_spec = module.TOOL_SPEC
tool = PythonAgentTool(
tool_name=tool_name,
tool_spec=tool_spec,
callback=tool_function,
)
tool = PythonAgentTool(tool_name, tool_spec, tool_function)
self.register_tool(tool)
successful_loads += 1

Expand Down
Loading
Loading