Skip to content

Commit ed85738

Browse files
authored
Add return docstring to the function description (#1207)
1 parent 79faa27 commit ed85738

File tree

2 files changed

+177
-10
lines changed

2 files changed

+177
-10
lines changed

pydantic_ai_slim/pydantic_ai/_griffe.py

+29-2
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,16 @@ def doc_descriptions(
2222
) -> tuple[str, dict[str, str]]:
2323
"""Extract the function description and parameter descriptions from a function's docstring.
2424
25+
The function parses the docstring using the specified format (or infers it if 'auto')
26+
and extracts both the main description and parameter descriptions. If a returns section
27+
is present in the docstring, the main description will be formatted as XML.
28+
2529
Returns:
26-
A tuple of (main function description, parameter descriptions).
30+
A tuple containing:
31+
- str: Main description string, which may be either:
32+
* Plain text if no returns section is present
33+
* XML-formatted if returns section exists, including <summary> and <returns> tags
34+
- dict[str, str]: Dictionary mapping parameter names to their descriptions
2735
"""
2836
doc = func.__doc__
2937
if doc is None:
@@ -33,7 +41,14 @@ def doc_descriptions(
3341
parent = cast(GriffeObject, sig)
3442

3543
docstring_style = _infer_docstring_style(doc) if docstring_format == 'auto' else docstring_format
36-
docstring = Docstring(doc, lineno=1, parser=docstring_style, parent=parent)
44+
docstring = Docstring(
45+
doc,
46+
lineno=1,
47+
parser=docstring_style,
48+
parent=parent,
49+
# https://mkdocstrings.github.io/griffe/reference/docstrings/#google-options
50+
parser_options={'returns_named_value': False, 'returns_multiple_items': False},
51+
)
3752
with _disable_griffe_logging():
3853
sections = docstring.parse()
3954

@@ -45,6 +60,18 @@ def doc_descriptions(
4560
if main := next((p for p in sections if p.kind == DocstringSectionKind.text), None):
4661
main_desc = main.value
4762

63+
if return_ := next((p for p in sections if p.kind == DocstringSectionKind.returns), None):
64+
return_statement = return_.value[0]
65+
return_desc = return_statement.description
66+
return_type = return_statement.annotation
67+
type_tag = f'<type>{return_type}</type>\n' if return_type else ''
68+
return_xml = f'<returns>\n{type_tag}<description>{return_desc}</description>\n</returns>'
69+
70+
if main_desc:
71+
main_desc = f'<summary>{main_desc}</summary>\n{return_xml}'
72+
else:
73+
main_desc = return_xml
74+
4875
return main_desc, params
4976

5077

tests/test_tools.py

+148-8
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,6 @@ def sphinx_style_docstring(foo: int, /) -> str: # pragma: no cover
122122
"""Sphinx style docstring.
123123
124124
:param foo: The foo thing.
125-
:return: The result.
126125
"""
127126
return str(foo)
128127

@@ -187,6 +186,152 @@ def test_docstring_numpy(docstring_format: Literal['numpy', 'auto']):
187186
)
188187

189188

189+
def test_google_style_with_returns():
190+
agent = Agent(FunctionModel(get_json_schema))
191+
192+
def my_tool(x: int) -> str: # pragma: no cover
193+
"""A function that does something.
194+
195+
Args:
196+
x: The input value.
197+
198+
Returns:
199+
str: The result as a string.
200+
"""
201+
return str(x)
202+
203+
agent.tool_plain(my_tool)
204+
result = agent.run_sync('Hello')
205+
json_schema = json.loads(result.data)
206+
assert json_schema == snapshot(
207+
{
208+
'name': 'my_tool',
209+
'description': """\
210+
<summary>A function that does something.</summary>
211+
<returns>
212+
<type>str</type>
213+
<description>The result as a string.</description>
214+
</returns>\
215+
""",
216+
'parameters_json_schema': {
217+
'additionalProperties': False,
218+
'properties': {'x': {'description': 'The input value.', 'type': 'integer'}},
219+
'required': ['x'],
220+
'type': 'object',
221+
},
222+
'outer_typed_dict_key': None,
223+
}
224+
)
225+
226+
227+
def test_sphinx_style_with_returns():
228+
agent = Agent(FunctionModel(get_json_schema))
229+
230+
def my_tool(x: int) -> str: # pragma: no cover
231+
"""A sphinx function with returns.
232+
233+
:param x: The input value.
234+
:rtype: str
235+
:return: The result as a string with type.
236+
"""
237+
return str(x)
238+
239+
agent.tool_plain(docstring_format='sphinx')(my_tool)
240+
result = agent.run_sync('Hello')
241+
json_schema = json.loads(result.data)
242+
assert json_schema == snapshot(
243+
{
244+
'name': 'my_tool',
245+
'description': """\
246+
<summary>A sphinx function with returns.</summary>
247+
<returns>
248+
<type>str</type>
249+
<description>The result as a string with type.</description>
250+
</returns>\
251+
""",
252+
'parameters_json_schema': {
253+
'additionalProperties': False,
254+
'properties': {'x': {'description': 'The input value.', 'type': 'integer'}},
255+
'required': ['x'],
256+
'type': 'object',
257+
},
258+
'outer_typed_dict_key': None,
259+
}
260+
)
261+
262+
263+
def test_numpy_style_with_returns():
264+
agent = Agent(FunctionModel(get_json_schema))
265+
266+
def my_tool(x: int) -> str: # pragma: no cover
267+
"""A numpy function with returns.
268+
269+
Parameters
270+
----------
271+
x : int
272+
The input value.
273+
274+
Returns
275+
-------
276+
str
277+
The result as a string with type.
278+
"""
279+
return str(x)
280+
281+
agent.tool_plain(docstring_format='numpy')(my_tool)
282+
result = agent.run_sync('Hello')
283+
json_schema = json.loads(result.data)
284+
assert json_schema == snapshot(
285+
{
286+
'name': 'my_tool',
287+
'description': """\
288+
<summary>A numpy function with returns.</summary>
289+
<returns>
290+
<type>str</type>
291+
<description>The result as a string with type.</description>
292+
</returns>\
293+
""",
294+
'parameters_json_schema': {
295+
'additionalProperties': False,
296+
'properties': {'x': {'description': 'The input value.', 'type': 'integer'}},
297+
'required': ['x'],
298+
'type': 'object',
299+
},
300+
'outer_typed_dict_key': None,
301+
}
302+
)
303+
304+
305+
def only_returns_type() -> str: # pragma: no cover
306+
"""
307+
308+
Returns:
309+
str: The result as a string.
310+
"""
311+
return 'foo'
312+
313+
314+
def test_only_returns_type():
315+
agent = Agent(FunctionModel(get_json_schema))
316+
agent.tool_plain(only_returns_type)
317+
318+
result = agent.run_sync('Hello')
319+
json_schema = json.loads(result.data)
320+
assert json_schema == snapshot(
321+
{
322+
'name': 'only_returns_type',
323+
'description': """\
324+
<returns>
325+
<type>str</type>
326+
<description>The result as a string.</description>
327+
</returns>\
328+
""",
329+
'parameters_json_schema': {'additionalProperties': False, 'properties': {}, 'type': 'object'},
330+
'outer_typed_dict_key': None,
331+
}
332+
)
333+
334+
190335
def unknown_docstring(**kwargs: int) -> str: # pragma: no cover
191336
"""Unknown style docstring."""
192337
return str(kwargs)
@@ -572,11 +717,7 @@ def ctx_tool(ctx: RunContext[int], x: int) -> int:
572717

573718

574719
async def tool_without_return_annotation_in_docstring() -> str: # pragma: no cover
575-
"""A tool that documents what it returns but doesn't have a return annotation in the docstring.
576-
577-
Returns:
578-
A value.
579-
"""
720+
"""A tool that documents what it returns but doesn't have a return annotation in the docstring."""
580721

581722
return ''
582723

@@ -591,8 +732,7 @@ def test_suppress_griffe_logging(caplog: LogCaptureFixture):
591732
json_schema = json.loads(result.data)
592733
assert json_schema == snapshot(
593734
{
594-
'description': "A tool that documents what it returns but doesn't have a "
595-
'return annotation in the docstring.',
735+
'description': "A tool that documents what it returns but doesn't have a return annotation in the docstring.",
596736
'name': 'tool_without_return_annotation_in_docstring',
597737
'outer_typed_dict_key': None,
598738
'parameters_json_schema': {'additionalProperties': False, 'properties': {}, 'type': 'object'},

0 commit comments

Comments
 (0)