-
Notifications
You must be signed in to change notification settings - Fork 6.6k
feat: add code samples for continuous tuning #13579
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# Copyright 2025 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
def create_continuous_tuning_job(tuned_model_name: str, checkpoint_id: str) -> str: | ||
# [START googlegenaisdk_continuous_tuning_create] | ||
import time | ||
|
||
from google import genai | ||
from google.genai.types import HttpOptions, TuningDataset, CreateTuningJobConfig | ||
Comment on lines
+18
to
+21
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. According to PEP 8, imports should be at the top of the file, outside of any functions.1 While this pattern is sometimes used in code samples for self-containment, moving the imports to the module level improves readability and is a standard Python convention. Style Guide ReferencesFootnotes
|
||
|
||
# TODO(developer): Update and un-comment below line | ||
# tuned_model_name = "projects/123456789012/locations/us-central1/models/1234567890@1" | ||
# checkpoint_id = "1" | ||
|
||
client = genai.Client(http_options=HttpOptions(api_version="v1beta1")) | ||
|
||
training_dataset = TuningDataset( | ||
gcs_uri="gs://cloud-samples-data/ai-platform/generative_ai/gemini/text/sft_train_data.jsonl", | ||
) | ||
validation_dataset = TuningDataset( | ||
gcs_uri="gs://cloud-samples-data/ai-platform/generative_ai/gemini/text/sft_validation_data.jsonl", | ||
) | ||
|
||
tuning_job = client.tunings.tune( | ||
base_model=tuned_model_name, | ||
training_dataset=training_dataset, | ||
config=CreateTuningJobConfig( | ||
tuned_model_display_name="Example tuning job", | ||
validation_dataset=validation_dataset, | ||
pre_tuned_model_checkpoint_id=checkpoint_id, | ||
), | ||
) | ||
|
||
running_states = set([ | ||
"JOB_STATE_PENDING", | ||
"JOB_STATE_RUNNING", | ||
]) | ||
Comment on lines
+46
to
+49
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This set of running states is a constant. It's better to define it at the module level (outside the function) with an all-caps name, like Style Guide ReferencesFootnotes
|
||
|
||
while tuning_job.state in running_states: | ||
print(tuning_job.state) | ||
tuning_job = client.tunings.get(name=tuning_job.name) | ||
time.sleep(60) | ||
|
||
print(tuning_job.tuned_model.model) | ||
print(tuning_job.tuned_model.endpoint) | ||
print(tuning_job.experiment) | ||
# Example response: | ||
# projects/123456789012/locations/us-central1/models/1234567890@2 | ||
# projects/123456789012/locations/us-central1/endpoints/123456789012345 | ||
# projects/123456789012/locations/us-central1/metadataStores/default/contexts/tuning-experiment-2025010112345678 | ||
|
||
if tuning_job.tuned_model.checkpoints: | ||
for i, checkpoint in enumerate(tuning_job.tuned_model.checkpoints): | ||
print(f"Checkpoint {i + 1}: ", checkpoint) | ||
# Example response: | ||
# Checkpoint 1: checkpoint_id='1' epoch=1 step=10 endpoint='projects/123456789012/locations/us-central1/endpoints/123456789000000' | ||
# Checkpoint 2: checkpoint_id='2' epoch=2 step=20 endpoint='projects/123456789012/locations/us-central1/endpoints/123456789012345' | ||
|
||
# [END googlegenaisdk_continuous_tuning_create] | ||
return tuning_job.name | ||
|
||
|
||
if __name__ == "__main__": | ||
pre_tuned_model_name = input("Pre-tuned model name: ") | ||
pre_tuned_model_checkpoint_id = input("Pre-tuned model checkpoint id: ") | ||
create_continuous_tuning_job(pre_tuned_model_name, pre_tuned_model_checkpoint_id) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
google-genai==1.30.0 | ||
google-genai==1.39.1 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,7 @@ | |
from google.genai import types | ||
import pytest | ||
|
||
import continuous_tuning_create | ||
import tuning_job_create | ||
import tuning_job_get | ||
import tuning_job_list | ||
|
@@ -306,3 +307,23 @@ def test_tuning_with_checkpoints_textgen_with_txt(mock_genai_client: MagicMock) | |
call(model="test-endpoint-1", contents="Why is the sky blue?"), | ||
call(model="test-endpoint-2", contents="Why is the sky blue?"), | ||
] | ||
|
||
|
||
@patch("google.genai.Client") | ||
def test_continuous_tuning_create(mock_genai_client: MagicMock) -> None: | ||
# Mock the API response | ||
mock_tuning_job = types.TuningJob( | ||
name="test-tuning-job", | ||
experiment="test-experiment", | ||
tuned_model=types.TunedModel( | ||
model="test-model-2", | ||
endpoint="test-endpoint" | ||
) | ||
) | ||
mock_genai_client.return_value.tunings.tune.return_value = mock_tuning_job | ||
|
||
response = continuous_tuning_create.create_continuous_tuning_job(tuned_model_name="test-model", checkpoint_id="1") | ||
|
||
mock_genai_client.assert_called_once_with(http_options=types.HttpOptions(api_version="v1beta1")) | ||
mock_genai_client.return_value.tunings.tune.assert_called_once() | ||
assert response == "test-tuning-job" | ||
Comment on lines
+312
to
+329
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test doesn't fully cover the functionality of To properly test the polling logic, you should:
Here is a more complete test case that replaces the current one: from unittest.mock import patch, MagicMock
@patch("continuous_tuning_create.time.sleep")
@patch("google.genai.Client")
def test_continuous_tuning_create(mock_genai_client: MagicMock, mock_sleep: MagicMock) -> None:
# Mock the API response for the initial tune call
mock_initial_job = types.TuningJob(
name="test-tuning-job",
state="JOB_STATE_PENDING",
)
mock_genai_client.return_value.tunings.tune.return_value = mock_initial_job
# Mock the responses for the polling `get` call
mock_running_job = types.TuningJob(
name="test-tuning-job",
state="JOB_STATE_RUNNING",
)
mock_succeeded_job = types.TuningJob(
name="test-tuning-job",
state="JOB_STATE_SUCCEEDED",
experiment="test-experiment",
tuned_model=types.TunedModel(
model="test-model-2",
endpoint="test-endpoint",
checkpoints=[]
)
)
mock_genai_client.return_value.tunings.get.side_effect = [
mock_running_job,
mock_succeeded_job
]
response = continuous_tuning_create.create_continuous_tuning_job(tuned_model_name="test-model", checkpoint_id="1")
mock_genai_client.assert_called_once_with(http_options=types.HttpOptions(api_version="v1beta1"))
mock_genai_client.return_value.tunings.tune.assert_called_once()
# Assert that polling happened
assert mock_genai_client.return_value.tunings.get.call_count == 2
mock_sleep.assert_called_with(60)
assert mock_sleep.call_count == 2
assert response == "test-tuning-job" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The parameter name
tuned_model_name
is confusing because it's used as thebase_model
on line 37. Renaming it tobase_model_name
would make the function's purpose clearer and improve readability. You will also need to update its usage on line 37.