Skip to content

Commit 8193abe

Browse files
authored
feat: add client side retry to GeminiTextGenerator (#1242)
* feat: add client side retry to GeminiTextGenerator * test * test * test * test * fix * max_retries * fix * fix
1 parent 20f3190 commit 8193abe

File tree

2 files changed

+256
-4
lines changed

2 files changed

+256
-4
lines changed

bigframes/ml/llm.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import bigframes
2727
from bigframes import clients, exceptions
2828
from bigframes.core import blocks, log_adapter
29+
import bigframes.dataframe
2930
from bigframes.ml import base, core, globals, utils
3031
import bigframes.pandas as bpd
3132

@@ -945,6 +946,7 @@ def predict(
945946
top_k: int = 40,
946947
top_p: float = 1.0,
947948
ground_with_google_search: bool = False,
949+
max_retries: int = 0,
948950
) -> bpd.DataFrame:
949951
"""Predict the result from input DataFrame.
950952
@@ -983,6 +985,10 @@ def predict(
983985
page for details: https://cloud.google.com/vertex-ai/generative-ai/pricing#google_models
984986
The default is `False`.
985987
988+
max_retries (int, default 0):
989+
Max number of retry rounds if any rows failed in the prediction. Each round need to make progress (has succeeded rows) to continue the next retry round.
990+
Each round will append newly succeeded rows. When the max retry rounds is reached, the remaining failed rows will be appended to the end of the result.
991+
986992
Returns:
987993
bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values.
988994
"""
@@ -1002,6 +1008,11 @@ def predict(
10021008
if top_p < 0.0 or top_p > 1.0:
10031009
raise ValueError(f"top_p must be [0.0, 1.0], but is {top_p}.")
10041010

1011+
if max_retries < 0:
1012+
raise ValueError(
1013+
f"max_retries must be larger than or equal to 0, but is {max_retries}."
1014+
)
1015+
10051016
(X,) = utils.batch_convert_to_dataframe(X, session=self._bqml_model.session)
10061017

10071018
if len(X.columns) == 1:
@@ -1018,15 +1029,37 @@ def predict(
10181029
"ground_with_google_search": ground_with_google_search,
10191030
}
10201031

1021-
df = self._bqml_model.generate_text(X, options)
1032+
df_result = bpd.DataFrame(session=self._bqml_model.session) # placeholder
1033+
df_fail = X
1034+
for _ in range(max_retries + 1):
1035+
df = self._bqml_model.generate_text(df_fail, options)
10221036

1023-
if (df[_ML_GENERATE_TEXT_STATUS] != "").any():
1037+
df_succ = df[df[_ML_GENERATE_TEXT_STATUS].str.len() == 0]
1038+
df_fail = df[df[_ML_GENERATE_TEXT_STATUS].str.len() > 0]
1039+
1040+
if df_succ.empty:
1041+
warnings.warn("Can't make any progress, stop retrying.", RuntimeWarning)
1042+
break
1043+
1044+
df_result = (
1045+
bpd.concat([df_result, df_succ]) if not df_result.empty else df_succ
1046+
)
1047+
1048+
if df_fail.empty:
1049+
break
1050+
1051+
if not df_fail.empty:
10241052
warnings.warn(
10251053
f"Some predictions failed. Check column {_ML_GENERATE_TEXT_STATUS} for detailed status. You may want to filter the failed rows and retry.",
10261054
RuntimeWarning,
10271055
)
10281056

1029-
return df
1057+
df_result = cast(
1058+
bpd.DataFrame,
1059+
bpd.concat([df_result, df_fail]) if not df_result.empty else df_fail,
1060+
)
1061+
1062+
return df_result
10301063

10311064
def score(
10321065
self,

tests/system/small/ml/test_llm.py

Lines changed: 220 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from unittest import mock
16+
17+
import pandas as pd
1518
import pytest
1619

1720
from bigframes import exceptions
18-
from bigframes.ml import llm
21+
from bigframes.ml import core, llm
1922
import bigframes.pandas as bpd
2023
from tests.system import utils
2124

@@ -372,6 +375,222 @@ def test_gemini_text_generator_multi_cols_predict_success(
372375
)
373376

374377

378+
# Overrides __eq__ function for comparing as mock.call parameter
379+
class EqCmpAllDataFrame(bpd.DataFrame):
380+
def __eq__(self, other):
381+
return self.equals(other)
382+
383+
384+
def test_gemini_text_generator_retry_success(session, bq_connection):
385+
# Requests.
386+
df0 = EqCmpAllDataFrame(
387+
{
388+
"prompt": [
389+
"What is BigQuery?",
390+
"What is BQML?",
391+
"What is BigQuery DataFrame?",
392+
]
393+
},
394+
index=[0, 1, 2],
395+
session=session,
396+
)
397+
df1 = EqCmpAllDataFrame(
398+
{
399+
"ml_generate_text_status": ["error", "error"],
400+
"prompt": [
401+
"What is BQML?",
402+
"What is BigQuery DataFrame?",
403+
],
404+
},
405+
index=[1, 2],
406+
session=session,
407+
)
408+
df2 = EqCmpAllDataFrame(
409+
{
410+
"ml_generate_text_status": ["error"],
411+
"prompt": [
412+
"What is BQML?",
413+
],
414+
},
415+
index=[1],
416+
session=session,
417+
)
418+
419+
mock_bqml_model = mock.create_autospec(spec=core.BqmlModel)
420+
type(mock_bqml_model).session = mock.PropertyMock(return_value=session)
421+
422+
# Responses. Retry twice then all succeeded.
423+
mock_bqml_model.generate_text.side_effect = [
424+
EqCmpAllDataFrame(
425+
{
426+
"ml_generate_text_status": ["", "error", "error"],
427+
"prompt": [
428+
"What is BigQuery?",
429+
"What is BQML?",
430+
"What is BigQuery DataFrame?",
431+
],
432+
},
433+
index=[0, 1, 2],
434+
session=session,
435+
),
436+
EqCmpAllDataFrame(
437+
{
438+
"ml_generate_text_status": ["error", ""],
439+
"prompt": [
440+
"What is BQML?",
441+
"What is BigQuery DataFrame?",
442+
],
443+
},
444+
index=[1, 2],
445+
session=session,
446+
),
447+
EqCmpAllDataFrame(
448+
{
449+
"ml_generate_text_status": [""],
450+
"prompt": [
451+
"What is BQML?",
452+
],
453+
},
454+
index=[1],
455+
session=session,
456+
),
457+
]
458+
options = {
459+
"temperature": 0.9,
460+
"max_output_tokens": 8192,
461+
"top_k": 40,
462+
"top_p": 1.0,
463+
"flatten_json_output": True,
464+
"ground_with_google_search": False,
465+
}
466+
467+
gemini_text_generator_model = llm.GeminiTextGenerator(
468+
connection_name=bq_connection, session=session
469+
)
470+
gemini_text_generator_model._bqml_model = mock_bqml_model
471+
472+
# 3rd retry isn't triggered
473+
result = gemini_text_generator_model.predict(df0, max_retries=3)
474+
475+
mock_bqml_model.generate_text.assert_has_calls(
476+
[
477+
mock.call(df0, options),
478+
mock.call(df1, options),
479+
mock.call(df2, options),
480+
]
481+
)
482+
pd.testing.assert_frame_equal(
483+
result.to_pandas(),
484+
pd.DataFrame(
485+
{
486+
"ml_generate_text_status": ["", "", ""],
487+
"prompt": [
488+
"What is BigQuery?",
489+
"What is BigQuery DataFrame?",
490+
"What is BQML?",
491+
],
492+
},
493+
index=[0, 2, 1],
494+
),
495+
check_dtype=False,
496+
check_index_type=False,
497+
)
498+
499+
500+
def test_gemini_text_generator_retry_no_progress(session, bq_connection):
501+
# Requests.
502+
df0 = EqCmpAllDataFrame(
503+
{
504+
"prompt": [
505+
"What is BigQuery?",
506+
"What is BQML?",
507+
"What is BigQuery DataFrame?",
508+
]
509+
},
510+
index=[0, 1, 2],
511+
session=session,
512+
)
513+
df1 = EqCmpAllDataFrame(
514+
{
515+
"ml_generate_text_status": ["error", "error"],
516+
"prompt": [
517+
"What is BQML?",
518+
"What is BigQuery DataFrame?",
519+
],
520+
},
521+
index=[1, 2],
522+
session=session,
523+
)
524+
525+
mock_bqml_model = mock.create_autospec(spec=core.BqmlModel)
526+
type(mock_bqml_model).session = mock.PropertyMock(return_value=session)
527+
# Responses. Retry once, no progress, just stop.
528+
mock_bqml_model.generate_text.side_effect = [
529+
EqCmpAllDataFrame(
530+
{
531+
"ml_generate_text_status": ["", "error", "error"],
532+
"prompt": [
533+
"What is BigQuery?",
534+
"What is BQML?",
535+
"What is BigQuery DataFrame?",
536+
],
537+
},
538+
index=[0, 1, 2],
539+
session=session,
540+
),
541+
EqCmpAllDataFrame(
542+
{
543+
"ml_generate_text_status": ["error", "error"],
544+
"prompt": [
545+
"What is BQML?",
546+
"What is BigQuery DataFrame?",
547+
],
548+
},
549+
index=[1, 2],
550+
session=session,
551+
),
552+
]
553+
options = {
554+
"temperature": 0.9,
555+
"max_output_tokens": 8192,
556+
"top_k": 40,
557+
"top_p": 1.0,
558+
"flatten_json_output": True,
559+
"ground_with_google_search": False,
560+
}
561+
562+
gemini_text_generator_model = llm.GeminiTextGenerator(
563+
connection_name=bq_connection, session=session
564+
)
565+
gemini_text_generator_model._bqml_model = mock_bqml_model
566+
567+
# No progress, only conduct retry once
568+
result = gemini_text_generator_model.predict(df0, max_retries=3)
569+
570+
mock_bqml_model.generate_text.assert_has_calls(
571+
[
572+
mock.call(df0, options),
573+
mock.call(df1, options),
574+
]
575+
)
576+
pd.testing.assert_frame_equal(
577+
result.to_pandas(),
578+
pd.DataFrame(
579+
{
580+
"ml_generate_text_status": ["", "error", "error"],
581+
"prompt": [
582+
"What is BigQuery?",
583+
"What is BQML?",
584+
"What is BigQuery DataFrame?",
585+
],
586+
},
587+
index=[0, 1, 2],
588+
),
589+
check_dtype=False,
590+
check_index_type=False,
591+
)
592+
593+
375594
@pytest.mark.flaky(retries=2)
376595
def test_llm_palm_score(llm_fine_tune_df_default_index):
377596
model = llm.PaLM2TextGenerator(model_name="text-bison")

0 commit comments

Comments
 (0)