Skip to content

Commit a7b8074

Browse files
authored
Add new _GeminiThoughtPart to adhere to Gemini response (#1897)
1 parent b5f26c3 commit a7b8074

File tree

2 files changed

+39
-3
lines changed

2 files changed

+39
-3
lines changed

pydantic_ai_slim/pydantic_ai/models/gemini.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,8 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
431431
if maybe_event is not None: # pragma: no branch
432432
yield maybe_event
433433
else:
434-
assert 'function_response' in gemini_part, f'Unexpected part: {gemini_part}' # pragma: no cover
434+
if not any([key in gemini_part for key in ['function_response', 'thought']]):
435+
raise AssertionError(f'Unexpected part: {gemini_part}') # pragma: no cover
435436

436437
async def _get_gemini_responses(self) -> AsyncIterator[_GeminiResponse]:
437438
# This method exists to ensure we only yield completed items, so we don't need to worry about
@@ -605,6 +606,11 @@ class _GeminiFileDataPart(TypedDict):
605606
file_data: Annotated[_GeminiFileData, pydantic.Field(alias='fileData')]
606607

607608

609+
class _GeminiThoughtPart(TypedDict):
610+
thought: bool
611+
thought_signature: Annotated[str, pydantic.Field(alias='thoughtSignature')]
612+
613+
608614
class _GeminiFunctionCallPart(TypedDict):
609615
function_call: Annotated[_GeminiFunctionCall, pydantic.Field(alias='functionCall')]
610616

@@ -665,6 +671,8 @@ def _part_discriminator(v: Any) -> str:
665671
return 'inline_data' # pragma: no cover
666672
elif 'fileData' in v:
667673
return 'file_data' # pragma: no cover
674+
elif 'thought' in v:
675+
return 'thought'
668676
elif 'functionCall' in v or 'function_call' in v:
669677
return 'function_call'
670678
elif 'functionResponse' in v or 'function_response' in v:
@@ -682,6 +690,7 @@ def _part_discriminator(v: Any) -> str:
682690
Annotated[_GeminiFunctionResponsePart, pydantic.Tag('function_response')],
683691
Annotated[_GeminiInlineDataPart, pydantic.Tag('inline_data')],
684692
Annotated[_GeminiFileDataPart, pydantic.Tag('file_data')],
693+
Annotated[_GeminiThoughtPart, pydantic.Tag('thought')],
685694
],
686695
pydantic.Discriminator(_part_discriminator),
687696
]

tests/models/test_gemini.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,13 @@
4141
_GeminiCandidates,
4242
_GeminiContent,
4343
_GeminiFunction,
44+
_GeminiFunctionCall,
4445
_GeminiFunctionCallingConfig,
46+
_GeminiFunctionCallPart,
4547
_GeminiResponse,
4648
_GeminiSafetyRating,
49+
_GeminiTextPart,
50+
_GeminiThoughtPart,
4751
_GeminiToolConfig,
4852
_GeminiTools,
4953
_GeminiUsageMetaData,
@@ -898,8 +902,11 @@ async def test_stream_text_heterogeneous(get_gemini_client: GetGeminiClient):
898902
_GeminiContent(
899903
role='model',
900904
parts=[
901-
{'text': 'foo'},
902-
{'function_call': {'name': 'get_location', 'args': {'loc_name': 'San Fransisco'}}},
905+
_GeminiThoughtPart(thought=True, thought_signature='test-signature-value'),
906+
_GeminiTextPart(text='foo'),
907+
_GeminiFunctionCallPart(
908+
function_call=_GeminiFunctionCall(name='get_location', args={'loc_name': 'San Fransisco'})
909+
),
903910
],
904911
)
905912
),
@@ -1327,3 +1334,23 @@ async def test_gemini_no_finish_reason(get_gemini_client: GetGeminiClient):
13271334
for message in result.all_messages():
13281335
if isinstance(message, ModelResponse):
13291336
assert message.vendor_details is None
1337+
1338+
1339+
async def test_response_with_thought_part(get_gemini_client: GetGeminiClient):
1340+
"""Tests that a response containing a 'thought' part can be parsed."""
1341+
content_with_thought = _GeminiContent(
1342+
role='model',
1343+
parts=[
1344+
_GeminiThoughtPart(thought=True, thought_signature='test-signature-value'),
1345+
_GeminiTextPart(text='Hello from thought test'),
1346+
],
1347+
)
1348+
response = gemini_response(content_with_thought)
1349+
gemini_client = get_gemini_client(response)
1350+
m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(http_client=gemini_client))
1351+
agent = Agent(m)
1352+
1353+
result = await agent.run('Test with thought')
1354+
1355+
assert result.output == 'Hello from thought test'
1356+
assert result.usage() == snapshot(Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3))

0 commit comments

Comments
 (0)