@@ -46,7 +46,6 @@ def my_tool(param1: str, param2: int = 42) -> dict:
46
46
from typing import (
47
47
Any ,
48
48
Callable ,
49
- Dict ,
50
49
Generic ,
51
50
Optional ,
52
51
ParamSpec ,
@@ -62,7 +61,7 @@ def my_tool(param1: str, param2: int = 42) -> dict:
62
61
from pydantic import BaseModel , Field , create_model
63
62
from typing_extensions import override
64
63
65
- from ..types .tools import AgentTool , JSONSchema , ToolResult , ToolSpec , ToolUse
64
+ from ..types .tools import AgentTool , JSONSchema , ToolGenerator , ToolResult , ToolSpec , ToolUse
66
65
67
66
logger = logging .getLogger (__name__ )
68
67
@@ -119,7 +118,7 @@ def _create_input_model(self) -> Type[BaseModel]:
119
118
Returns:
120
119
A Pydantic BaseModel class customized for the function's parameters.
121
120
"""
122
- field_definitions : Dict [str , Any ] = {}
121
+ field_definitions : dict [str , Any ] = {}
123
122
124
123
for name , param in self .signature .parameters .items ():
125
124
# Skip special parameters
@@ -179,7 +178,7 @@ def extract_metadata(self) -> ToolSpec:
179
178
180
179
return tool_spec
181
180
182
- def _clean_pydantic_schema (self , schema : Dict [str , Any ]) -> None :
181
+ def _clean_pydantic_schema (self , schema : dict [str , Any ]) -> None :
183
182
"""Clean up Pydantic schema to match Strands' expected format.
184
183
185
184
Pydantic's JSON schema output includes several elements that aren't needed for Strands Agent tools and could
@@ -227,7 +226,7 @@ def _clean_pydantic_schema(self, schema: Dict[str, Any]) -> None:
227
226
if key in prop_schema :
228
227
del prop_schema [key ]
229
228
230
- def validate_input (self , input_data : Dict [str , Any ]) -> Dict [str , Any ]:
229
+ def validate_input (self , input_data : dict [str , Any ]) -> dict [str , Any ]:
231
230
"""Validate input data using the Pydantic model.
232
231
233
232
This method ensures that the input data meets the expected schema before it's passed to the actual function. It
@@ -270,32 +269,32 @@ class DecoratedFunctionTool(AgentTool, Generic[P, R]):
270
269
271
270
_tool_name : str
272
271
_tool_spec : ToolSpec
272
+ _tool_func : Callable [P , R ]
273
273
_metadata : FunctionToolMetadata
274
- original_function : Callable [P , R ]
275
274
276
275
def __init__ (
277
276
self ,
278
- function : Callable [P , R ],
279
277
tool_name : str ,
280
278
tool_spec : ToolSpec ,
279
+ tool_func : Callable [P , R ],
281
280
metadata : FunctionToolMetadata ,
282
281
):
283
282
"""Initialize the decorated function tool.
284
283
285
284
Args:
286
- function: The original function being decorated.
287
285
tool_name: The name to use for the tool (usually the function name).
288
286
tool_spec: The tool specification containing metadata for Agent integration.
287
+ tool_func: The original function being decorated.
289
288
metadata: The FunctionToolMetadata object with extracted function information.
290
289
"""
291
290
super ().__init__ ()
292
291
293
- self .original_function = function
292
+ self ._tool_name = tool_name
294
293
self ._tool_spec = tool_spec
294
+ self ._tool_func = tool_func
295
295
self ._metadata = metadata
296
- self ._tool_name = tool_name
297
296
298
- functools .update_wrapper (wrapper = self , wrapped = self .original_function )
297
+ functools .update_wrapper (wrapper = self , wrapped = self ._tool_func )
299
298
300
299
def __get__ (self , instance : Any , obj_type : Optional [Type ] = None ) -> "DecoratedFunctionTool[P, R]" :
301
300
"""Descriptor protocol implementation for proper method binding.
@@ -323,12 +322,10 @@ def my_tool():
323
322
tool = instance.my_tool
324
323
```
325
324
"""
326
- if instance is not None and not inspect .ismethod (self .original_function ):
325
+ if instance is not None and not inspect .ismethod (self ._tool_func ):
327
326
# Create a bound method
328
- new_callback = self .original_function .__get__ (instance , instance .__class__ )
329
- return DecoratedFunctionTool (
330
- function = new_callback , tool_name = self .tool_name , tool_spec = self .tool_spec , metadata = self ._metadata
331
- )
327
+ tool_func = self ._tool_func .__get__ (instance , instance .__class__ )
328
+ return DecoratedFunctionTool (self ._tool_name , self ._tool_spec , tool_func , self ._metadata )
332
329
333
330
return self
334
331
@@ -360,7 +357,7 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
360
357
361
358
return cast (R , self .invoke (tool_use , ** kwargs ))
362
359
363
- return self .original_function (* args , ** kwargs )
360
+ return self ._tool_func (* args , ** kwargs )
364
361
365
362
@property
366
363
def tool_name (self ) -> str :
@@ -389,10 +386,20 @@ def tool_type(self) -> str:
389
386
"""
390
387
return "function"
391
388
392
- def invoke (self , tool : ToolUse , * args : Any , ** kwargs : dict [str , Any ]) -> ToolResult :
393
- """Invoke the tool with a tool use specification.
389
+ @property
390
+ def tool_func (self ) -> Callable [P , R ]:
391
+ """Get the undecorated tool function.
392
+
393
+ Returns:
394
+ Undecorated tool function.
395
+ """
396
+ return self ._tool_func
394
397
395
- This method handles tool use invocations from a Strands Agent. It validates the input,
398
+ @override
399
+ def stream (self , tool_use : ToolUse , * args : Any , ** kwargs : dict [str , Any ]) -> ToolGenerator :
400
+ """Stream the tool with a tool use specification.
401
+
402
+ This method handles tool use streams from a Strands Agent. It validates the input,
396
403
calls the function, and formats the result according to the expected tool result format.
397
404
398
405
Key operations:
@@ -404,15 +411,17 @@ def invoke(self, tool: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> ToolRes
404
411
5. Handle and format any errors that occur
405
412
406
413
Args:
407
- tool : The tool use specification from the Agent.
414
+ tool_use : The tool use specification from the Agent.
408
415
*args: Additional positional arguments (not typically used).
409
416
**kwargs: Additional keyword arguments, may include 'agent' reference.
410
417
418
+ Yields:
419
+ Events of the tool stream.
420
+
411
421
Returns:
412
422
A standardized tool result dictionary with status and content.
413
423
"""
414
424
# This is a tool use call - process accordingly
415
- tool_use = tool
416
425
tool_use_id = tool_use .get ("toolUseId" , "unknown" )
417
426
tool_input = tool_use .get ("input" , {})
418
427
@@ -424,8 +433,9 @@ def invoke(self, tool: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> ToolRes
424
433
if "agent" in kwargs and "agent" in self ._metadata .signature .parameters :
425
434
validated_input ["agent" ] = kwargs .get ("agent" )
426
435
427
- # We get "too few arguments here" but because that's because fof the way we're calling it
428
- result = self .original_function (** validated_input ) # type: ignore
436
+ result = self ._tool_func (** validated_input ) # type: ignore # "Too few arguments" expected
437
+ if inspect .isgenerator (result ):
438
+ result = yield from result
429
439
430
440
# FORMAT THE RESULT for Strands Agent
431
441
if isinstance (result , dict ) and "status" in result and "content" in result :
@@ -476,7 +486,7 @@ def get_display_properties(self) -> dict[str, str]:
476
486
Function properties (e.g., function name).
477
487
"""
478
488
properties = super ().get_display_properties ()
479
- properties ["Function" ] = self .original_function .__name__
489
+ properties ["Function" ] = self ._tool_func .__name__
480
490
return properties
481
491
482
492
@@ -573,7 +583,7 @@ def decorator(f: T) -> "DecoratedFunctionTool[P, R]":
573
583
if not isinstance (tool_name , str ):
574
584
raise ValueError (f"Tool name must be a string, got { type (tool_name )} " )
575
585
576
- return DecoratedFunctionTool (function = f , tool_name = tool_name , tool_spec = tool_spec , metadata = tool_meta )
586
+ return DecoratedFunctionTool (tool_name , tool_spec , f , tool_meta )
577
587
578
588
# Handle both @tool and @tool() syntax
579
589
if func is None :
0 commit comments