Skip to content

Commit 985e948

Browse files
committed
fix: support list of pydantic model to FunctionTool arg annotation
1 parent 78e74b5 commit 985e948

File tree

2 files changed

+154
-12
lines changed

2 files changed

+154
-12
lines changed

src/google/adk/tools/function_tool.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,13 @@
1616

1717
import inspect
1818
import logging
19-
from typing import Any
20-
from typing import Callable
21-
from typing import get_args
22-
from typing import get_origin
23-
from typing import Optional
24-
from typing import Union
19+
from typing import Any, Callable, Optional, Union, get_args, get_origin
2520

26-
from google.genai import types
2721
import pydantic
2822
from typing_extensions import override
2923

24+
from google.genai import types
25+
3026
from ..utils.context_utils import Aclosing
3127
from ._automatic_function_calling_util import build_function_declaration
3228
from .base_tool import BaseTool
@@ -102,6 +98,7 @@ def _preprocess_args(self, args: dict[str, Any]) -> dict[str, Any]:
10298
10399
Currently handles:
104100
- Converting JSON dictionaries to Pydantic model instances where expected
101+
- Converting lists of JSON dictionaries to lists of Pydantic model instances
105102
106103
Future extensions could include:
107104
- Type coercion for other complex types
@@ -129,8 +126,36 @@ def _preprocess_args(self, args: dict[str, Any]) -> dict[str, Any]:
129126
if len(non_none_types) == 1:
130127
target_type = non_none_types[0]
131128

129+
# Check if the target type is a list
130+
if get_origin(target_type) is list:
131+
list_args = get_args(target_type)
132+
if list_args:
133+
element_type = list_args[0]
134+
135+
# Check if the element type is a Pydantic model
136+
if inspect.isclass(element_type) and issubclass(
137+
element_type, pydantic.BaseModel
138+
):
139+
# Skip conversion if the value is None
140+
if args[param_name] is None:
141+
continue
142+
143+
# Convert list elements to Pydantic models
144+
if isinstance(args[param_name], list):
145+
converted_list = []
146+
for item in args[param_name]:
147+
try:
148+
converted_list.append(element_type.model_validate(item))
149+
except Exception as e:
150+
# Skip items that fail validation
151+
logger.warning(
152+
f"Skipping item in '{param_name}': "
153+
f'Failed to convert to {element_type.__name__}: {e}'
154+
)
155+
converted_args[param_name] = converted_list
156+
132157
# Check if the target type is a Pydantic model
133-
if inspect.isclass(target_type) and issubclass(
158+
elif inspect.isclass(target_type) and issubclass(
134159
target_type, pydantic.BaseModel
135160
):
136161
# Skip conversion if the value is None and the parameter is Optional

tests/unittests/tools/test_function_tool_pydantic.py

Lines changed: 121 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@
1717
from typing import Optional
1818
from unittest.mock import MagicMock
1919

20+
import pydantic
21+
import pytest
22+
2023
from google.adk.agents.invocation_context import InvocationContext
2124
from google.adk.sessions.session import Session
2225
from google.adk.tools.function_tool import FunctionTool
2326
from google.adk.tools.tool_context import ToolContext
24-
import pydantic
25-
import pytest
2627

2728

2829
class UserModel(pydantic.BaseModel):
@@ -280,5 +281,121 @@ async def test_run_async_with_optional_pydantic_models():
280281
assert result["theme"] == "dark"
281282
assert result["notifications"] is True
282283
assert result["preferences_type"] == "PreferencesModel"
283-
assert result["preferences_type"] == "PreferencesModel"
284-
assert result["preferences_type"] == "PreferencesModel"
284+
285+
286+
def function_with_list_of_pydantic_models(users: list[UserModel]) -> dict:
287+
"""Function that takes a list of Pydantic models."""
288+
return {
289+
"count": len(users),
290+
"names": [user.name for user in users],
291+
"ages": [user.age for user in users],
292+
"types": [type(user).__name__ for user in users],
293+
}
294+
295+
296+
def function_with_optional_list_of_pydantic_models(
297+
users: Optional[list[UserModel]] = None,
298+
) -> dict:
299+
"""Function that takes an optional list of Pydantic models."""
300+
if users is None:
301+
return {"count": 0, "names": []}
302+
return {
303+
"count": len(users),
304+
"names": [user.name for user in users],
305+
}
306+
307+
308+
def test_preprocess_args_with_list_of_dicts_to_pydantic_models():
309+
"""Test _preprocess_args converts list of dicts to list of Pydantic models."""
310+
tool = FunctionTool(function_with_list_of_pydantic_models)
311+
312+
input_args = {
313+
"users": [
314+
{"name": "Alice", "age": 30, "email": "alice@example.com"},
315+
{"name": "Bob", "age": 25},
316+
{"name": "Charlie", "age": 35, "email": "charlie@example.com"},
317+
]
318+
}
319+
320+
processed_args = tool._preprocess_args(input_args)
321+
322+
# Check that the list of dicts was converted to a list of Pydantic models
323+
assert "users" in processed_args
324+
users = processed_args["users"]
325+
assert isinstance(users, list)
326+
assert len(users) == 3
327+
328+
# Check each element is a Pydantic model with correct data
329+
assert isinstance(users[0], UserModel)
330+
assert users[0].name == "Alice"
331+
assert users[0].age == 30
332+
assert users[0].email == "alice@example.com"
333+
334+
assert isinstance(users[1], UserModel)
335+
assert users[1].name == "Bob"
336+
assert users[1].age == 25
337+
assert users[1].email is None
338+
339+
assert isinstance(users[2], UserModel)
340+
assert users[2].name == "Charlie"
341+
assert users[2].age == 35
342+
assert users[2].email == "charlie@example.com"
343+
344+
345+
def test_preprocess_args_with_optional_list_of_pydantic_models_none():
346+
"""Test _preprocess_args handles None for optional list parameter."""
347+
tool = FunctionTool(function_with_optional_list_of_pydantic_models)
348+
349+
input_args = {"users": None}
350+
351+
processed_args = tool._preprocess_args(input_args)
352+
353+
# Check that None is preserved
354+
assert "users" in processed_args
355+
assert processed_args["users"] is None
356+
357+
358+
def test_preprocess_args_with_optional_list_of_pydantic_models_with_data():
359+
"""Test _preprocess_args converts list for optional list parameter."""
360+
tool = FunctionTool(function_with_optional_list_of_pydantic_models)
361+
362+
input_args = {
363+
"users": [
364+
{"name": "Alice", "age": 30},
365+
{"name": "Bob", "age": 25},
366+
]
367+
}
368+
369+
processed_args = tool._preprocess_args(input_args)
370+
371+
# Check conversion
372+
assert "users" in processed_args
373+
users = processed_args["users"]
374+
assert len(users) == 2
375+
assert all(isinstance(user, UserModel) for user in users)
376+
assert users[0].name == "Alice"
377+
assert users[1].name == "Bob"
378+
379+
380+
def test_preprocess_args_with_list_skips_invalid_items():
381+
"""Test _preprocess_args skips items that fail validation."""
382+
tool = FunctionTool(function_with_list_of_pydantic_models)
383+
384+
input_args = {
385+
"users": [
386+
{"name": "Alice", "age": 30},
387+
{"name": "Invalid"}, # Missing required 'age' field
388+
{"name": "Bob", "age": 25},
389+
]
390+
}
391+
392+
processed_args = tool._preprocess_args(input_args)
393+
394+
# Check that invalid item was skipped
395+
assert "users" in processed_args
396+
users = processed_args["users"]
397+
assert len(users) == 2 # Only 2 valid items
398+
assert users[0].name == "Alice"
399+
assert users[0].age == 30
400+
assert users[1].name == "Bob"
401+
assert users[1].age == 25

0 commit comments

Comments
 (0)