Skip to content

Commit 76e9b13

Browse files
sararobcopybara-github
authored andcommitted
feat: enable more types in FunctionDeclaration schema
PiperOrigin-RevId: 791668482
1 parent 987ccc8 commit 76e9b13

File tree

4 files changed

+392
-179
lines changed

4 files changed

+392
-179
lines changed

google/genai/_automatic_function_calling_util.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,39 @@
4141
}
4242

4343

44+
def _raise_for_unsupported_param(
45+
param: inspect.Parameter, func_name: str, exception: Union[Exception, type[Exception]]
46+
) -> None:
47+
raise ValueError(
48+
f'Failed to parse the parameter {param} of function {func_name} for'
49+
' automatic function calling.Automatic function calling works best with'
50+
' simpler function signature schema, consider manually parsing your'
51+
f' function declaration for function {func_name}.'
52+
) from exception
53+
54+
55+
def _handle_params_as_deferred_annotations(param: inspect.Parameter, annotation_under_future: dict[str, Any], name: str) -> inspect.Parameter:
56+
"""Catches the case when type hints are stored as strings."""
57+
if isinstance(param.annotation, str):
58+
param = param.replace(annotation=annotation_under_future[name])
59+
return param
60+
61+
62+
def _add_unevaluated_items_to_fixed_len_tuple_schema(
63+
json_schema: dict[str, Any]
64+
) -> dict[str, Any]:
65+
if (
66+
json_schema.get('maxItems')
67+
and (
68+
json_schema.get('prefixItems')
69+
and len(json_schema['prefixItems']) == json_schema['maxItems']
70+
)
71+
and json_schema.get('type') == 'array'
72+
):
73+
json_schema['unevaluatedItems'] = False
74+
return json_schema
75+
76+
4477
def _is_builtin_primitive_or_compound(
4578
annotation: inspect.Parameter.annotation, # type: ignore[valid-type]
4679
) -> bool:
@@ -92,7 +125,7 @@ def _is_default_value_compatible(
92125
return False
93126

94127

95-
def _parse_schema_from_parameter(
128+
def _parse_schema_from_parameter( # type: ignore[return]
96129
api_option: Literal['VERTEX_AI', 'GEMINI_API'],
97130
param: inspect.Parameter,
98131
func_name: str,
@@ -267,12 +300,7 @@ def _parse_schema_from_parameter(
267300
)
268301
schema.required = _get_required_fields(schema)
269302
return schema
270-
raise ValueError(
271-
f'Failed to parse the parameter {param} of function {func_name} for'
272-
' automatic function calling.Automatic function calling works best with'
273-
' simpler function signature schema, consider manually parsing your'
274-
f' function declaration for function {func_name}.'
275-
)
303+
_raise_for_unsupported_param(param, func_name, ValueError)
276304

277305

278306
def _get_required_fields(schema: types.Schema) -> Optional[list[str]]:

google/genai/tests/models/test_generate_content_tools.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
#
1515

16+
import collections
1617
import logging
1718
import sys
1819
import typing
@@ -826,6 +827,89 @@ def get_information(
826827
assert 'cat' in response.text
827828

828829

830+
def test_automatic_function_calling_with_union_operator(client):
831+
class AnimalObject(pydantic.BaseModel):
832+
name: str
833+
age: int
834+
species: str
835+
836+
def get_information(
837+
object_of_interest: str | AnimalObject,
838+
) -> str:
839+
if isinstance(object_of_interest, AnimalObject):
840+
return (
841+
f'The animal is of {object_of_interest.species} species and is named'
842+
f' {object_of_interest.name} is {object_of_interest.age} years old'
843+
)
844+
else:
845+
return f'The object of interest is {object_of_interest}'
846+
847+
response = client.models.generate_content(
848+
model='gemini-1.5-flash',
849+
contents=(
850+
'I have a one year old cat named Sundae, can you get the'
851+
' information of the cat for me?'
852+
),
853+
config={
854+
'tools': [get_information],
855+
'automatic_function_calling': {'ignore_call_history': True},
856+
},
857+
)
858+
assert response.text
859+
860+
861+
def test_automatic_function_calling_with_tuple_param(client):
862+
def output_latlng(
863+
latlng: tuple[float, float],
864+
) -> str:
865+
return f'The latitude is {latlng[0]} and the longitude is {latlng[1]}'
866+
867+
response = client.models.generate_content(
868+
model='gemini-1.5-flash',
869+
contents=(
870+
'The coordinates are (51.509, -0.118). What is the latitude and longitude?'
871+
),
872+
config={
873+
'tools': [output_latlng],
874+
'automatic_function_calling': {'ignore_call_history': True},
875+
},
876+
)
877+
assert response.text
878+
879+
880+
@pytest.mark.skipif(
881+
sys.version_info < (3, 10),
882+
reason='| is only supported in Python 3.10 and above.',
883+
)
884+
def test_automatic_function_calling_with_union_operator_return_type(client):
885+
def get_cheese_age(cheese: int) -> int | float:
886+
"""
887+
Retrieves data about the age of the cheese given its ID.
888+
889+
Args:
890+
cheese_id: The ID of the cheese.
891+
892+
Returns:
893+
An int or float of the age of the cheese.
894+
"""
895+
if cheese == 1:
896+
return 2.5
897+
elif cheese == 2:
898+
return 3
899+
else:
900+
return 0.0
901+
902+
response = client.models.generate_content(
903+
model='gemini-2.5-flash',
904+
contents='How old is the cheese with id 2?',
905+
config={
906+
'tools': [get_cheese_age],
907+
'automatic_function_calling': {'ignore_call_history': True},
908+
},
909+
)
910+
assert '3' in response.text
911+
912+
829913
def test_automatic_function_calling_with_parameterized_generic_union_type(
830914
client,
831915
):

0 commit comments

Comments
 (0)