|  | 
| 17 | 17 | from typing import Optional | 
| 18 | 18 | from unittest.mock import MagicMock | 
| 19 | 19 | 
 | 
|  | 20 | +import pydantic | 
|  | 21 | +import pytest | 
|  | 22 | + | 
| 20 | 23 | from google.adk.agents.invocation_context import InvocationContext | 
| 21 | 24 | from google.adk.sessions.session import Session | 
| 22 | 25 | from google.adk.tools.function_tool import FunctionTool | 
| 23 | 26 | from google.adk.tools.tool_context import ToolContext | 
| 24 |  | -import pydantic | 
| 25 |  | -import pytest | 
| 26 | 27 | 
 | 
| 27 | 28 | 
 | 
| 28 | 29 | class UserModel(pydantic.BaseModel): | 
| @@ -280,5 +281,121 @@ async def test_run_async_with_optional_pydantic_models(): | 
| 280 | 281 |   assert result["theme"] == "dark" | 
| 281 | 282 |   assert result["notifications"] is True | 
| 282 | 283 |   assert result["preferences_type"] == "PreferencesModel" | 
| 283 |  | -  assert result["preferences_type"] == "PreferencesModel" | 
| 284 |  | -  assert result["preferences_type"] == "PreferencesModel" | 
|  | 284 | + | 
|  | 285 | + | 
|  | 286 | +def function_with_list_of_pydantic_models(users: list[UserModel]) -> dict: | 
|  | 287 | +  """Function that takes a list of Pydantic models.""" | 
|  | 288 | +  return { | 
|  | 289 | +      "count": len(users), | 
|  | 290 | +      "names": [user.name for user in users], | 
|  | 291 | +      "ages": [user.age for user in users], | 
|  | 292 | +      "types": [type(user).__name__ for user in users], | 
|  | 293 | +  } | 
|  | 294 | + | 
|  | 295 | + | 
|  | 296 | +def function_with_optional_list_of_pydantic_models( | 
|  | 297 | +    users: Optional[list[UserModel]] = None, | 
|  | 298 | +) -> dict: | 
|  | 299 | +  """Function that takes an optional list of Pydantic models.""" | 
|  | 300 | +  if users is None: | 
|  | 301 | +    return {"count": 0, "names": []} | 
|  | 302 | +  return { | 
|  | 303 | +      "count": len(users), | 
|  | 304 | +      "names": [user.name for user in users], | 
|  | 305 | +  } | 
|  | 306 | + | 
|  | 307 | + | 
|  | 308 | +def test_preprocess_args_with_list_of_dicts_to_pydantic_models(): | 
|  | 309 | +  """Test _preprocess_args converts list of dicts to list of Pydantic models.""" | 
|  | 310 | +  tool = FunctionTool(function_with_list_of_pydantic_models) | 
|  | 311 | + | 
|  | 312 | +  input_args = { | 
|  | 313 | +      "users": [ | 
|  | 314 | +          {"name": "Alice", "age": 30, "email": "alice@example.com"}, | 
|  | 315 | +          {"name": "Bob", "age": 25}, | 
|  | 316 | +          {"name": "Charlie", "age": 35, "email": "charlie@example.com"}, | 
|  | 317 | +      ] | 
|  | 318 | +  } | 
|  | 319 | + | 
|  | 320 | +  processed_args = tool._preprocess_args(input_args) | 
|  | 321 | + | 
|  | 322 | +  # Check that the list of dicts was converted to a list of Pydantic models | 
|  | 323 | +  assert "users" in processed_args | 
|  | 324 | +  users = processed_args["users"] | 
|  | 325 | +  assert isinstance(users, list) | 
|  | 326 | +  assert len(users) == 3 | 
|  | 327 | + | 
|  | 328 | +  # Check each element is a Pydantic model with correct data | 
|  | 329 | +  assert isinstance(users[0], UserModel) | 
|  | 330 | +  assert users[0].name == "Alice" | 
|  | 331 | +  assert users[0].age == 30 | 
|  | 332 | +  assert users[0].email == "alice@example.com" | 
|  | 333 | + | 
|  | 334 | +  assert isinstance(users[1], UserModel) | 
|  | 335 | +  assert users[1].name == "Bob" | 
|  | 336 | +  assert users[1].age == 25 | 
|  | 337 | +  assert users[1].email is None | 
|  | 338 | + | 
|  | 339 | +  assert isinstance(users[2], UserModel) | 
|  | 340 | +  assert users[2].name == "Charlie" | 
|  | 341 | +  assert users[2].age == 35 | 
|  | 342 | +  assert users[2].email == "charlie@example.com" | 
|  | 343 | + | 
|  | 344 | + | 
|  | 345 | +def test_preprocess_args_with_optional_list_of_pydantic_models_none(): | 
|  | 346 | +  """Test _preprocess_args handles None for optional list parameter.""" | 
|  | 347 | +  tool = FunctionTool(function_with_optional_list_of_pydantic_models) | 
|  | 348 | + | 
|  | 349 | +  input_args = {"users": None} | 
|  | 350 | + | 
|  | 351 | +  processed_args = tool._preprocess_args(input_args) | 
|  | 352 | + | 
|  | 353 | +  # Check that None is preserved | 
|  | 354 | +  assert "users" in processed_args | 
|  | 355 | +  assert processed_args["users"] is None | 
|  | 356 | + | 
|  | 357 | + | 
|  | 358 | +def test_preprocess_args_with_optional_list_of_pydantic_models_with_data(): | 
|  | 359 | +  """Test _preprocess_args converts list for optional list parameter.""" | 
|  | 360 | +  tool = FunctionTool(function_with_optional_list_of_pydantic_models) | 
|  | 361 | + | 
|  | 362 | +  input_args = { | 
|  | 363 | +      "users": [ | 
|  | 364 | +          {"name": "Alice", "age": 30}, | 
|  | 365 | +          {"name": "Bob", "age": 25}, | 
|  | 366 | +      ] | 
|  | 367 | +  } | 
|  | 368 | + | 
|  | 369 | +  processed_args = tool._preprocess_args(input_args) | 
|  | 370 | + | 
|  | 371 | +  # Check conversion | 
|  | 372 | +  assert "users" in processed_args | 
|  | 373 | +  users = processed_args["users"] | 
|  | 374 | +  assert len(users) == 2 | 
|  | 375 | +  assert all(isinstance(user, UserModel) for user in users) | 
|  | 376 | +  assert users[0].name == "Alice" | 
|  | 377 | +  assert users[1].name == "Bob" | 
|  | 378 | + | 
|  | 379 | + | 
|  | 380 | +def test_preprocess_args_with_list_skips_invalid_items(): | 
|  | 381 | +  """Test _preprocess_args skips items that fail validation.""" | 
|  | 382 | +  tool = FunctionTool(function_with_list_of_pydantic_models) | 
|  | 383 | + | 
|  | 384 | +  input_args = { | 
|  | 385 | +      "users": [ | 
|  | 386 | +          {"name": "Alice", "age": 30}, | 
|  | 387 | +          {"name": "Invalid"},  # Missing required 'age' field | 
|  | 388 | +          {"name": "Bob", "age": 25}, | 
|  | 389 | +      ] | 
|  | 390 | +  } | 
|  | 391 | + | 
|  | 392 | +  processed_args = tool._preprocess_args(input_args) | 
|  | 393 | + | 
|  | 394 | +  # Check that invalid item was skipped | 
|  | 395 | +  assert "users" in processed_args | 
|  | 396 | +  users = processed_args["users"] | 
|  | 397 | +  assert len(users) == 2  # Only 2 valid items | 
|  | 398 | +  assert users[0].name == "Alice" | 
|  | 399 | +  assert users[0].age == 30 | 
|  | 400 | +  assert users[1].name == "Bob" | 
|  | 401 | +  assert users[1].age == 25 | 
0 commit comments