|
13 | 13 | # limitations under the License.
|
14 | 14 | #
|
15 | 15 |
|
| 16 | +import collections |
16 | 17 | import logging
|
17 | 18 | import sys
|
18 | 19 | import typing
|
@@ -826,6 +827,89 @@ def get_information(
|
826 | 827 | assert 'cat' in response.text
|
827 | 828 |
|
828 | 829 |
|
| 830 | +def test_automatic_function_calling_with_union_operator(client): |
| 831 | + class AnimalObject(pydantic.BaseModel): |
| 832 | + name: str |
| 833 | + age: int |
| 834 | + species: str |
| 835 | + |
| 836 | + def get_information( |
| 837 | + object_of_interest: str | AnimalObject, |
| 838 | + ) -> str: |
| 839 | + if isinstance(object_of_interest, AnimalObject): |
| 840 | + return ( |
| 841 | + f'The animal is of {object_of_interest.species} species and is named' |
| 842 | + f' {object_of_interest.name} is {object_of_interest.age} years old' |
| 843 | + ) |
| 844 | + else: |
| 845 | + return f'The object of interest is {object_of_interest}' |
| 846 | + |
| 847 | + response = client.models.generate_content( |
| 848 | + model='gemini-1.5-flash', |
| 849 | + contents=( |
| 850 | + 'I have a one year old cat named Sundae, can you get the' |
| 851 | + ' information of the cat for me?' |
| 852 | + ), |
| 853 | + config={ |
| 854 | + 'tools': [get_information], |
| 855 | + 'automatic_function_calling': {'ignore_call_history': True}, |
| 856 | + }, |
| 857 | + ) |
| 858 | + assert response.text |
| 859 | + |
| 860 | + |
| 861 | +def test_automatic_function_calling_with_tuple_param(client): |
| 862 | + def output_latlng( |
| 863 | + latlng: tuple[float, float], |
| 864 | + ) -> str: |
| 865 | + return f'The latitude is {latlng[0]} and the longitude is {latlng[1]}' |
| 866 | + |
| 867 | + response = client.models.generate_content( |
| 868 | + model='gemini-1.5-flash', |
| 869 | + contents=( |
| 870 | + 'The coordinates are (51.509, -0.118). What is the latitude and longitude?' |
| 871 | + ), |
| 872 | + config={ |
| 873 | + 'tools': [output_latlng], |
| 874 | + 'automatic_function_calling': {'ignore_call_history': True}, |
| 875 | + }, |
| 876 | + ) |
| 877 | + assert response.text |
| 878 | + |
| 879 | + |
| 880 | +@pytest.mark.skipif( |
| 881 | + sys.version_info < (3, 10), |
| 882 | + reason='| is only supported in Python 3.10 and above.', |
| 883 | +) |
| 884 | +def test_automatic_function_calling_with_union_operator_return_type(client): |
| 885 | + def get_cheese_age(cheese: int) -> int | float: |
| 886 | + """ |
| 887 | + Retrieves data about the age of the cheese given its ID. |
| 888 | +
|
| 889 | + Args: |
| 890 | + cheese_id: The ID of the cheese. |
| 891 | +
|
| 892 | + Returns: |
| 893 | + An int or float of the age of the cheese. |
| 894 | + """ |
| 895 | + if cheese == 1: |
| 896 | + return 2.5 |
| 897 | + elif cheese == 2: |
| 898 | + return 3 |
| 899 | + else: |
| 900 | + return 0.0 |
| 901 | + |
| 902 | + response = client.models.generate_content( |
| 903 | + model='gemini-2.5-flash', |
| 904 | + contents='How old is the cheese with id 2?', |
| 905 | + config={ |
| 906 | + 'tools': [get_cheese_age], |
| 907 | + 'automatic_function_calling': {'ignore_call_history': True}, |
| 908 | + }, |
| 909 | + ) |
| 910 | + assert '3' in response.text |
| 911 | + |
| 912 | + |
829 | 913 | def test_automatic_function_calling_with_parameterized_generic_union_type(
|
830 | 914 | client,
|
831 | 915 | ):
|
|
0 commit comments