|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
| 15 | +from unittest import mock |
| 16 | + |
| 17 | +import pandas as pd |
15 | 18 | import pytest
|
16 | 19 |
|
17 | 20 | from bigframes import exceptions
|
18 |
| -from bigframes.ml import llm |
| 21 | +from bigframes.ml import core, llm |
19 | 22 | import bigframes.pandas as bpd
|
20 | 23 | from tests.system import utils
|
21 | 24 |
|
@@ -372,6 +375,222 @@ def test_gemini_text_generator_multi_cols_predict_success(
|
372 | 375 | )
|
373 | 376 |
|
374 | 377 |
|
| 378 | +# Overrides __eq__ function for comparing as mock.call parameter |
| 379 | +class EqCmpAllDataFrame(bpd.DataFrame): |
| 380 | + def __eq__(self, other): |
| 381 | + return self.equals(other) |
| 382 | + |
| 383 | + |
| 384 | +def test_gemini_text_generator_retry_success(session, bq_connection): |
| 385 | + # Requests. |
| 386 | + df0 = EqCmpAllDataFrame( |
| 387 | + { |
| 388 | + "prompt": [ |
| 389 | + "What is BigQuery?", |
| 390 | + "What is BQML?", |
| 391 | + "What is BigQuery DataFrame?", |
| 392 | + ] |
| 393 | + }, |
| 394 | + index=[0, 1, 2], |
| 395 | + session=session, |
| 396 | + ) |
| 397 | + df1 = EqCmpAllDataFrame( |
| 398 | + { |
| 399 | + "ml_generate_text_status": ["error", "error"], |
| 400 | + "prompt": [ |
| 401 | + "What is BQML?", |
| 402 | + "What is BigQuery DataFrame?", |
| 403 | + ], |
| 404 | + }, |
| 405 | + index=[1, 2], |
| 406 | + session=session, |
| 407 | + ) |
| 408 | + df2 = EqCmpAllDataFrame( |
| 409 | + { |
| 410 | + "ml_generate_text_status": ["error"], |
| 411 | + "prompt": [ |
| 412 | + "What is BQML?", |
| 413 | + ], |
| 414 | + }, |
| 415 | + index=[1], |
| 416 | + session=session, |
| 417 | + ) |
| 418 | + |
| 419 | + mock_bqml_model = mock.create_autospec(spec=core.BqmlModel) |
| 420 | + type(mock_bqml_model).session = mock.PropertyMock(return_value=session) |
| 421 | + |
| 422 | + # Responses. Retry twice then all succeeded. |
| 423 | + mock_bqml_model.generate_text.side_effect = [ |
| 424 | + EqCmpAllDataFrame( |
| 425 | + { |
| 426 | + "ml_generate_text_status": ["", "error", "error"], |
| 427 | + "prompt": [ |
| 428 | + "What is BigQuery?", |
| 429 | + "What is BQML?", |
| 430 | + "What is BigQuery DataFrame?", |
| 431 | + ], |
| 432 | + }, |
| 433 | + index=[0, 1, 2], |
| 434 | + session=session, |
| 435 | + ), |
| 436 | + EqCmpAllDataFrame( |
| 437 | + { |
| 438 | + "ml_generate_text_status": ["error", ""], |
| 439 | + "prompt": [ |
| 440 | + "What is BQML?", |
| 441 | + "What is BigQuery DataFrame?", |
| 442 | + ], |
| 443 | + }, |
| 444 | + index=[1, 2], |
| 445 | + session=session, |
| 446 | + ), |
| 447 | + EqCmpAllDataFrame( |
| 448 | + { |
| 449 | + "ml_generate_text_status": [""], |
| 450 | + "prompt": [ |
| 451 | + "What is BQML?", |
| 452 | + ], |
| 453 | + }, |
| 454 | + index=[1], |
| 455 | + session=session, |
| 456 | + ), |
| 457 | + ] |
| 458 | + options = { |
| 459 | + "temperature": 0.9, |
| 460 | + "max_output_tokens": 8192, |
| 461 | + "top_k": 40, |
| 462 | + "top_p": 1.0, |
| 463 | + "flatten_json_output": True, |
| 464 | + "ground_with_google_search": False, |
| 465 | + } |
| 466 | + |
| 467 | + gemini_text_generator_model = llm.GeminiTextGenerator( |
| 468 | + connection_name=bq_connection, session=session |
| 469 | + ) |
| 470 | + gemini_text_generator_model._bqml_model = mock_bqml_model |
| 471 | + |
| 472 | + # 3rd retry isn't triggered |
| 473 | + result = gemini_text_generator_model.predict(df0, max_retries=3) |
| 474 | + |
| 475 | + mock_bqml_model.generate_text.assert_has_calls( |
| 476 | + [ |
| 477 | + mock.call(df0, options), |
| 478 | + mock.call(df1, options), |
| 479 | + mock.call(df2, options), |
| 480 | + ] |
| 481 | + ) |
| 482 | + pd.testing.assert_frame_equal( |
| 483 | + result.to_pandas(), |
| 484 | + pd.DataFrame( |
| 485 | + { |
| 486 | + "ml_generate_text_status": ["", "", ""], |
| 487 | + "prompt": [ |
| 488 | + "What is BigQuery?", |
| 489 | + "What is BigQuery DataFrame?", |
| 490 | + "What is BQML?", |
| 491 | + ], |
| 492 | + }, |
| 493 | + index=[0, 2, 1], |
| 494 | + ), |
| 495 | + check_dtype=False, |
| 496 | + check_index_type=False, |
| 497 | + ) |
| 498 | + |
| 499 | + |
| 500 | +def test_gemini_text_generator_retry_no_progress(session, bq_connection): |
| 501 | + # Requests. |
| 502 | + df0 = EqCmpAllDataFrame( |
| 503 | + { |
| 504 | + "prompt": [ |
| 505 | + "What is BigQuery?", |
| 506 | + "What is BQML?", |
| 507 | + "What is BigQuery DataFrame?", |
| 508 | + ] |
| 509 | + }, |
| 510 | + index=[0, 1, 2], |
| 511 | + session=session, |
| 512 | + ) |
| 513 | + df1 = EqCmpAllDataFrame( |
| 514 | + { |
| 515 | + "ml_generate_text_status": ["error", "error"], |
| 516 | + "prompt": [ |
| 517 | + "What is BQML?", |
| 518 | + "What is BigQuery DataFrame?", |
| 519 | + ], |
| 520 | + }, |
| 521 | + index=[1, 2], |
| 522 | + session=session, |
| 523 | + ) |
| 524 | + |
| 525 | + mock_bqml_model = mock.create_autospec(spec=core.BqmlModel) |
| 526 | + type(mock_bqml_model).session = mock.PropertyMock(return_value=session) |
| 527 | + # Responses. Retry once, no progress, just stop. |
| 528 | + mock_bqml_model.generate_text.side_effect = [ |
| 529 | + EqCmpAllDataFrame( |
| 530 | + { |
| 531 | + "ml_generate_text_status": ["", "error", "error"], |
| 532 | + "prompt": [ |
| 533 | + "What is BigQuery?", |
| 534 | + "What is BQML?", |
| 535 | + "What is BigQuery DataFrame?", |
| 536 | + ], |
| 537 | + }, |
| 538 | + index=[0, 1, 2], |
| 539 | + session=session, |
| 540 | + ), |
| 541 | + EqCmpAllDataFrame( |
| 542 | + { |
| 543 | + "ml_generate_text_status": ["error", "error"], |
| 544 | + "prompt": [ |
| 545 | + "What is BQML?", |
| 546 | + "What is BigQuery DataFrame?", |
| 547 | + ], |
| 548 | + }, |
| 549 | + index=[1, 2], |
| 550 | + session=session, |
| 551 | + ), |
| 552 | + ] |
| 553 | + options = { |
| 554 | + "temperature": 0.9, |
| 555 | + "max_output_tokens": 8192, |
| 556 | + "top_k": 40, |
| 557 | + "top_p": 1.0, |
| 558 | + "flatten_json_output": True, |
| 559 | + "ground_with_google_search": False, |
| 560 | + } |
| 561 | + |
| 562 | + gemini_text_generator_model = llm.GeminiTextGenerator( |
| 563 | + connection_name=bq_connection, session=session |
| 564 | + ) |
| 565 | + gemini_text_generator_model._bqml_model = mock_bqml_model |
| 566 | + |
| 567 | + # No progress, only conduct retry once |
| 568 | + result = gemini_text_generator_model.predict(df0, max_retries=3) |
| 569 | + |
| 570 | + mock_bqml_model.generate_text.assert_has_calls( |
| 571 | + [ |
| 572 | + mock.call(df0, options), |
| 573 | + mock.call(df1, options), |
| 574 | + ] |
| 575 | + ) |
| 576 | + pd.testing.assert_frame_equal( |
| 577 | + result.to_pandas(), |
| 578 | + pd.DataFrame( |
| 579 | + { |
| 580 | + "ml_generate_text_status": ["", "error", "error"], |
| 581 | + "prompt": [ |
| 582 | + "What is BigQuery?", |
| 583 | + "What is BQML?", |
| 584 | + "What is BigQuery DataFrame?", |
| 585 | + ], |
| 586 | + }, |
| 587 | + index=[0, 1, 2], |
| 588 | + ), |
| 589 | + check_dtype=False, |
| 590 | + check_index_type=False, |
| 591 | + ) |
| 592 | + |
| 593 | + |
375 | 594 | @pytest.mark.flaky(retries=2)
|
376 | 595 | def test_llm_palm_score(llm_fine_tune_df_default_index):
|
377 | 596 | model = llm.PaLM2TextGenerator(model_name="text-bison")
|
|
0 commit comments