Skip to content

Commit 4fb54df

Browse files
authored
feat: add gemini-2.0-flash-001 and gemini-2.0-flash-lite-001 to fine tune score endponts and multimodal endpoints (#1650)
* add test, no code support yet * add gemini-2.0-xx to fine tune score endponts and multimodal endpoints * wait for bqml to support gemini-2.0-flash-lite-001 * remove unsupported GA feature * remove unsupported fine-tune endpoints * fix a failed test * remove features are not ready * fix failed test * revert a typo
1 parent c958dbe commit 4fb54df

File tree

3 files changed

+17
-2
lines changed

3 files changed

+17
-2
lines changed

bigframes/ml/llm.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,16 @@
7272
_GEMINI_FINE_TUNE_SCORE_ENDPOINTS = (
7373
_GEMINI_1P5_PRO_002_ENDPOINT,
7474
_GEMINI_1P5_FLASH_002_ENDPOINT,
75+
_GEMINI_2_FLASH_001_ENDPOINT,
76+
_GEMINI_2_FLASH_LITE_001_ENDPOINT,
7577
)
7678
_GEMINI_MULTIMODAL_ENDPOINTS = (
7779
_GEMINI_1P5_PRO_001_ENDPOINT,
7880
_GEMINI_1P5_PRO_002_ENDPOINT,
7981
_GEMINI_1P5_FLASH_001_ENDPOINT,
8082
_GEMINI_1P5_FLASH_002_ENDPOINT,
8183
_GEMINI_2_FLASH_EXP_ENDPOINT,
84+
_GEMINI_2_FLASH_001_ENDPOINT,
8285
)
8386

8487
_CLAUDE_3_SONNET_ENDPOINT = "claude-3-sonnet"
@@ -712,7 +715,8 @@ def score(
712715
] = "text_generation",
713716
) -> bigframes.dataframe.DataFrame:
714717
"""Calculate evaluation metrics of the model. Only support
715-
"gemini-1.5-pro-002", and "gemini-1.5-flash-002".
718+
"gemini-1.5-pro-002", "gemini-1.5-flash-002",
719+
"gemini-2.0-flash-lite-001", and "gemini-2.0-flash-001".
716720
717721
.. note::
718722
@@ -746,7 +750,7 @@ def score(
746750

747751
if self.model_name not in _GEMINI_FINE_TUNE_SCORE_ENDPOINTS:
748752
raise NotImplementedError(
749-
"score() only supports gemini-1.5-pro-002, and gemini-1.5-flash-2 model."
753+
"score() only supports gemini-1.5-pro-002, gemini-1.5-flash-2, gemini-2.0-flash-001, and gemini-2.0-flash-lite-001 model."
750754
)
751755

752756
X, y = utils.batch_convert_to_dataframe(X, y, session=self._bqml_model.session)

tests/system/small/ml/test_llm.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,8 @@ def test_create_load_gemini_text_generator_model(
152152
"gemini-1.5-flash-001",
153153
"gemini-1.5-flash-002",
154154
"gemini-2.0-flash-exp",
155+
"gemini-2.0-flash-001",
156+
"gemini-2.0-flash-lite-001",
155157
),
156158
)
157159
@pytest.mark.flaky(retries=2)
@@ -177,6 +179,8 @@ def test_gemini_text_generator_predict_default_params_success(
177179
"gemini-1.5-flash-001",
178180
"gemini-1.5-flash-002",
179181
"gemini-2.0-flash-exp",
182+
"gemini-2.0-flash-001",
183+
"gemini-2.0-flash-lite-001",
180184
),
181185
)
182186
@pytest.mark.flaky(retries=2)
@@ -204,6 +208,8 @@ def test_gemini_text_generator_predict_with_params_success(
204208
"gemini-1.5-flash-001",
205209
"gemini-1.5-flash-002",
206210
"gemini-2.0-flash-exp",
211+
"gemini-2.0-flash-001",
212+
"gemini-2.0-flash-lite-001",
207213
),
208214
)
209215
@pytest.mark.flaky(retries=2)
@@ -764,6 +770,8 @@ def test_text_embedding_generator_retry_no_progress(session, bq_connection):
764770
(
765771
"gemini-1.5-pro-002",
766772
"gemini-1.5-flash-002",
773+
"gemini-2.0-flash-001",
774+
"gemini-2.0-flash-lite-001",
767775
),
768776
)
769777
def test_llm_gemini_score(llm_fine_tune_df_default_index, model_name):
@@ -792,6 +800,8 @@ def test_llm_gemini_score(llm_fine_tune_df_default_index, model_name):
792800
(
793801
"gemini-1.5-pro-002",
794802
"gemini-1.5-flash-002",
803+
"gemini-2.0-flash-001",
804+
"gemini-2.0-flash-lite-001",
795805
),
796806
)
797807
def test_llm_gemini_pro_score_params(llm_fine_tune_df_default_index, model_name):

tests/system/small/ml/test_multimodal_llm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def test_multimodal_embedding_generator_predict_default_params_success(
4747
"gemini-1.5-flash-001",
4848
"gemini-1.5-flash-002",
4949
"gemini-2.0-flash-exp",
50+
"gemini-2.0-flash-001",
5051
),
5152
)
5253
@pytest.mark.flaky(retries=2)

0 commit comments

Comments
 (0)