Skip to content

Commit 72dc46b

Browse files
speedstorm1copybara-github
authored andcommitted
feat: enable continuous fine-tuning on a pre-tuned model in the SDK.
PiperOrigin-RevId: 793891314
1 parent 5f6746d commit 72dc46b

File tree

5 files changed

+182
-17
lines changed

5 files changed

+182
-17
lines changed

google/genai/tests/tunings/test_end_to_end.py

Lines changed: 69 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,25 @@
1-
"""Tests for create_sft_job."""
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
215

16+
17+
"""Tests for tunings.tune()."""
18+
19+
import time
20+
from ... import _replay_api_client
321
from ... import types as genai_types
422
from .. import pytest_helper
5-
from ... import _replay_api_client
623

724

825
test_table: list[pytest_helper.TestTableItem] = []
@@ -18,14 +35,15 @@
1835

1936

2037
def test_tune_until_success(client):
21-
import time
22-
2338
if client._api_client.vertexai:
2439
job = client.tunings.tune(
2540
base_model="gemini-2.0-flash-001",
2641
training_dataset=genai_types.TuningDataset(
2742
gcs_uri="gs://cloud-samples-data/ai-platform/generative_ai/gemini-2_0/text/sft_train_data.jsonl",
2843
),
44+
config=genai_types.CreateTuningJobConfig(
45+
epoch_count=1,
46+
),
2947
)
3048
else:
3149
# Remove GenAI SDK test since it is deprecated:
@@ -35,8 +53,54 @@ def test_tune_until_success(client):
3553
while not job.has_ended:
3654
# Skipping the sleep for when in replay mode.
3755
if client._api_client._mode not in ("replay", "auto"):
38-
time.sleep(60)
56+
time.sleep(300)
3957
job = client.tunings.get(name=job.name)
4058

4159
assert job.has_ended
4260
assert job.has_succeeded
61+
62+
63+
def test_continuous_tuning(client):
64+
if not client._api_client.vertexai:
65+
return
66+
67+
job = client.tunings.tune(
68+
base_model="gemini-2.5-flash",
69+
training_dataset=genai_types.TuningDataset(
70+
gcs_uri="gs://cloud-samples-data/ai-platform/generative_ai/gemini-2_0/text/sft_train_data.jsonl",
71+
),
72+
config=genai_types.CreateTuningJobConfig(
73+
epoch_count=1,
74+
),
75+
)
76+
77+
while not job.has_ended:
78+
# Skipping the sleep for when in replay mode.
79+
if client._api_client._mode not in ("replay", "auto"):
80+
time.sleep(300)
81+
job = client.tunings.get(name=job.name)
82+
83+
assert job.has_ended
84+
assert job.has_succeeded
85+
86+
continuous_job = client.tunings.tune(
87+
base_model=job.tuned_model.model,
88+
training_dataset=genai_types.TuningDataset(
89+
gcs_uri="gs://cloud-samples-data/ai-platform/generative_ai/gemini-2_0/text/sft_train_data.jsonl",
90+
),
91+
config=genai_types.CreateTuningJobConfig(
92+
tuned_model_display_name="continuous tuning job",
93+
epoch_count=1,
94+
)
95+
)
96+
97+
while not continuous_job.has_ended:
98+
# Skipping the sleep for when in replay mode.
99+
if client._api_client._mode not in ("replay", "auto"):
100+
time.sleep(300)
101+
continuous_job = client.tunings.get(name=continuous_job.name)
102+
103+
assert continuous_job.has_ended
104+
assert continuous_job.has_succeeded
105+
106+

google/genai/tests/tunings/test_get.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#
1515

1616

17+
"""Tests for tunings.get()."""
18+
1719
from ... import types as genai_types
1820
from .. import pytest_helper
1921

@@ -25,6 +27,13 @@
2527
),
2628
exception_if_mldev="Not Found",
2729
),
30+
pytest_helper.TestTableItem(
31+
name="test_vertexai_with_pretuned_model",
32+
parameters=genai_types._GetTuningJobParameters(
33+
name="projects/801452371447/locations/us-central1/tuningJobs/8520221517529743360",
34+
),
35+
exception_if_mldev="Not Found",
36+
),
2837
pytest_helper.TestTableItem(
2938
name="test_mldev",
3039
parameters=genai_types._GetTuningJobParameters(

google/genai/tests/tunings/test_tune.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
#
15-
# %%
1615

1716

18-
"""Tests for create_sft_job."""
17+
"""Tests for tunings.tune()."""
1918

2019
from ... import types as genai_types
2120
from .. import pytest_helper
@@ -32,6 +31,42 @@
3231
),
3332
exception_if_mldev="gcs_uri parameter is not supported in Gemini API.",
3433
),
34+
pytest_helper.TestTableItem(
35+
name="test_tune_pretuned_model",
36+
parameters=genai_types.CreateTuningJobParameters(
37+
base_model="projects/801452371447/locations/us-central1/models/9030969596621881344",
38+
training_dataset=genai_types.TuningDataset(
39+
gcs_uri="gs://cloud-samples-data/ai-platform/generative_ai/gemini-2_0/text/sft_train_data.jsonl",
40+
),
41+
),
42+
exception_if_mldev="is not supported in Gemini API",
43+
),
44+
pytest_helper.TestTableItem(
45+
name="test_tune_pretuned_model_with_checkpoint_id",
46+
parameters=genai_types.CreateTuningJobParameters(
47+
base_model="projects/801452371447/locations/us-central1/models/9030969596621881344",
48+
training_dataset=genai_types.TuningDataset(
49+
gcs_uri="gs://cloud-samples-data/ai-platform/generative_ai/gemini-2_0/text/sft_train_data.jsonl",
50+
),
51+
config=genai_types.CreateTuningJobConfig(
52+
pre_tuned_model_checkpoint_id="3",
53+
),
54+
),
55+
exception_if_mldev="is not supported in Gemini API",
56+
),
57+
pytest_helper.TestTableItem(
58+
name="test_non_pretuned_model_with_checkpoint_id",
59+
parameters=genai_types.CreateTuningJobParameters(
60+
base_model="gemini-2.5-flash",
61+
training_dataset=genai_types.TuningDataset(
62+
gcs_uri="gs://cloud-samples-data/ai-platform/generative_ai/gemini-1_5/text/sft_train_data.jsonl",
63+
),
64+
config=genai_types.CreateTuningJobConfig(
65+
pre_tuned_model_checkpoint_id="3",
66+
),
67+
),
68+
exception_if_mldev="is not supported in Gemini API.",
69+
),
3570
pytest_helper.TestTableItem(
3671
name="test_dataset_gcs_uri_all_parameters",
3772
parameters=genai_types.CreateTuningJobParameters(

google/genai/tunings.py

Lines changed: 54 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,12 @@ def _CreateTuningJobConfig_to_mldev(
166166
'export_last_checkpoint_only parameter is not supported in Gemini API.'
167167
)
168168

169+
if getv(from_object, ['pre_tuned_model_checkpoint_id']) is not None:
170+
raise ValueError(
171+
'pre_tuned_model_checkpoint_id parameter is not supported in Gemini'
172+
' API.'
173+
)
174+
169175
if getv(from_object, ['adapter_size']) is not None:
170176
raise ValueError('adapter_size parameter is not supported in Gemini API.')
171177

@@ -194,6 +200,9 @@ def _CreateTuningJobParametersPrivate_to_mldev(
194200
if getv(from_object, ['base_model']) is not None:
195201
setv(to_object, ['baseModel'], getv(from_object, ['base_model']))
196202

203+
if getv(from_object, ['pre_tuned_model']) is not None:
204+
setv(to_object, ['preTunedModel'], getv(from_object, ['pre_tuned_model']))
205+
197206
if getv(from_object, ['training_dataset']) is not None:
198207
setv(
199208
to_object,
@@ -359,6 +368,13 @@ def _CreateTuningJobConfig_to_vertex(
359368
getv(from_object, ['export_last_checkpoint_only']),
360369
)
361370

371+
if getv(from_object, ['pre_tuned_model_checkpoint_id']) is not None:
372+
setv(
373+
to_object,
374+
['preTunedModel', 'checkpointId'],
375+
getv(from_object, ['pre_tuned_model_checkpoint_id']),
376+
)
377+
362378
if getv(from_object, ['adapter_size']) is not None:
363379
setv(
364380
parent_object,
@@ -383,6 +399,9 @@ def _CreateTuningJobParametersPrivate_to_vertex(
383399
if getv(from_object, ['base_model']) is not None:
384400
setv(to_object, ['baseModel'], getv(from_object, ['base_model']))
385401

402+
if getv(from_object, ['pre_tuned_model']) is not None:
403+
setv(to_object, ['preTunedModel'], getv(from_object, ['pre_tuned_model']))
404+
386405
if getv(from_object, ['training_dataset']) is not None:
387406
setv(
388407
to_object,
@@ -915,6 +934,7 @@ def _tune(
915934
self,
916935
*,
917936
base_model: Optional[str] = None,
937+
pre_tuned_model: Optional[types.PreTunedModelOrDict] = None,
918938
training_dataset: types.TuningDatasetOrDict,
919939
config: Optional[types.CreateTuningJobConfigOrDict] = None,
920940
) -> types.TuningJob:
@@ -931,6 +951,7 @@ def _tune(
931951

932952
parameter_model = types._CreateTuningJobParametersPrivate(
933953
base_model=base_model,
954+
pre_tuned_model=pre_tuned_model,
934955
training_dataset=training_dataset,
935956
config=config,
936957
)
@@ -986,6 +1007,7 @@ def _tune_mldev(
9861007
self,
9871008
*,
9881009
base_model: Optional[str] = None,
1010+
pre_tuned_model: Optional[types.PreTunedModelOrDict] = None,
9891011
training_dataset: types.TuningDatasetOrDict,
9901012
config: Optional[types.CreateTuningJobConfigOrDict] = None,
9911013
) -> types.TuningOperation:
@@ -1002,6 +1024,7 @@ def _tune_mldev(
10021024

10031025
parameter_model = types._CreateTuningJobParametersPrivate(
10041026
base_model=base_model,
1027+
pre_tuned_model=pre_tuned_model,
10051028
training_dataset=training_dataset,
10061029
config=config,
10071030
)
@@ -1093,11 +1116,19 @@ def tune(
10931116
config: Optional[types.CreateTuningJobConfigOrDict] = None,
10941117
) -> types.TuningJob:
10951118
if self._api_client.vertexai:
1096-
tuning_job = self._tune(
1097-
base_model=base_model,
1098-
training_dataset=training_dataset,
1099-
config=config,
1100-
)
1119+
if base_model.startswith('projects/'): # Pre-tuned model
1120+
pre_tuned_model = types.PreTunedModel(tuned_model_name=base_model)
1121+
tuning_job = self._tune(
1122+
pre_tuned_model=pre_tuned_model,
1123+
training_dataset=training_dataset,
1124+
config=config,
1125+
)
1126+
else:
1127+
tuning_job = self._tune(
1128+
base_model=base_model,
1129+
training_dataset=training_dataset,
1130+
config=config,
1131+
)
11011132
else:
11021133
operation = self._tune_mldev(
11031134
base_model=base_model,
@@ -1269,6 +1300,7 @@ async def _tune(
12691300
self,
12701301
*,
12711302
base_model: Optional[str] = None,
1303+
pre_tuned_model: Optional[types.PreTunedModelOrDict] = None,
12721304
training_dataset: types.TuningDatasetOrDict,
12731305
config: Optional[types.CreateTuningJobConfigOrDict] = None,
12741306
) -> types.TuningJob:
@@ -1285,6 +1317,7 @@ async def _tune(
12851317

12861318
parameter_model = types._CreateTuningJobParametersPrivate(
12871319
base_model=base_model,
1320+
pre_tuned_model=pre_tuned_model,
12881321
training_dataset=training_dataset,
12891322
config=config,
12901323
)
@@ -1340,6 +1373,7 @@ async def _tune_mldev(
13401373
self,
13411374
*,
13421375
base_model: Optional[str] = None,
1376+
pre_tuned_model: Optional[types.PreTunedModelOrDict] = None,
13431377
training_dataset: types.TuningDatasetOrDict,
13441378
config: Optional[types.CreateTuningJobConfigOrDict] = None,
13451379
) -> types.TuningOperation:
@@ -1356,6 +1390,7 @@ async def _tune_mldev(
13561390

13571391
parameter_model = types._CreateTuningJobParametersPrivate(
13581392
base_model=base_model,
1393+
pre_tuned_model=pre_tuned_model,
13591394
training_dataset=training_dataset,
13601395
config=config,
13611396
)
@@ -1447,11 +1482,20 @@ async def tune(
14471482
config: Optional[types.CreateTuningJobConfigOrDict] = None,
14481483
) -> types.TuningJob:
14491484
if self._api_client.vertexai:
1450-
tuning_job = await self._tune(
1451-
base_model=base_model,
1452-
training_dataset=training_dataset,
1453-
config=config,
1454-
)
1485+
if base_model.startswith('projects/'): # Pre-tuned model
1486+
pre_tuned_model = types.PreTunedModel(tuned_model_name=base_model)
1487+
1488+
tuning_job = await self._tune(
1489+
pre_tuned_model=pre_tuned_model,
1490+
training_dataset=training_dataset,
1491+
config=config,
1492+
)
1493+
else:
1494+
tuning_job = await self._tune(
1495+
base_model=base_model,
1496+
training_dataset=training_dataset,
1497+
config=config,
1498+
)
14551499
else:
14561500
operation = await self._tune_mldev(
14571501
base_model=base_model,

google/genai/types.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10057,6 +10057,10 @@ class CreateTuningJobConfig(_common.BaseModel):
1005710057
default=None,
1005810058
description="""If set to true, disable intermediate checkpoints for SFT and only the last checkpoint will be exported. Otherwise, enable intermediate checkpoints for SFT.""",
1005910059
)
10060+
pre_tuned_model_checkpoint_id: Optional[str] = Field(
10061+
default=None,
10062+
description="""The optional checkpoint id of the pre-tuned model to use for tuning, if applicable.""",
10063+
)
1006010064
adapter_size: Optional[AdapterSize] = Field(
1006110065
default=None, description="""Adapter size for tuning."""
1006210066
)
@@ -10094,6 +10098,9 @@ class CreateTuningJobConfigDict(TypedDict, total=False):
1009410098
export_last_checkpoint_only: Optional[bool]
1009510099
"""If set to true, disable intermediate checkpoints for SFT and only the last checkpoint will be exported. Otherwise, enable intermediate checkpoints for SFT."""
1009610100

10101+
pre_tuned_model_checkpoint_id: Optional[str]
10102+
"""The optional checkpoint id of the pre-tuned model to use for tuning, if applicable."""
10103+
1009710104
adapter_size: Optional[AdapterSize]
1009810105
"""Adapter size for tuning."""
1009910106

@@ -10116,6 +10123,9 @@ class _CreateTuningJobParametersPrivate(_common.BaseModel):
1011610123
default=None,
1011710124
description="""The base model that is being tuned, e.g., "gemini-2.5-flash".""",
1011810125
)
10126+
pre_tuned_model: Optional[PreTunedModel] = Field(
10127+
default=None, description="""The PreTunedModel that is being tuned."""
10128+
)
1011910129
training_dataset: Optional[TuningDataset] = Field(
1012010130
default=None,
1012110131
description="""Cloud Storage path to file containing training dataset for tuning. The dataset must be formatted as a JSONL file.""",
@@ -10131,6 +10141,9 @@ class _CreateTuningJobParametersPrivateDict(TypedDict, total=False):
1013110141
base_model: Optional[str]
1013210142
"""The base model that is being tuned, e.g., "gemini-2.5-flash"."""
1013310143

10144+
pre_tuned_model: Optional[PreTunedModelDict]
10145+
"""The PreTunedModel that is being tuned."""
10146+
1013410147
training_dataset: Optional[TuningDatasetDict]
1013510148
"""Cloud Storage path to file containing training dataset for tuning. The dataset must be formatted as a JSONL file."""
1013610149

0 commit comments

Comments
 (0)