Skip to content

Commit cb21348

Browse files
afarntrogVamil Gandhi
authored andcommitted
feat: improve structured output tool circular reference handling (#817)
* feat: improve structured output tool circular reference handling and optional field detection - Move circular reference detection earlier in schema flattening process - Simplify optional field detection using field.is_required() instead of Union type inspection - Add comprehensive test coverage for circular reference scenarios - Fix handling of fields with default values that make them optional
1 parent a647a57 commit cb21348

File tree

2 files changed

+84
-21
lines changed

2 files changed

+84
-21
lines changed

src/strands/tools/structured_output.py

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,16 @@ def _flatten_schema(schema: Dict[str, Any]) -> Dict[str, Any]:
2727
"properties": {},
2828
}
2929

30-
# Add title if present
3130
if "title" in schema:
3231
flattened["title"] = schema["title"]
3332

34-
# Add description from schema if present, or use model docstring
3533
if "description" in schema and schema["description"]:
3634
flattened["description"] = schema["description"]
3735

3836
# Process properties
3937
required_props: list[str] = []
38+
if "properties" not in schema and "$ref" in schema:
39+
raise ValueError("Circular reference detected and not supported.")
4040
if "properties" in schema:
4141
required_props = []
4242
for prop_name, prop_value in schema["properties"].items():
@@ -76,9 +76,6 @@ def _flatten_schema(schema: Dict[str, Any]) -> Dict[str, Any]:
7676

7777
if len(required_props) > 0:
7878
flattened["required"] = required_props
79-
else:
80-
raise ValueError("Circular reference detected and not supported")
81-
8279
return flattened
8380

8481

@@ -325,21 +322,7 @@ def _expand_nested_properties(schema: Dict[str, Any], model: Type[BaseModel]) ->
325322
continue
326323

327324
field_type = field.annotation
328-
329-
# Handle Optional types
330-
is_optional = False
331-
if (
332-
field_type is not None
333-
and hasattr(field_type, "__origin__")
334-
and field_type.__origin__ is Union
335-
and hasattr(field_type, "__args__")
336-
):
337-
# Look for Optional[BaseModel]
338-
for arg in field_type.__args__:
339-
if arg is type(None):
340-
is_optional = True
341-
elif isinstance(arg, type) and issubclass(arg, BaseModel):
342-
field_type = arg
325+
is_optional = not field.is_required()
343326

344327
# If this is a BaseModel field, expand its properties with full details
345328
if isinstance(field_type, type) and issubclass(field_type, BaseModel):

tests/strands/tools/test_structured_output.py

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Literal, Optional
1+
from typing import List, Literal, Optional
22

33
import pytest
44
from pydantic import BaseModel, Field
@@ -157,6 +157,7 @@ def test_convert_pydantic_to_tool_spec_multiple_same_type():
157157
"user2": {
158158
"type": ["object", "null"],
159159
"description": "The second user",
160+
"title": "UserWithPlanet",
160161
"properties": {
161162
"name": {"description": "The name of the user", "title": "Name", "type": "string"},
162163
"age": {
@@ -208,6 +209,85 @@ class NodeWithCircularRef(BaseModel):
208209
convert_pydantic_to_tool_spec(NodeWithCircularRef)
209210

210211

212+
def test_convert_pydantic_with_circular_required_dependency():
213+
"""Test that the tool handles circular dependencies gracefully."""
214+
215+
class NodeWithCircularRef(BaseModel):
216+
"""A node with a circular reference to itself."""
217+
218+
name: str = Field(description="The name of the node")
219+
parent: "NodeWithCircularRef"
220+
221+
with pytest.raises(ValueError, match="Circular reference detected and not supported"):
222+
convert_pydantic_to_tool_spec(NodeWithCircularRef)
223+
224+
225+
def test_convert_pydantic_with_circular_optional_dependency():
226+
"""Test that the tool handles circular dependencies gracefully."""
227+
228+
class NodeWithCircularRef(BaseModel):
229+
"""A node with a circular reference to itself."""
230+
231+
name: str = Field(description="The name of the node")
232+
parent: Optional["NodeWithCircularRef"] = None
233+
234+
with pytest.raises(ValueError, match="Circular reference detected and not supported"):
235+
convert_pydantic_to_tool_spec(NodeWithCircularRef)
236+
237+
238+
def test_convert_pydantic_with_circular_optional_dependenc_not_using_optional_typing():
239+
"""Test that the tool handles circular dependencies gracefully."""
240+
241+
class NodeWithCircularRef(BaseModel):
242+
"""A node with a circular reference to itself."""
243+
244+
name: str = Field(description="The name of the node")
245+
parent: "NodeWithCircularRef" = None
246+
247+
with pytest.raises(ValueError, match="Circular reference detected and not supported"):
248+
convert_pydantic_to_tool_spec(NodeWithCircularRef)
249+
250+
251+
def test_conversion_works_with_fields_that_are_not_marked_as_optional_but_have_a_default_value_which_makes_them_optional(): # noqa E501
252+
class Family(BaseModel):
253+
ages: List[str] = Field(default_factory=list)
254+
names: List[str] = Field(default_factory=list)
255+
256+
converted_output = convert_pydantic_to_tool_spec(Family)
257+
expected_output = {
258+
"name": "Family",
259+
"description": "Family structured output tool",
260+
"inputSchema": {
261+
"json": {
262+
"type": "object",
263+
"properties": {
264+
"ages": {
265+
"items": {"type": "string"},
266+
"title": "Ages",
267+
"type": ["array", "null"],
268+
},
269+
"names": {
270+
"items": {"type": "string"},
271+
"title": "Names",
272+
"type": ["array", "null"],
273+
},
274+
},
275+
"title": "Family",
276+
}
277+
},
278+
}
279+
assert converted_output == expected_output
280+
281+
282+
def test_marks_fields_as_optional_for_model_w_fields_that_are_not_marked_as_optional_but_have_a_default_value_which_makes_them_optional(): # noqa E501
283+
class Family(BaseModel):
284+
ages: List[str] = Field(default_factory=list)
285+
names: List[str] = Field(default_factory=list)
286+
287+
converted_output = convert_pydantic_to_tool_spec(Family)
288+
assert "null" in converted_output["inputSchema"]["json"]["properties"]["ages"]["type"]
289+
290+
211291
def test_convert_pydantic_with_custom_description():
212292
"""Test that custom descriptions override model docstrings."""
213293

0 commit comments

Comments
 (0)