Skip to content

Support field labels for GeminiModel and GoogleModel on Vertex AI #1056

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
6901001
Support field `labels` for Gemini models
vricciardulli Mar 4, 2025
3de85dc
test that GLA model returns 400 if labels provided
vricciardulli Mar 4, 2025
1301b9e
Merge remote-tracking branch 'upstream/main' into support-labels-fiel…
vricciardulli Mar 5, 2025
5bfd3dc
adapt to latest provider-based changes
vricciardulli Mar 5, 2025
492c378
refactor tests
vricciardulli Mar 5, 2025
dd12b26
remove wrong deprecation skip line
vricciardulli Mar 5, 2025
f59754f
Merge remote-tracking branch 'upstream/main' into support-labels-fiel…
vricciardulli Mar 5, 2025
28d03b6
Merge remote-tracking branch 'upstream/main' into support-labels-fiel…
vricciardulli Mar 6, 2025
39b96eb
Merge remote-tracking branch 'upstream/main' into support-labels-fiel…
vricciardulli Mar 11, 2025
e8f541d
Merge remote-tracking branch 'upstream/main' into support-labels-fiel…
vricciardulli Mar 12, 2025
3f60d30
Merge remote-tracking branch 'upstream/main' into support-labels-fiel…
vricciardulli Mar 13, 2025
d8ef023
Merge remote-tracking branch 'upstream/main' into support-labels-fiel…
vricciardulli Mar 13, 2025
07e0950
Merge remote-tracking branch 'upstream/main' into support-labels-fiel…
vricciardulli Mar 13, 2025
3e94e88
Merge remote-tracking branch 'upstream/main' into support-labels-fiel…
vricciardulli Mar 14, 2025
b6fe08c
Merge remote-tracking branch 'upstream/main' into support-labels-fiel…
vricciardulli Mar 15, 2025
1521813
Merge remote-tracking branch 'upstream/main' into support-labels-fiel…
vricciardulli Mar 15, 2025
46ad619
Merge remote-tracking branch 'upstream/main' into support-labels-fiel…
vricciardulli Mar 18, 2025
894a1f8
Merge remote-tracking branch 'upstream/main' into support-labels-fiel…
vricciardulli Mar 21, 2025
ea7f910
fix tests
vricciardulli Mar 21, 2025
39dd436
Merge remote-tracking branch 'upstream/main' into support-labels-fiel…
vricciardulli Mar 27, 2025
202cde8
Merge remote-tracking branch 'upstream/main' into support-labels-fiel…
vricciardulli Mar 27, 2025
6e02f4d
Merge remote-tracking branch 'upstream/main' into support-labels-fiel…
vricciardulli Mar 28, 2025
d8bfab9
allow reporting deprecation warnings and private usage in `test_gemin…
vricciardulli Mar 28, 2025
042733f
remove unnecessary import
vricciardulli Mar 28, 2025
1cd7c1f
Merge remote-tracking branch 'upstream/main' into support-labels-fiel…
vricciardulli Mar 31, 2025
bd2e05a
Merge remote-tracking branch 'upstream/main' into support-labels-fiel…
vricciardulli Mar 31, 2025
2486fcb
Merge remote-tracking branch 'upstream/main' into support-labels-fiel…
vricciardulli Apr 1, 2025
2df29fa
Merge remote-tracking branch 'upstream/main' into support-labels-fiel…
vricciardulli May 23, 2025
0be9334
support labels for google model
vricciardulli May 23, 2025
e7a681c
use vcr for google gla tests
vricciardulli May 23, 2025
b30f0c7
use vcr for gemini vertex tests
vricciardulli May 23, 2025
344f82b
Merge remote-tracking branch 'upstream/main' into support-labels-fiel…
vricciardulli May 23, 2025
c34ef9f
try to use string as provider
vricciardulli May 23, 2025
97a9df9
Revert "try to use string as provider"
vricciardulli May 23, 2025
7dcfda8
try to use a fixture to patch `google-auth` package
vricciardulli May 23, 2025
df6da39
Merge remote-tracking branch 'upstream/main' into support-labels-fiel…
vricciardulli May 24, 2025
fb1b5aa
Merge remote-tracking branch 'upstream/main' into support-labels-fiel…
vricciardulli May 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 33 additions & 16 deletions pydantic_ai_slim/pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class GeminiModelSettings(ModelSettings, total=False):
"""

gemini_safety_settings: list[GeminiSafetySettings]
"""Safety settings options for Gemini model request."""

gemini_thinking_config: ThinkingConfig
"""Thinking is "on" by default in both the API and AI Studio.
Expand All @@ -93,6 +94,12 @@ class GeminiModelSettings(ModelSettings, total=False):
See more about it on <https://ai.google.dev/gemini-api/docs/thinking>.
"""

gemini_labels: dict[str, str]
"""User-defined metadata to break down billed charges. Only supported by the Vertex AI provider.

See the [Gemini API docs](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/add-labels-to-api-calls) for use cases and limitations.
"""


@dataclass(init=False)
class GeminiModel(Model):
Expand Down Expand Up @@ -223,25 +230,17 @@ async def _make_request(
if tool_config is not None:
request_data['toolConfig'] = tool_config

generation_config: _GeminiGenerationConfig = {}
if model_settings:
if (max_tokens := model_settings.get('max_tokens')) is not None:
generation_config['max_output_tokens'] = max_tokens
if (temperature := model_settings.get('temperature')) is not None:
generation_config['temperature'] = temperature
if (top_p := model_settings.get('top_p')) is not None:
generation_config['top_p'] = top_p
if (presence_penalty := model_settings.get('presence_penalty')) is not None:
generation_config['presence_penalty'] = presence_penalty
if (frequency_penalty := model_settings.get('frequency_penalty')) is not None:
generation_config['frequency_penalty'] = frequency_penalty
if (thinkingConfig := model_settings.get('gemini_thinking_config')) is not None:
generation_config['thinking_config'] = thinkingConfig # pragma: no cover
if (gemini_safety_settings := model_settings.get('gemini_safety_settings')) is not None:
request_data['safetySettings'] = gemini_safety_settings
generation_config = _settings_to_generation_config(model_settings)
if generation_config:
request_data['generationConfig'] = generation_config

if gemini_safety_settings := model_settings.get('gemini_safety_settings'):
request_data['safetySettings'] = gemini_safety_settings

if gemini_labels := model_settings.get('gemini_labels'):
if self._system == 'google-vertex':
request_data['labels'] = gemini_labels

headers = {'Content-Type': 'application/json', 'User-Agent': get_user_agent()}
url = f'/{self._model_name}:{"streamGenerateContent" if streamed else "generateContent"}'

Expand Down Expand Up @@ -362,6 +361,23 @@ async def _map_user_prompt(self, part: UserPromptPart) -> list[_GeminiPartUnion]
return content


def _settings_to_generation_config(model_settings: GeminiModelSettings) -> _GeminiGenerationConfig:
config: _GeminiGenerationConfig = {}
if (max_tokens := model_settings.get('max_tokens')) is not None:
config['max_output_tokens'] = max_tokens
if (temperature := model_settings.get('temperature')) is not None:
config['temperature'] = temperature
if (top_p := model_settings.get('top_p')) is not None:
config['top_p'] = top_p
if (presence_penalty := model_settings.get('presence_penalty')) is not None:
config['presence_penalty'] = presence_penalty
if (frequency_penalty := model_settings.get('frequency_penalty')) is not None:
config['frequency_penalty'] = frequency_penalty
if (thinkingConfig := model_settings.get('gemini_thinking_config')) is not None:
config['thinking_config'] = thinkingConfig # pragma: no cover
return config


class AuthProtocol(Protocol):
"""Abstract definition for Gemini authentication."""

Expand Down Expand Up @@ -483,6 +499,7 @@ class _GeminiRequest(TypedDict):
<https://ai.google.dev/gemini-api/docs/system-instructions?lang=rest>
"""
generationConfig: NotRequired[_GeminiGenerationConfig]
labels: NotRequired[dict[str, str]]


class GeminiSafetySettings(TypedDict):
Expand Down
7 changes: 7 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,12 @@ class GoogleModelSettings(ModelSettings, total=False):
See <https://ai.google.dev/gemini-api/docs/thinking> for more information.
"""

google_labels: dict[str, str]
"""User-defined metadata to break down billed charges. Only supported by the Vertex AI API.

See the [Gemini API docs](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/add-labels-to-api-calls) for use cases and limitations.
"""


@dataclass(init=False)
class GoogleModel(Model):
Expand Down Expand Up @@ -269,6 +275,7 @@ async def _generate_content(
frequency_penalty=model_settings.get('frequency_penalty'),
safety_settings=model_settings.get('google_safety_settings'),
thinking_config=model_settings.get('google_thinking_config'),
labels=model_settings.get('google_labels'),
tools=cast(ToolListUnionDict, tools),
tool_config=tool_config,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
interactions:
- request:
headers:
accept:
- "*/*"
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- "82"
content-type:
- application/json
host:
- generativelanguage.googleapis.com
method: POST
parsed_body:
contents:
- parts:
- text: What is the capital of France?
role: user
uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent
response:
headers:
alt-svc:
- h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
content-length:
- "637"
content-type:
- application/json; charset=UTF-8
server-timing:
- gfet4t7; dur=426
transfer-encoding:
- chunked
vary:
- Origin
- X-Origin
- Referer
parsed_body:
candidates:
- avgLogprobs: -0.02703852951526642
content:
parts:
- text: |
The capital of France is **Paris**.
role: model
finishReason: STOP
modelVersion: gemini-2.0-flash
usageMetadata:
candidatesTokenCount: 9
candidatesTokensDetails:
- modality: TEXT
tokenCount: 9
promptTokenCount: 7
promptTokensDetails:
- modality: TEXT
tokenCount: 7
totalTokenCount: 16
status:
code: 200
message: OK
version: 1
110 changes: 110 additions & 0 deletions tests/models/cassettes/test_gemini_vertexai/test_labels.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
interactions:
- request:
body: grant_type=%5B%27refresh_token%27%5D&client_id=%5B%27764086051850-6qr4p6gpi6hn506pt8ejuq83di341hur.apps.googleusercontent.com%27%5D&client_secret=%5B%27scrubbed%27%5D&refresh_token=%5B%27scrubbed%27%5D
headers:
accept:
- "*/*"
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- "268"
content-type:
- application/x-www-form-urlencoded
method: POST
uri: https://oauth2.googleapis.com/token
response:
headers:
alt-svc:
- h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
cache-control:
- no-cache, no-store, max-age=0, must-revalidate
content-length:
- "1419"
content-type:
- application/json; charset=utf-8
expires:
- Mon, 01 Jan 1990 00:00:00 GMT
pragma:
- no-cache
transfer-encoding:
- chunked
vary:
- Origin
- X-Origin
- Referer
parsed_body:
access_token: scrubbed
expires_in: 3599
id_token: eyJhbGciOiJSUzI1NiIsImtpZCI6IjgyMWYzYmM2NmYwNzUxZjc4NDA2MDY3OTliMWFkZjllOWZiNjBkZmIiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI3NjQwODYwNTE4NTAtNnFyNHA2Z3BpNmhuNTA2cHQ4ZWp1cTgzZGkzNDFodXIuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI3NjQwODYwNTE4NTAtNnFyNHA2Z3BpNmhuNTA2cHQ4ZWp1cTgzZGkzNDFodXIuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMDY1Njg0NzQzMTU3NzkyMTI1NTkiLCJoZCI6InB5ZGFudGljLmRldiIsImVtYWlsIjoibWFyY2Vsb0BweWRhbnRpYy5kZXYiLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6ImlyckNRNE00c0Z0Z2dfS2hRTVNjekEiLCJpYXQiOjE3NDM0MTM3NzcsImV4cCI6MTc0MzQxNzM3N30.BAvb4TlcIoYcQODNLFqwtUQoSNJJbpAR0lk2OyFxXK9rSZ7m1e1_Dp1O4ApxPUS7f_NX34eSCuDJN2IXgh8VBv4k3IhI7CbMydYeqXuwlbgOOp1Z0farGEKneU1M7TvdngigAJ9wT-2LHjKd_GEcGau-CUvzXpcT1IOnNNyXGVqtuGmEfcw5jjPkKJNECUheeNHE3zeImatTstOLuKmI1ZK-etl41l3poSNuQkZkrbQ80Vst8BdT-b1tnJnXP1KGATBIamDy99OOiB9a7a9m_ikXYEyN91yR76DYot3hpDPlOX0H9hF-BOSqoOtlSS2TMBkMvFiiYWjID1e_9VlNUg
scope: https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/userinfo.email openid https://www.googleapis.com/auth/sqlservice.login
token_type: Bearer
status:
code: 200
message: OK
- request:
headers:
accept:
- "*/*"
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- "133"
content-type:
- application/json
host:
- us-central1-aiplatform.googleapis.com
method: POST
parsed_body:
contents:
- parts:
- text: What is the capital of France?
role: user
labels:
environment: test
team: analytics
uri: https://us-central1-aiplatform.googleapis.com/v1/projects/pydantic-ai/locations/us-central1/publishers/google/models/gemini-2.0-flash:generateContent
response:
headers:
alt-svc:
- h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
content-length:
- "759"
content-type:
- application/json; charset=UTF-8
transfer-encoding:
- chunked
vary:
- Origin
- X-Origin
- Referer
parsed_body:
candidates:
- avgLogprobs: -0.02703852951526642
content:
parts:
- text: |
The capital of France is **Paris**.
role: model
finishReason: STOP
createTime: "2025-05-23T07:53:55.494386Z"
modelVersion: gemini-2.0-flash
responseId: kykwaLKWHti5nvgPmN2T8AE
usageMetadata:
candidatesTokenCount: 9
candidatesTokensDetails:
- modality: TEXT
tokenCount: 9
promptTokenCount: 7
promptTokensDetails:
- modality: TEXT
tokenCount: 7
totalTokenCount: 16
trafficType: ON_DEMAND
status:
code: 200
message: OK
version: 1
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
interactions:
- request:
body: grant_type=%5B%27refresh_token%27%5D&client_id=%5B%27764086051850-6qr4p6gpi6hn506pt8ejuq83di341hur.apps.googleusercontent.com%27%5D&client_secret=%5B%27scrubbed%27%5D&refresh_token=%5B%27scrubbed%27%5D
headers:
accept:
- "*/*"
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- "268"
content-type:
- application/x-www-form-urlencoded
method: POST
uri: https://oauth2.googleapis.com/token
response:
headers:
alt-svc:
- h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
cache-control:
- no-cache, no-store, max-age=0, must-revalidate
content-length:
- "1420"
content-type:
- application/json; charset=utf-8
expires:
- Mon, 01 Jan 1990 00:00:00 GMT
pragma:
- no-cache
transfer-encoding:
- chunked
vary:
- Origin
- X-Origin
- Referer
parsed_body:
access_token: scrubbed
expires_in: 3599
id_token: eyJhbGciOiJSUzI1NiIsImtpZCI6IjY2MGVmM2I5Nzg0YmRmNTZlYmU4NTlmNTc3ZjdmYjJlOGMxY2VmZmIiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI3NjQwODYwNTE4NTAtNnFyNHA2Z3BpNmhuNTA2cHQ4ZWp1cTgzZGkzNDFodXIuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI3NjQwODYwNTE4NTAtNnFyNHA2Z3BpNmhuNTA2cHQ4ZWp1cTgzZGkzNDFodXIuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMDY1Njg0NzQzMTU3NzkyMTI1NTkiLCJoZCI6InB5ZGFudGljLmRldiIsImVtYWlsIjoibWFyY2Vsb0BweWRhbnRpYy5kZXYiLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6Ii1CeV9XOWwtRHg1ekg0YTVOV25fV3ciLCJpYXQiOjE3NDc1NzQxOTEsImV4cCI6MTc0NzU3Nzc5MX0.dHg3qRlYoQ8WyIml7-kGqsuefvkl5deuZ0yTQM-RvKuuqtF_t6p8TrWbndEuSbZpRn9JhVPnsoEAYVPexbGy-pon4gu1aHH0dJNq3ghhdim7qp5JWpegLaZqvNvELvEHjj2VNLWXQ70-5wEaI_HCtAWTjlROAHQxvoWHJAdeH0Yf9zoljEBQvx3VLDLEpdCcMd-UGNCBucpQlFHcCJs5Qq8yj8R8f27BCEmRo7z9K3Axuedj_wcJ_tWV1x1tWxojUloJaKsIfztFOPFxzOdNPOlTHXsE47d4v43v87a8LhdDGloD72xN_kLapfIqyTIwRTj4cQvQp5H0u7As49fvMA
scope: https://www.googleapis.com/auth/userinfo.email openid https://www.googleapis.com/auth/sqlservice.login https://www.googleapis.com/auth/cloud-platform
token_type: Bearer
status:
code: 200
message: OK
- request:
headers:
accept:
- "*/*"
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- "257"
content-type:
- application/json
host:
- aiplatform.googleapis.com
method: POST
parsed_body:
contents:
- parts:
- text: What is the capital of France?
role: user
generationConfig: {}
labels:
environment: test
team: analytics
systemInstruction:
parts:
- text: You are a helpful chatbot.
role: user
uri: https://aiplatform.googleapis.com/v1beta1/projects/pydantic-ai/locations/global/publishers/google/models/gemini-2.0-flash:generateContent
response:
headers:
alt-svc:
- h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
content-length:
- "759"
content-type:
- application/json; charset=UTF-8
transfer-encoding:
- chunked
vary:
- Origin
- X-Origin
- Referer
parsed_body:
candidates:
- avgLogprobs: -0.0005532301729544997
content:
parts:
- text: |
The capital of France is Paris.
role: model
finishReason: STOP
createTime: "2025-05-23T07:09:59.524624Z"
modelVersion: gemini-2.0-flash
responseId: sN0paKOZFtmtyOgPqMyL6AE
usageMetadata:
candidatesTokenCount: 8
candidatesTokensDetails:
- modality: TEXT
tokenCount: 8
promptTokenCount: 13
promptTokensDetails:
- modality: TEXT
tokenCount: 13
totalTokenCount: 21
trafficType: ON_DEMAND
status:
code: 200
message: OK
version: 1
Loading