Skip to content

Commit 79f6a52

Browse files
pgrayyjsamuel1
authored andcommitted
iterative tools (strands-agents#345)
1 parent d3a5240 commit 79f6a52

File tree

15 files changed

+422
-281
lines changed

15 files changed

+422
-281
lines changed

src/strands/agent/agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def caller(
130130
}
131131

132132
# Execute the tool
133-
events = run_tool(agent=self._agent, tool=tool_use, kwargs=kwargs)
133+
events = run_tool(self._agent, tool_use, kwargs)
134134

135135
try:
136136
while True:

src/strands/event_loop/event_loop.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ async def event_loop_cycle(agent: "Agent", kwargs: dict[str, Any]) -> AsyncGener
5656
- event_loop_cycle_span: Current tracing Span for this cycle
5757
5858
Yields:
59-
Model and tool invocation events. The last event is a tuple containing:
59+
Model and tool stream events. The last event is a tuple containing:
6060
6161
- StopReason: Reason the model stopped generating (e.g., "tool_use")
6262
- Message: The generated message from the model
@@ -254,14 +254,14 @@ async def recurse_event_loop(agent: "Agent", kwargs: dict[str, Any]) -> AsyncGen
254254
recursive_trace.end()
255255

256256

257-
def run_tool(agent: "Agent", kwargs: dict[str, Any], tool: ToolUse) -> ToolGenerator:
257+
def run_tool(agent: "Agent", tool_use: ToolUse, kwargs: dict[str, Any]) -> ToolGenerator:
258258
"""Process a tool invocation.
259259
260-
Looks up the tool in the registry and invokes it with the provided parameters.
260+
Looks up the tool in the registry and streams it with the provided parameters.
261261
262262
Args:
263263
agent: The agent for which the tool is being executed.
264-
tool: The tool object to process, containing name and parameters.
264+
tool_use: The tool object to process, containing name and parameters.
265265
kwargs: Additional keyword arguments passed to the tool.
266266
267267
Yields:
@@ -270,9 +270,9 @@ def run_tool(agent: "Agent", kwargs: dict[str, Any], tool: ToolUse) -> ToolGener
270270
Returns:
271271
The final tool result or an error response if the tool fails or is not found.
272272
"""
273-
logger.debug("tool=<%s> | invoking", tool)
274-
tool_use_id = tool["toolUseId"]
275-
tool_name = tool["name"]
273+
logger.debug("tool_use=<%s> | streaming", tool_use)
274+
tool_use_id = tool_use["toolUseId"]
275+
tool_name = tool_use["name"]
276276

277277
# Get the tool info
278278
tool_info = agent.tool_registry.dynamic_tools.get(tool_name)
@@ -301,8 +301,7 @@ def run_tool(agent: "Agent", kwargs: dict[str, Any], tool: ToolUse) -> ToolGener
301301
}
302302
)
303303

304-
result = tool_func.invoke(tool, **kwargs)
305-
yield {"result": result} # Placeholder until tool_func becomes a generator from which we can yield from
304+
result = yield from tool_func.stream(tool_use, **kwargs)
306305
return result
307306

308307
except Exception as e:
@@ -341,8 +340,8 @@ async def _handle_tool_execution(
341340
kwargs: Additional keyword arguments, including request state.
342341
343342
Yields:
344-
Tool invocation events along with events yielded from a recursive call to the event loop. The last event is a
345-
tuple containing:
343+
Tool stream events along with events yielded from a recursive call to the event loop. The last event is a tuple
344+
containing:
346345
- The stop reason,
347346
- The updated message,
348347
- The updated event loop metrics,
@@ -355,7 +354,7 @@ async def _handle_tool_execution(
355354
return
356355

357356
def tool_handler(tool_use: ToolUse) -> ToolGenerator:
358-
return run_tool(agent=agent, kwargs=kwargs, tool=tool_use)
357+
return run_tool(agent, tool_use, kwargs)
359358

360359
tool_events = run_tools(
361360
handler=tool_handler,

src/strands/tools/decorator.py

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ def my_tool(param1: str, param2: int = 42) -> dict:
4646
from typing import (
4747
Any,
4848
Callable,
49-
Dict,
5049
Generic,
5150
Optional,
5251
ParamSpec,
@@ -62,7 +61,7 @@ def my_tool(param1: str, param2: int = 42) -> dict:
6261
from pydantic import BaseModel, Field, create_model
6362
from typing_extensions import override
6463

65-
from ..types.tools import AgentTool, JSONSchema, ToolResult, ToolSpec, ToolUse
64+
from ..types.tools import AgentTool, JSONSchema, ToolGenerator, ToolResult, ToolSpec, ToolUse
6665

6766
logger = logging.getLogger(__name__)
6867

@@ -119,7 +118,7 @@ def _create_input_model(self) -> Type[BaseModel]:
119118
Returns:
120119
A Pydantic BaseModel class customized for the function's parameters.
121120
"""
122-
field_definitions: Dict[str, Any] = {}
121+
field_definitions: dict[str, Any] = {}
123122

124123
for name, param in self.signature.parameters.items():
125124
# Skip special parameters
@@ -179,7 +178,7 @@ def extract_metadata(self) -> ToolSpec:
179178

180179
return tool_spec
181180

182-
def _clean_pydantic_schema(self, schema: Dict[str, Any]) -> None:
181+
def _clean_pydantic_schema(self, schema: dict[str, Any]) -> None:
183182
"""Clean up Pydantic schema to match Strands' expected format.
184183
185184
Pydantic's JSON schema output includes several elements that aren't needed for Strands Agent tools and could
@@ -227,7 +226,7 @@ def _clean_pydantic_schema(self, schema: Dict[str, Any]) -> None:
227226
if key in prop_schema:
228227
del prop_schema[key]
229228

230-
def validate_input(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
229+
def validate_input(self, input_data: dict[str, Any]) -> dict[str, Any]:
231230
"""Validate input data using the Pydantic model.
232231
233232
This method ensures that the input data meets the expected schema before it's passed to the actual function. It
@@ -270,32 +269,32 @@ class DecoratedFunctionTool(AgentTool, Generic[P, R]):
270269

271270
_tool_name: str
272271
_tool_spec: ToolSpec
272+
_tool_func: Callable[P, R]
273273
_metadata: FunctionToolMetadata
274-
original_function: Callable[P, R]
275274

276275
def __init__(
277276
self,
278-
function: Callable[P, R],
279277
tool_name: str,
280278
tool_spec: ToolSpec,
279+
tool_func: Callable[P, R],
281280
metadata: FunctionToolMetadata,
282281
):
283282
"""Initialize the decorated function tool.
284283
285284
Args:
286-
function: The original function being decorated.
287285
tool_name: The name to use for the tool (usually the function name).
288286
tool_spec: The tool specification containing metadata for Agent integration.
287+
tool_func: The original function being decorated.
289288
metadata: The FunctionToolMetadata object with extracted function information.
290289
"""
291290
super().__init__()
292291

293-
self.original_function = function
292+
self._tool_name = tool_name
294293
self._tool_spec = tool_spec
294+
self._tool_func = tool_func
295295
self._metadata = metadata
296-
self._tool_name = tool_name
297296

298-
functools.update_wrapper(wrapper=self, wrapped=self.original_function)
297+
functools.update_wrapper(wrapper=self, wrapped=self._tool_func)
299298

300299
def __get__(self, instance: Any, obj_type: Optional[Type] = None) -> "DecoratedFunctionTool[P, R]":
301300
"""Descriptor protocol implementation for proper method binding.
@@ -323,12 +322,10 @@ def my_tool():
323322
tool = instance.my_tool
324323
```
325324
"""
326-
if instance is not None and not inspect.ismethod(self.original_function):
325+
if instance is not None and not inspect.ismethod(self._tool_func):
327326
# Create a bound method
328-
new_callback = self.original_function.__get__(instance, instance.__class__)
329-
return DecoratedFunctionTool(
330-
function=new_callback, tool_name=self.tool_name, tool_spec=self.tool_spec, metadata=self._metadata
331-
)
327+
tool_func = self._tool_func.__get__(instance, instance.__class__)
328+
return DecoratedFunctionTool(self._tool_name, self._tool_spec, tool_func, self._metadata)
332329

333330
return self
334331

@@ -360,7 +357,7 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
360357

361358
return cast(R, self.invoke(tool_use, **kwargs))
362359

363-
return self.original_function(*args, **kwargs)
360+
return self._tool_func(*args, **kwargs)
364361

365362
@property
366363
def tool_name(self) -> str:
@@ -389,10 +386,11 @@ def tool_type(self) -> str:
389386
"""
390387
return "function"
391388

392-
def invoke(self, tool: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> ToolResult:
393-
"""Invoke the tool with a tool use specification.
389+
@override
390+
def stream(self, tool_use: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> ToolGenerator:
391+
"""Stream the tool with a tool use specification.
394392
395-
This method handles tool use invocations from a Strands Agent. It validates the input,
393+
This method handles tool use streams from a Strands Agent. It validates the input,
396394
calls the function, and formats the result according to the expected tool result format.
397395
398396
Key operations:
@@ -404,15 +402,17 @@ def invoke(self, tool: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> ToolRes
404402
5. Handle and format any errors that occur
405403
406404
Args:
407-
tool: The tool use specification from the Agent.
405+
tool_use: The tool use specification from the Agent.
408406
*args: Additional positional arguments (not typically used).
409407
**kwargs: Additional keyword arguments, may include 'agent' reference.
410408
409+
Yields:
410+
Events of the tool stream.
411+
411412
Returns:
412413
A standardized tool result dictionary with status and content.
413414
"""
414415
# This is a tool use call - process accordingly
415-
tool_use = tool
416416
tool_use_id = tool_use.get("toolUseId", "unknown")
417417
tool_input = tool_use.get("input", {})
418418

@@ -424,8 +424,9 @@ def invoke(self, tool: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> ToolRes
424424
if "agent" in kwargs and "agent" in self._metadata.signature.parameters:
425425
validated_input["agent"] = kwargs.get("agent")
426426

427-
# We get "too few arguments here" but because that's because fof the way we're calling it
428-
result = self.original_function(**validated_input) # type: ignore
427+
result = self._tool_func(**validated_input) # type: ignore # "Too few arguments" expected
428+
if inspect.isgenerator(result):
429+
result = yield from result
429430

430431
# FORMAT THE RESULT for Strands Agent
431432
if isinstance(result, dict) and "status" in result and "content" in result:
@@ -476,7 +477,7 @@ def get_display_properties(self) -> dict[str, str]:
476477
Function properties (e.g., function name).
477478
"""
478479
properties = super().get_display_properties()
479-
properties["Function"] = self.original_function.__name__
480+
properties["Function"] = self._tool_func.__name__
480481
return properties
481482

482483

@@ -573,7 +574,7 @@ def decorator(f: T) -> "DecoratedFunctionTool[P, R]":
573574
if not isinstance(tool_name, str):
574575
raise ValueError(f"Tool name must be a string, got {type(tool_name)}")
575576

576-
return DecoratedFunctionTool(function=f, tool_name=tool_name, tool_spec=tool_spec, metadata=tool_meta)
577+
return DecoratedFunctionTool(tool_name, tool_spec, f, tool_meta)
577578

578579
# Handle both @tool and @tool() syntax
579580
if func is None:

src/strands/tools/executor.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,23 +41,23 @@ def run_tools(
4141
thread_pool: Optional thread pool for parallel processing.
4242
4343
Yields:
44-
Events of the tool invocations. Tool results are appended to `tool_results`.
44+
Events of the tool stream. Tool results are appended to `tool_results`.
4545
"""
4646

47-
def handle(tool: ToolUse) -> ToolGenerator:
47+
def handle(tool_use: ToolUse) -> ToolGenerator:
4848
tracer = get_tracer()
49-
tool_call_span = tracer.start_tool_call_span(tool, parent_span)
49+
tool_call_span = tracer.start_tool_call_span(tool_use, parent_span)
5050

51-
tool_name = tool["name"]
51+
tool_name = tool_use["name"]
5252
tool_trace = Trace(f"Tool: {tool_name}", parent_id=cycle_trace.id, raw_name=tool_name)
5353
tool_start_time = time.time()
5454

55-
result = yield from handler(tool)
55+
result = yield from handler(tool_use)
5656

5757
tool_success = result.get("status") == "success"
5858
tool_duration = time.time() - tool_start_time
5959
message = Message(role="user", content=[{"toolResult": result}])
60-
event_loop_metrics.add_tool_usage(tool, tool_duration, tool_trace, tool_success, message)
60+
event_loop_metrics.add_tool_usage(tool_use, tool_duration, tool_trace, tool_success, message)
6161
cycle_trace.add_child(tool_trace)
6262

6363
if tool_call_span:
@@ -66,12 +66,12 @@ def handle(tool: ToolUse) -> ToolGenerator:
6666
return result
6767

6868
def work(
69-
tool: ToolUse,
69+
tool_use: ToolUse,
7070
worker_id: int,
7171
worker_queue: queue.Queue,
7272
worker_event: threading.Event,
7373
) -> ToolResult:
74-
events = handle(tool)
74+
events = handle(tool_use)
7575

7676
try:
7777
while True:

src/strands/tools/loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def load_python_tool(tool_path: str, tool_name: str) -> AgentTool:
108108
if not callable(tool_func):
109109
raise TypeError(f"Tool {tool_name} function is not callable")
110110

111-
return PythonAgentTool(tool_name, tool_spec, callback=tool_func)
111+
return PythonAgentTool(tool_name, tool_spec, tool_func)
112112

113113
except Exception:
114114
logger.exception("tool_name=<%s>, sys_path=<%s> | failed to load python tool", tool_name, sys.path)

src/strands/tools/mcp/mcp_agent_tool.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99
from typing import TYPE_CHECKING, Any
1010

1111
from mcp.types import Tool as MCPTool
12+
from typing_extensions import override
1213

13-
from ...types.tools import AgentTool, ToolResult, ToolSpec, ToolUse
14+
from ...types.tools import AgentTool, ToolGenerator, ToolSpec, ToolUse
1415

1516
if TYPE_CHECKING:
1617
from .mcp_client import MCPClient
@@ -73,13 +74,22 @@ def tool_type(self) -> str:
7374
"""
7475
return "python"
7576

76-
def invoke(self, tool: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> ToolResult:
77-
"""Invoke the MCP tool.
77+
@override
78+
def stream(self, tool_use: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> ToolGenerator:
79+
"""Stream the MCP tool.
7880
79-
This method delegates the tool invocation to the MCP server connection,
80-
passing the tool use ID, tool name, and input arguments.
81+
This method delegates the tool stream to the MCP server connection, passing the tool use ID, tool name, and
82+
input arguments.
83+
84+
Yields:
85+
No events.
86+
87+
Returns:
88+
A standardized tool result dictionary with status and content.
8189
"""
82-
logger.debug("invoking MCP tool '%s' with tool_use_id=%s", self.tool_name, tool["toolUseId"])
90+
logger.debug("tool_name=<%s>, tool_use_id=<%s> | streaming", self.tool_name, tool_use["toolUseId"])
91+
8392
return self.mcp_client.call_tool_sync(
84-
tool_use_id=tool["toolUseId"], name=self.tool_name, arguments=tool["input"]
93+
tool_use_id=tool_use["toolUseId"], name=self.tool_name, arguments=tool_use["input"]
8594
)
95+
yield # type: ignore # Need yield to create generator, but left unreachable as we have no events

src/strands/tools/registry.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -347,11 +347,7 @@ def reload_tool(self, tool_name: str) -> None:
347347
# Validate tool spec
348348
self.validate_tool_spec(module.TOOL_SPEC)
349349

350-
new_tool = PythonAgentTool(
351-
tool_name=tool_name,
352-
tool_spec=module.TOOL_SPEC,
353-
callback=tool_function,
354-
)
350+
new_tool = PythonAgentTool(tool_name, module.TOOL_SPEC, tool_function)
355351

356352
# Register the tool
357353
self.register_tool(new_tool)
@@ -431,11 +427,7 @@ def initialize_tools(self, load_tools_from_directory: bool = True) -> None:
431427
continue
432428

433429
tool_spec = module.TOOL_SPEC
434-
tool = PythonAgentTool(
435-
tool_name=tool_name,
436-
tool_spec=tool_spec,
437-
callback=tool_function,
438-
)
430+
tool = PythonAgentTool(tool_name, tool_spec, tool_function)
439431
self.register_tool(tool)
440432
successful_loads += 1
441433

@@ -463,11 +455,7 @@ def initialize_tools(self, load_tools_from_directory: bool = True) -> None:
463455
continue
464456

465457
tool_spec = module.TOOL_SPEC
466-
tool = PythonAgentTool(
467-
tool_name=tool_name,
468-
tool_spec=tool_spec,
469-
callback=tool_function,
470-
)
458+
tool = PythonAgentTool(tool_name, tool_spec, tool_function)
471459
self.register_tool(tool)
472460
successful_loads += 1
473461

0 commit comments

Comments
 (0)