Skip to content

Commit 6cc2bdb

Browse files
Annhiluccopybara-github
authored andcommitted
feat: Support IntEnums when processing JSON schemas
PiperOrigin-RevId: 774980819
1 parent 714452f commit 6cc2bdb

File tree

2 files changed

+24
-13
lines changed

2 files changed

+24
-13
lines changed

google/genai/_transformers.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -787,15 +787,25 @@ def _recurse(sub_schema: dict[str, Any]) -> dict[str, Any]:
787787
def _process_enum(
788788
enum: EnumMeta, client: _api_client.BaseApiClient
789789
) -> types.Schema:
790+
is_integer_enum = False
791+
790792
for member in enum: # type: ignore
791-
if not isinstance(member.value, str):
793+
if isinstance(member.value, int):
794+
is_integer_enum = True
795+
elif not isinstance(member.value, str):
792796
raise TypeError(
793-
f'Enum member {member.name} value must be a string, got'
797+
f'Enum member {member.name} value must be a string or integer, got'
794798
f' {type(member.value)}'
795799
)
796800

801+
enum_to_process = enum
802+
if is_integer_enum:
803+
str_members = [str(member.value) for member in enum] # type: ignore
804+
str_enum = Enum(enum.__name__, str_members, type=str) # type: ignore
805+
enum_to_process = str_enum
806+
797807
class Placeholder(pydantic.BaseModel):
798-
placeholder: enum # type: ignore[valid-type]
808+
placeholder: enum_to_process # type: ignore[valid-type]
799809

800810
enum_schema = Placeholder.model_json_schema()
801811
process_schema(enum_schema, client)

google/genai/tests/models/test_generate_content.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1508,17 +1508,18 @@ class IntegerEnum(Enum):
15081508
BRASS = 4
15091509
KEYBOARD = 5
15101510

1511-
with pytest.raises(TypeError) as e:
1512-
client.models.generate_content(
1513-
model='gemini-1.5-flash',
1514-
contents='What instrument plays multiple notes at once?',
1515-
config={
1516-
'response_mime_type': 'text/x.enum',
1517-
'response_schema': IntegerEnum,
1518-
},
1519-
)
1511+
response =client.models.generate_content(
1512+
model='gemini-1.5-flash',
1513+
contents='What instrument plays multiple notes at once?',
1514+
config={
1515+
'response_mime_type': 'text/x.enum',
1516+
'response_schema': IntegerEnum,
1517+
},
1518+
)
1519+
1520+
instrument_values = {str(member.value) for member in IntegerEnum}
15201521

1521-
assert 'value must be a string' in str(e)
1522+
assert response.text in instrument_values
15221523

15231524

15241525
def test_json_schema(client):

0 commit comments

Comments
 (0)