Skip to content

Commit 999df44

Browse files
committed
iterative tools
1 parent e8cf208 commit 999df44

File tree

16 files changed

+435
-276
lines changed

16 files changed

+435
-276
lines changed

src/strands/agent/agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def caller(
131131

132132
# Execute the tool
133133
events = self._agent.tool_handler.process(
134-
tool=tool_use,
134+
tool_use,
135135
model=self._agent.model,
136136
system_prompt=self._agent.system_prompt,
137137
messages=self._agent.messages,

src/strands/event_loop/event_loop.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ async def event_loop_cycle(
7575
- event_loop_cycle_span: Current tracing Span for this cycle
7676
7777
Yields:
78-
Model and tool invocation events. The last event is a tuple containing:
78+
Model and tool stream events. The last event is a tuple containing:
7979
8080
- StopReason: Reason the model stopped generating (e.g., "tool_use")
8181
- Message: The generated message from the model
@@ -353,8 +353,8 @@ async def _handle_tool_execution(
353353
kwargs: Additional keyword arguments, including request state.
354354
355355
Yields:
356-
Tool invocation events along with events yielded from a recursive call to the event loop. The last event is a
357-
tuple containing:
356+
Tool stream events along with events yielded from a recursive call to the event loop. The last event is a tuple
357+
containing:
358358
- The stop reason,
359359
- The updated message,
360360
- The updated event loop metrics,

src/strands/handlers/tool_handler.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""This module provides handlers for managing tool invocations."""
1+
"""This module provides handlers for managing tool streams."""
22

33
import logging
44
from typing import Any, Optional
@@ -12,10 +12,10 @@
1212

1313

1414
class AgentToolHandler(ToolHandler):
15-
"""Handler for processing tool invocations in agent.
15+
"""Handler for processing tool streams.
1616
1717
This class implements the ToolHandler interface and provides functionality for looking up tools in a registry and
18-
invoking them with the appropriate parameters.
18+
streaming them with the appropriate parameters.
1919
"""
2020

2121
def __init__(self, tool_registry: ToolRegistry) -> None:
@@ -28,35 +28,35 @@ def __init__(self, tool_registry: ToolRegistry) -> None:
2828

2929
def process(
3030
self,
31-
tool: ToolUse,
31+
tool_use: ToolUse,
3232
*,
3333
model: Model,
3434
system_prompt: Optional[str],
3535
messages: Messages,
3636
tool_config: ToolConfig,
3737
kwargs: dict[str, Any],
3838
) -> ToolGenerator:
39-
"""Process a tool invocation.
39+
"""Process a tool stream.
4040
41-
Looks up the tool in the registry and invokes it with the provided parameters.
41+
Looks up the tool in the registry and streams it with the provided parameters.
4242
4343
Args:
44-
tool: The tool object to process, containing name and parameters.
44+
tool_use: The tool object to process, containing name and parameters.
4545
model: The model being used for the agent.
4646
system_prompt: The system prompt for the agent.
4747
messages: The conversation history.
4848
tool_config: Configuration for the tool.
4949
kwargs: Additional keyword arguments passed to the tool.
5050
5151
Yields:
52-
Events of the tool invocation.
52+
Events of the tool stream.
5353
5454
Returns:
5555
The final tool result or an error response if the tool fails or is not found.
5656
"""
57-
logger.debug("tool=<%s> | invoking", tool)
58-
tool_use_id = tool["toolUseId"]
59-
tool_name = tool["name"]
57+
logger.debug("tool_use=<%s> | streaming", tool_use)
58+
tool_use_id = tool_use["toolUseId"]
59+
tool_name = tool_use["name"]
6060

6161
# Get the tool info
6262
tool_info = self.tool_registry.dynamic_tools.get(tool_name)
@@ -85,9 +85,7 @@ def process(
8585
}
8686
)
8787

88-
result = tool_func.invoke(tool, **kwargs)
89-
yield {"result": result} # Placeholder until tool_func becomes a generator from which we can yield from
90-
return result
88+
return (yield from tool_func.stream(tool_use, **kwargs))
9189

9290
except Exception as e:
9391
logger.exception("tool_name=<%s> | failed to process tool", tool_name)

src/strands/tools/decorator.py

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ def my_tool(param1: str, param2: int = 42) -> dict:
4646
from typing import (
4747
Any,
4848
Callable,
49-
Dict,
5049
Generic,
5150
Optional,
5251
ParamSpec,
@@ -62,7 +61,7 @@ def my_tool(param1: str, param2: int = 42) -> dict:
6261
from pydantic import BaseModel, Field, create_model
6362
from typing_extensions import override
6463

65-
from ..types.tools import AgentTool, JSONSchema, ToolResult, ToolSpec, ToolUse
64+
from ..types.tools import AgentTool, JSONSchema, ToolGenerator, ToolResult, ToolSpec, ToolUse
6665

6766
logger = logging.getLogger(__name__)
6867

@@ -119,7 +118,7 @@ def _create_input_model(self) -> Type[BaseModel]:
119118
Returns:
120119
A Pydantic BaseModel class customized for the function's parameters.
121120
"""
122-
field_definitions: Dict[str, Any] = {}
121+
field_definitions: dict[str, Any] = {}
123122

124123
for name, param in self.signature.parameters.items():
125124
# Skip special parameters
@@ -179,7 +178,7 @@ def extract_metadata(self) -> ToolSpec:
179178

180179
return tool_spec
181180

182-
def _clean_pydantic_schema(self, schema: Dict[str, Any]) -> None:
181+
def _clean_pydantic_schema(self, schema: dict[str, Any]) -> None:
183182
"""Clean up Pydantic schema to match Strands' expected format.
184183
185184
Pydantic's JSON schema output includes several elements that aren't needed for Strands Agent tools and could
@@ -227,7 +226,7 @@ def _clean_pydantic_schema(self, schema: Dict[str, Any]) -> None:
227226
if key in prop_schema:
228227
del prop_schema[key]
229228

230-
def validate_input(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
229+
def validate_input(self, input_data: dict[str, Any]) -> dict[str, Any]:
231230
"""Validate input data using the Pydantic model.
232231
233232
This method ensures that the input data meets the expected schema before it's passed to the actual function. It
@@ -270,32 +269,32 @@ class DecoratedFunctionTool(AgentTool, Generic[P, R]):
270269

271270
_tool_name: str
272271
_tool_spec: ToolSpec
272+
_tool_func: Callable[P, R]
273273
_metadata: FunctionToolMetadata
274-
original_function: Callable[P, R]
275274

276275
def __init__(
277276
self,
278-
function: Callable[P, R],
279277
tool_name: str,
280278
tool_spec: ToolSpec,
279+
tool_func: Callable[P, R],
281280
metadata: FunctionToolMetadata,
282281
):
283282
"""Initialize the decorated function tool.
284283
285284
Args:
286-
function: The original function being decorated.
287285
tool_name: The name to use for the tool (usually the function name).
288286
tool_spec: The tool specification containing metadata for Agent integration.
287+
tool_func: The original function being decorated.
289288
metadata: The FunctionToolMetadata object with extracted function information.
290289
"""
291290
super().__init__()
292291

293-
self.original_function = function
292+
self._tool_name = tool_name
294293
self._tool_spec = tool_spec
294+
self._tool_func = tool_func
295295
self._metadata = metadata
296-
self._tool_name = tool_name
297296

298-
functools.update_wrapper(wrapper=self, wrapped=self.original_function)
297+
functools.update_wrapper(wrapper=self, wrapped=self._tool_func)
299298

300299
def __get__(self, instance: Any, obj_type: Optional[Type] = None) -> "DecoratedFunctionTool[P, R]":
301300
"""Descriptor protocol implementation for proper method binding.
@@ -323,12 +322,10 @@ def my_tool():
323322
tool = instance.my_tool
324323
```
325324
"""
326-
if instance is not None and not inspect.ismethod(self.original_function):
325+
if instance is not None and not inspect.ismethod(self._tool_func):
327326
# Create a bound method
328-
new_callback = self.original_function.__get__(instance, instance.__class__)
329-
return DecoratedFunctionTool(
330-
function=new_callback, tool_name=self.tool_name, tool_spec=self.tool_spec, metadata=self._metadata
331-
)
327+
tool_func = self._tool_func.__get__(instance, instance.__class__)
328+
return DecoratedFunctionTool(self._tool_name, self._tool_spec, tool_func, self._metadata)
332329

333330
return self
334331

@@ -360,7 +357,7 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
360357

361358
return cast(R, self.invoke(tool_use, **kwargs))
362359

363-
return self.original_function(*args, **kwargs)
360+
return self._tool_func(*args, **kwargs)
364361

365362
@property
366363
def tool_name(self) -> str:
@@ -389,10 +386,20 @@ def tool_type(self) -> str:
389386
"""
390387
return "function"
391388

392-
def invoke(self, tool: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> ToolResult:
393-
"""Invoke the tool with a tool use specification.
389+
@property
390+
def tool_func(self) -> Callable[P, R]:
391+
"""Get the undecorated tool function.
392+
393+
Returns:
394+
Undecorated tool function.
395+
"""
396+
return self._tool_func
394397

395-
This method handles tool use invocations from a Strands Agent. It validates the input,
398+
@override
399+
def stream(self, tool_use: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> ToolGenerator:
400+
"""Stream the tool with a tool use specification.
401+
402+
This method handles tool use streams from a Strands Agent. It validates the input,
396403
calls the function, and formats the result according to the expected tool result format.
397404
398405
Key operations:
@@ -404,15 +411,17 @@ def invoke(self, tool: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> ToolRes
404411
5. Handle and format any errors that occur
405412
406413
Args:
407-
tool: The tool use specification from the Agent.
414+
tool_use: The tool use specification from the Agent.
408415
*args: Additional positional arguments (not typically used).
409416
**kwargs: Additional keyword arguments, may include 'agent' reference.
410417
418+
Yields:
419+
Events of the tool stream.
420+
411421
Returns:
412422
A standardized tool result dictionary with status and content.
413423
"""
414424
# This is a tool use call - process accordingly
415-
tool_use = tool
416425
tool_use_id = tool_use.get("toolUseId", "unknown")
417426
tool_input = tool_use.get("input", {})
418427

@@ -424,8 +433,9 @@ def invoke(self, tool: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> ToolRes
424433
if "agent" in kwargs and "agent" in self._metadata.signature.parameters:
425434
validated_input["agent"] = kwargs.get("agent")
426435

427-
# We get "too few arguments here" but because that's because fof the way we're calling it
428-
result = self.original_function(**validated_input) # type: ignore
436+
result = self._tool_func(**validated_input) # type: ignore # "Too few arguments" expected
437+
if inspect.isgenerator(result):
438+
result = yield from result
429439

430440
# FORMAT THE RESULT for Strands Agent
431441
if isinstance(result, dict) and "status" in result and "content" in result:
@@ -476,7 +486,7 @@ def get_display_properties(self) -> dict[str, str]:
476486
Function properties (e.g., function name).
477487
"""
478488
properties = super().get_display_properties()
479-
properties["Function"] = self.original_function.__name__
489+
properties["Function"] = self._tool_func.__name__
480490
return properties
481491

482492

@@ -573,7 +583,7 @@ def decorator(f: T) -> "DecoratedFunctionTool[P, R]":
573583
if not isinstance(tool_name, str):
574584
raise ValueError(f"Tool name must be a string, got {type(tool_name)}")
575585

576-
return DecoratedFunctionTool(function=f, tool_name=tool_name, tool_spec=tool_spec, metadata=tool_meta)
586+
return DecoratedFunctionTool(tool_name, tool_spec, f, tool_meta)
577587

578588
# Handle both @tool and @tool() syntax
579589
if func is None:

src/strands/tools/executor.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,23 +41,23 @@ def run_tools(
4141
thread_pool: Optional thread pool for parallel processing.
4242
4343
Yields:
44-
Events of the tool invocations. Tool results are appended to `tool_results`.
44+
Events of the tool stream. Tool results are appended to `tool_results`.
4545
"""
4646

47-
def handle(tool: ToolUse) -> ToolGenerator:
47+
def handle(tool_use: ToolUse) -> ToolGenerator:
4848
tracer = get_tracer()
49-
tool_call_span = tracer.start_tool_call_span(tool, parent_span)
49+
tool_call_span = tracer.start_tool_call_span(tool_use, parent_span)
5050

51-
tool_name = tool["name"]
51+
tool_name = tool_use["name"]
5252
tool_trace = Trace(f"Tool: {tool_name}", parent_id=cycle_trace.id, raw_name=tool_name)
5353
tool_start_time = time.time()
5454

55-
result = yield from handler(tool)
55+
result = yield from handler(tool_use)
5656

5757
tool_success = result.get("status") == "success"
5858
tool_duration = time.time() - tool_start_time
5959
message = Message(role="user", content=[{"toolResult": result}])
60-
event_loop_metrics.add_tool_usage(tool, tool_duration, tool_trace, tool_success, message)
60+
event_loop_metrics.add_tool_usage(tool_use, tool_duration, tool_trace, tool_success, message)
6161
cycle_trace.add_child(tool_trace)
6262

6363
if tool_call_span:
@@ -66,12 +66,12 @@ def handle(tool: ToolUse) -> ToolGenerator:
6666
return result
6767

6868
def work(
69-
tool: ToolUse,
69+
tool_use: ToolUse,
7070
worker_id: int,
7171
worker_queue: queue.Queue,
7272
worker_event: threading.Event,
7373
) -> ToolResult:
74-
events = handle(tool)
74+
events = handle(tool_use)
7575

7676
try:
7777
while True:

src/strands/tools/loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def load_python_tool(tool_path: str, tool_name: str) -> AgentTool:
108108
if not callable(tool_func):
109109
raise TypeError(f"Tool {tool_name} function is not callable")
110110

111-
return PythonAgentTool(tool_name, tool_spec, callback=tool_func)
111+
return PythonAgentTool(tool_name, tool_spec, tool_func)
112112

113113
except Exception:
114114
logger.exception("tool_name=<%s>, sys_path=<%s> | failed to load python tool", tool_name, sys.path)

src/strands/tools/mcp/mcp_agent_tool.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99
from typing import TYPE_CHECKING, Any
1010

1111
from mcp.types import Tool as MCPTool
12+
from typing_extensions import override
1213

13-
from ...types.tools import AgentTool, ToolResult, ToolSpec, ToolUse
14+
from ...types.tools import AgentTool, ToolGenerator, ToolResult, ToolSpec, ToolUse
1415

1516
if TYPE_CHECKING:
1617
from .mcp_client import MCPClient
@@ -73,13 +74,31 @@ def tool_type(self) -> str:
7374
"""
7475
return "python"
7576

76-
def invoke(self, tool: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> ToolResult:
77+
@override
78+
def invoke(self, tool_use: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> ToolResult:
7779
"""Invoke the MCP tool.
7880
7981
This method delegates the tool invocation to the MCP server connection,
8082
passing the tool use ID, tool name, and input arguments.
83+
84+
Returns:
85+
A standardized tool result dictionary with status and content.
8186
"""
82-
logger.debug("invoking MCP tool '%s' with tool_use_id=%s", self.tool_name, tool["toolUseId"])
87+
logger.debug("tool_name=<%s>, tool_use_id=<%s> | invoking", self.tool_name, tool_use["toolUseId"])
88+
8389
return self.mcp_client.call_tool_sync(
84-
tool_use_id=tool["toolUseId"], name=self.tool_name, arguments=tool["input"]
90+
tool_use_id=tool_use["toolUseId"], name=self.tool_name, arguments=tool_use["input"]
8591
)
92+
93+
@override
94+
def stream(self, tool_use: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> ToolGenerator:
95+
"""Stream the MCP tool.
96+
97+
Yields:
98+
No events.
99+
100+
Returns:
101+
A standardized tool result dictionary with status and content.
102+
"""
103+
return self.invoke(tool_use, *args, **kwargs)
104+
yield # type: ignore

0 commit comments

Comments
 (0)