@@ -46,7 +46,6 @@ def my_tool(param1: str, param2: int = 42) -> dict:
4646from typing import (
4747 Any ,
4848 Callable ,
49- Dict ,
5049 Generic ,
5150 Optional ,
5251 ParamSpec ,
@@ -62,7 +61,7 @@ def my_tool(param1: str, param2: int = 42) -> dict:
6261from pydantic import BaseModel , Field , create_model
6362from typing_extensions import override
6463
65- from ..types .tools import AgentTool , JSONSchema , ToolResult , ToolSpec , ToolUse
64+ from ..types .tools import AgentTool , JSONSchema , ToolGenerator , ToolResult , ToolSpec , ToolUse
6665
6766logger = logging .getLogger (__name__ )
6867
@@ -119,7 +118,7 @@ def _create_input_model(self) -> Type[BaseModel]:
119118 Returns:
120119 A Pydantic BaseModel class customized for the function's parameters.
121120 """
122- field_definitions : Dict [str , Any ] = {}
121+ field_definitions : dict [str , Any ] = {}
123122
124123 for name , param in self .signature .parameters .items ():
125124 # Skip special parameters
@@ -179,7 +178,7 @@ def extract_metadata(self) -> ToolSpec:
179178
180179 return tool_spec
181180
182- def _clean_pydantic_schema (self , schema : Dict [str , Any ]) -> None :
181+ def _clean_pydantic_schema (self , schema : dict [str , Any ]) -> None :
183182 """Clean up Pydantic schema to match Strands' expected format.
184183
185184 Pydantic's JSON schema output includes several elements that aren't needed for Strands Agent tools and could
@@ -227,7 +226,7 @@ def _clean_pydantic_schema(self, schema: Dict[str, Any]) -> None:
227226 if key in prop_schema :
228227 del prop_schema [key ]
229228
230- def validate_input (self , input_data : Dict [str , Any ]) -> Dict [str , Any ]:
229+ def validate_input (self , input_data : dict [str , Any ]) -> dict [str , Any ]:
231230 """Validate input data using the Pydantic model.
232231
233232 This method ensures that the input data meets the expected schema before it's passed to the actual function. It
@@ -270,32 +269,32 @@ class DecoratedFunctionTool(AgentTool, Generic[P, R]):
270269
271270 _tool_name : str
272271 _tool_spec : ToolSpec
272+ _tool_func : Callable [P , R ]
273273 _metadata : FunctionToolMetadata
274- original_function : Callable [P , R ]
275274
276275 def __init__ (
277276 self ,
278- function : Callable [P , R ],
279277 tool_name : str ,
280278 tool_spec : ToolSpec ,
279+ tool_func : Callable [P , R ],
281280 metadata : FunctionToolMetadata ,
282281 ):
283282 """Initialize the decorated function tool.
284283
285284 Args:
286- function: The original function being decorated.
287285 tool_name: The name to use for the tool (usually the function name).
288286 tool_spec: The tool specification containing metadata for Agent integration.
287+ tool_func: The original function being decorated.
289288 metadata: The FunctionToolMetadata object with extracted function information.
290289 """
291290 super ().__init__ ()
292291
293- self .original_function = function
292+ self ._tool_name = tool_name
294293 self ._tool_spec = tool_spec
294+ self ._tool_func = tool_func
295295 self ._metadata = metadata
296- self ._tool_name = tool_name
297296
298- functools .update_wrapper (wrapper = self , wrapped = self .original_function )
297+ functools .update_wrapper (wrapper = self , wrapped = self ._tool_func )
299298
300299 def __get__ (self , instance : Any , obj_type : Optional [Type ] = None ) -> "DecoratedFunctionTool[P, R]" :
301300 """Descriptor protocol implementation for proper method binding.
@@ -323,12 +322,10 @@ def my_tool():
323322 tool = instance.my_tool
324323 ```
325324 """
326- if instance is not None and not inspect .ismethod (self .original_function ):
325+ if instance is not None and not inspect .ismethod (self ._tool_func ):
327326 # Create a bound method
328- new_callback = self .original_function .__get__ (instance , instance .__class__ )
329- return DecoratedFunctionTool (
330- function = new_callback , tool_name = self .tool_name , tool_spec = self .tool_spec , metadata = self ._metadata
331- )
327+ tool_func = self ._tool_func .__get__ (instance , instance .__class__ )
328+ return DecoratedFunctionTool (self ._tool_name , self ._tool_spec , tool_func , self ._metadata )
332329
333330 return self
334331
@@ -360,7 +357,7 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
360357
361358 return cast (R , self .invoke (tool_use , ** kwargs ))
362359
363- return self .original_function (* args , ** kwargs )
360+ return self ._tool_func (* args , ** kwargs )
364361
365362 @property
366363 def tool_name (self ) -> str :
@@ -389,10 +386,20 @@ def tool_type(self) -> str:
389386 """
390387 return "function"
391388
392- def invoke (self , tool : ToolUse , * args : Any , ** kwargs : dict [str , Any ]) -> ToolResult :
393- """Invoke the tool with a tool use specification.
389+ @property
390+ def tool_func (self ) -> Callable [P , R ]:
391+ """Get the undecorated tool function.
392+
393+ Returns:
394+ Undecorated tool function.
395+ """
396+ return self ._tool_func
394397
395- This method handles tool use invocations from a Strands Agent. It validates the input,
398+ @override
399+ def stream (self , tool_use : ToolUse , * args : Any , ** kwargs : dict [str , Any ]) -> ToolGenerator :
400+ """Stream the tool with a tool use specification.
401+
402+ This method handles tool use streams from a Strands Agent. It validates the input,
396403 calls the function, and formats the result according to the expected tool result format.
397404
398405 Key operations:
@@ -404,15 +411,17 @@ def invoke(self, tool: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> ToolRes
404411 5. Handle and format any errors that occur
405412
406413 Args:
407- tool : The tool use specification from the Agent.
414+ tool_use : The tool use specification from the Agent.
408415 *args: Additional positional arguments (not typically used).
409416 **kwargs: Additional keyword arguments, may include 'agent' reference.
410417
418+ Yields:
419+ Events of the tool stream.
420+
411421 Returns:
412422 A standardized tool result dictionary with status and content.
413423 """
414424 # This is a tool use call - process accordingly
415- tool_use = tool
416425 tool_use_id = tool_use .get ("toolUseId" , "unknown" )
417426 tool_input = tool_use .get ("input" , {})
418427
@@ -424,8 +433,9 @@ def invoke(self, tool: ToolUse, *args: Any, **kwargs: dict[str, Any]) -> ToolRes
424433 if "agent" in kwargs and "agent" in self ._metadata .signature .parameters :
425434 validated_input ["agent" ] = kwargs .get ("agent" )
426435
427- # We get "too few arguments here" but because that's because fof the way we're calling it
428- result = self .original_function (** validated_input ) # type: ignore
436+ result = self ._tool_func (** validated_input ) # type: ignore # "Too few arguments" expected
437+ if inspect .isgenerator (result ):
438+ result = yield from result
429439
430440 # FORMAT THE RESULT for Strands Agent
431441 if isinstance (result , dict ) and "status" in result and "content" in result :
@@ -476,7 +486,7 @@ def get_display_properties(self) -> dict[str, str]:
476486 Function properties (e.g., function name).
477487 """
478488 properties = super ().get_display_properties ()
479- properties ["Function" ] = self .original_function .__name__
489+ properties ["Function" ] = self ._tool_func .__name__
480490 return properties
481491
482492
@@ -573,7 +583,7 @@ def decorator(f: T) -> "DecoratedFunctionTool[P, R]":
573583 if not isinstance (tool_name , str ):
574584 raise ValueError (f"Tool name must be a string, got { type (tool_name )} " )
575585
576- return DecoratedFunctionTool (function = f , tool_name = tool_name , tool_spec = tool_spec , metadata = tool_meta )
586+ return DecoratedFunctionTool (tool_name , tool_spec , f , tool_meta )
577587
578588 # Handle both @tool and @tool() syntax
579589 if func is None :
0 commit comments