@@ -66,6 +66,98 @@ def my_tool(param1: str, param2: int = 42) -> dict:
6666logger = logging .getLogger (__name__ )
6767
6868
69+ def _resolve_json_schema_references (schema : dict [str , Any ]) -> dict [str , Any ]:
70+ """Resolve all $ref references in a JSON schema by inlining definitions.
71+
72+ Some model providers (e.g., Bedrock via LiteLLM) don't support JSON Schema
73+ $ref references. This function flattens the schema by replacing all $ref
74+ occurrences with their actual definitions from the $defs section.
75+
76+ This is particularly important for Pydantic-generated schemas that use $defs
77+ for enum types, as these would otherwise cause validation errors with certain
78+ model providers.
79+
80+ Args:
81+ schema: A JSON schema dict that may contain $ref references and a $defs section
82+
83+ Returns:
84+ A new schema dict with all $ref references replaced by their definitions.
85+ The $defs section is removed from the result.
86+
87+ Example:
88+ Input schema with $ref:
89+ {
90+ "$defs": {"Color": {"type": "string", "enum": ["red", "blue"]}},
91+ "properties": {"color": {"$ref": "#/$defs/Color"}}
92+ }
93+
94+ Output schema with resolved reference:
95+ {
96+ "properties": {"color": {"type": "string", "enum": ["red", "blue"]}}
97+ }
98+ """
99+ # Get definitions if they exist
100+ defs = schema .get ("$defs" , {})
101+ if not defs :
102+ return schema
103+
104+ def resolve_node (node : Any ) -> Any :
105+ """Recursively process a schema node, replacing any $ref with actual definitions.
106+
107+ Args:
108+ node: Any value from the schema (dict, list, or primitive)
109+
110+ Returns:
111+ The node with all $ref references resolved
112+ """
113+ if not isinstance (node , dict ):
114+ return node
115+
116+ # If this node is a $ref, replace it with the referenced definition
117+ if "$ref" in node :
118+ # Extract the definition name from the reference (e.g., "#/$defs/Color" -> "Color")
119+ ref_name = node ["$ref" ].split ("/" )[- 1 ]
120+ if ref_name in defs :
121+ # Copy the referenced definition to avoid modifying the original
122+ resolved = defs [ref_name ].copy ()
123+ # Preserve any additional properties from the $ref node (e.g., "default", "description")
124+ for key , value in node .items ():
125+ if key != "$ref" :
126+ resolved [key ] = value
127+ # Recursively resolve in case the definition itself contains references
128+ return resolve_node (resolved )
129+ # If reference not found, return as-is (shouldn't happen with valid schemas)
130+ return node
131+
132+ # For dict nodes, recursively process all values
133+ result : dict [str , Any ] = {}
134+ for key , value in node .items ():
135+ if isinstance (value , list ):
136+ # For arrays, resolve each item
137+ result [key ] = [resolve_node (item ) for item in value ]
138+ elif isinstance (value , dict ):
139+ # For objects, check if this is a properties dict that needs special handling
140+ if key == "properties" and isinstance (value , dict ):
141+ # Ensure all property definitions are fully resolved
142+ result [key ] = {
143+ prop_name : resolve_node (prop_schema )
144+ for prop_name , prop_schema in value .items ()
145+ }
146+ else :
147+ result [key ] = resolve_node (value )
148+ else :
149+ # Primitive values are copied as-is
150+ result [key ] = value
151+ return result
152+
153+ # Process the entire schema, excluding the $defs section from the result
154+ result = {
155+ key : resolve_node (value ) for key , value in schema .items () if key != "$defs"
156+ }
157+
158+ return result
159+
160+
69161# Type for wrapped function
70162T = TypeVar ("T" , bound = Callable [..., Any ])
71163
@@ -101,7 +193,8 @@ def __init__(self, func: Callable[..., Any]) -> None:
101193
102194 # Get parameter descriptions from parsed docstring
103195 self .param_descriptions = {
104- param .arg_name : param .description or f"Parameter { param .arg_name } " for param in self .doc .params
196+ param .arg_name : param .description or f"Parameter { param .arg_name } "
197+ for param in self .doc .params
105198 }
106199
107200 # Create a Pydantic model for validation
@@ -131,7 +224,10 @@ def _create_input_model(self) -> Type[BaseModel]:
131224 description = self .param_descriptions .get (name , f"Parameter { name } " )
132225
133226 # Create Field with description and default
134- field_definitions [name ] = (param_type , Field (default = default , description = description ))
227+ field_definitions [name ] = (
228+ param_type ,
229+ Field (default = default , description = description ),
230+ )
135231
136232 # Create model name based on function name
137233 model_name = f"{ self .func .__name__ .capitalize ()} Tool"
@@ -173,8 +269,17 @@ def extract_metadata(self) -> ToolSpec:
173269 # Clean up Pydantic-specific schema elements
174270 self ._clean_pydantic_schema (input_schema )
175271
272+ # Flatten schema by resolving $ref references to their definitions
273+ # This is required for compatibility with model providers that don't support
274+ # JSON Schema $ref (e.g., Bedrock/Anthropic via LiteLLM)
275+ input_schema = _resolve_json_schema_references (input_schema )
276+
176277 # Create tool specification
177- tool_spec : ToolSpec = {"name" : func_name , "description" : description , "inputSchema" : {"json" : input_schema }}
278+ tool_spec : ToolSpec = {
279+ "name" : func_name ,
280+ "description" : description ,
281+ "inputSchema" : {"json" : input_schema },
282+ }
178283
179284 return tool_spec
180285
@@ -206,7 +311,9 @@ def _clean_pydantic_schema(self, schema: dict[str, Any]) -> None:
206311 if "anyOf" in prop_schema :
207312 any_of = prop_schema ["anyOf" ]
208313 # Handle Optional[Type] case (represented as anyOf[Type, null])
209- if len (any_of ) == 2 and any (item .get ("type" ) == "null" for item in any_of ):
314+ if len (any_of ) == 2 and any (
315+ item .get ("type" ) == "null" for item in any_of
316+ ):
210317 # Find the non-null type
211318 for item in any_of :
212319 if item .get ("type" ) != "null" :
@@ -250,7 +357,9 @@ def validate_input(self, input_data: dict[str, Any]) -> dict[str, Any]:
250357 except Exception as e :
251358 # Re-raise with more detailed error message
252359 error_msg = str (e )
253- raise ValueError (f"Validation failed for input parameters: { error_msg } " ) from e
360+ raise ValueError (
361+ f"Validation failed for input parameters: { error_msg } "
362+ ) from e
254363
255364
256365P = ParamSpec ("P" ) # Captures all parameters
@@ -296,7 +405,9 @@ def __init__(
296405
297406 functools .update_wrapper (wrapper = self , wrapped = self ._tool_func )
298407
299- def __get__ (self , instance : Any , obj_type : Optional [Type ] = None ) -> "DecoratedFunctionTool[P, R]" :
408+ def __get__ (
409+ self , instance : Any , obj_type : Optional [Type ] = None
410+ ) -> "DecoratedFunctionTool[P, R]" :
300411 """Descriptor protocol implementation for proper method binding.
301412
302413 This method enables the decorated function to work correctly when used as a class method.
@@ -325,7 +436,9 @@ def my_tool():
325436 if instance is not None and not inspect .ismethod (self ._tool_func ):
326437 # Create a bound method
327438 tool_func = self ._tool_func .__get__ (instance , instance .__class__ )
328- return DecoratedFunctionTool (self ._tool_name , self ._tool_spec , tool_func , self ._metadata )
439+ return DecoratedFunctionTool (
440+ self ._tool_name , self ._tool_spec , tool_func , self ._metadata
441+ )
329442
330443 return self
331444
@@ -372,7 +485,9 @@ def tool_type(self) -> str:
372485 return "function"
373486
374487 @override
375- async def stream (self , tool_use : ToolUse , invocation_state : dict [str , Any ], ** kwargs : Any ) -> ToolGenerator :
488+ async def stream (
489+ self , tool_use : ToolUse , invocation_state : dict [str , Any ], ** kwargs : Any
490+ ) -> ToolGenerator :
376491 """Stream the tool with a tool use specification.
377492
378493 This method handles tool use streams from a Strands Agent. It validates the input,
@@ -403,7 +518,10 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw
403518 validated_input = self ._metadata .validate_input (tool_input )
404519
405520 # Pass along the agent if provided and expected by the function
406- if "agent" in invocation_state and "agent" in self ._metadata .signature .parameters :
521+ if (
522+ "agent" in invocation_state
523+ and "agent" in self ._metadata .signature .parameters
524+ ):
407525 validated_input ["agent" ] = invocation_state .get ("agent" )
408526
409527 # "Too few arguments" expected, hence the type ignore
@@ -468,21 +586,27 @@ def get_display_properties(self) -> dict[str, str]:
468586# Handle @decorator
469587@overload
470588def tool (__func : Callable [P , R ]) -> DecoratedFunctionTool [P , R ]: ...
589+
590+
471591# Handle @decorator()
472592@overload
473593def tool (
474594 description : Optional [str ] = None ,
475595 inputSchema : Optional [JSONSchema ] = None ,
476596 name : Optional [str ] = None ,
477597) -> Callable [[Callable [P , R ]], DecoratedFunctionTool [P , R ]]: ...
598+
599+
478600# Suppressing the type error because we want callers to be able to use both `tool` and `tool()` at the
479601# call site, but the actual implementation handles that and it's not representable via the type-system
480602def tool ( # type: ignore
481603 func : Optional [Callable [P , R ]] = None ,
482604 description : Optional [str ] = None ,
483605 inputSchema : Optional [JSONSchema ] = None ,
484606 name : Optional [str ] = None ,
485- ) -> Union [DecoratedFunctionTool [P , R ], Callable [[Callable [P , R ]], DecoratedFunctionTool [P , R ]]]:
607+ ) -> Union [
608+ DecoratedFunctionTool [P , R ], Callable [[Callable [P , R ]], DecoratedFunctionTool [P , R ]]
609+ ]:
486610 """Decorator that transforms a Python function into a Strands tool.
487611
488612 This decorator seamlessly enables a function to be called both as a regular Python function and as a Strands tool.
0 commit comments