From 598d57d24ea613130a74bf7db86c757a668626b8 Mon Sep 17 00:00:00 2001 From: Alexey Volkov Date: Fri, 13 Oct 2023 19:16:35 -0700 Subject: [PATCH] feat: LLM - Added support for multiple response candidates in code chat models PiperOrigin-RevId: 573371030 --- tests/unit/aiplatform/test_language_models.py | 51 +++++++++++++++++++ vertexai/language_models/_language_models.py | 16 ++++-- 2 files changed, 63 insertions(+), 4 deletions(-) diff --git a/tests/unit/aiplatform/test_language_models.py b/tests/unit/aiplatform/test_language_models.py index f0ca546ed1..5e231f25d5 100644 --- a/tests/unit/aiplatform/test_language_models.py +++ b/tests/unit/aiplatform/test_language_models.py @@ -2419,6 +2419,57 @@ def test_code_chat(self): assert prediction_parameters["maxDecodeSteps"] == message_max_output_tokens assert prediction_parameters["stopSequences"] == message_stop_sequences + def test_code_chat_model_send_message_with_multiple_candidates(self): + """Tests the code chat model with multiple candidates.""" + with mock.patch.object( + target=model_garden_service_client.ModelGardenServiceClient, + attribute="get_publisher_model", + return_value=gca_publisher_model.PublisherModel( + _CODECHAT_BISON_PUBLISHER_MODEL_DICT + ), + autospec=True, + ): + model = language_models.CodeChatModel.from_pretrained( + "google/codechat-bison@001" + ) + + chat = model.start_chat() + + gca_predict_response1 = gca_prediction_service.PredictResponse() + gca_predict_response1.predictions.append( + _TEST_CHAT_GENERATION_MULTI_CANDIDATE_PREDICTION + ) + + with mock.patch.object( + target=prediction_service_client.PredictionServiceClient, + attribute="predict", + return_value=gca_predict_response1, + autospec=True, + ): + message_text1 = "Are my favorite movies based on a book series?" + expected_response_candidates = ( + _TEST_CHAT_GENERATION_MULTI_CANDIDATE_PREDICTION["candidates"] + ) + expected_candidate_0 = expected_response_candidates[0]["content"] + expected_candidate_1 = expected_response_candidates[1]["content"] + + response = chat.send_message( + message=message_text1, + # candidate_count acts as a maximum number, not exact number. + candidate_count=7, + ) + # The service can return a different number of candidates. + assert response.text == expected_candidate_0 + assert len(response.candidates) == 2 + assert response.candidates[0].text == expected_candidate_0 + assert response.candidates[1].text == expected_candidate_1 + + assert len(chat.message_history) == 2 + assert chat.message_history[0].author == chat.USER_AUTHOR + assert chat.message_history[0].content == message_text1 + assert chat.message_history[1].author == chat.MODEL_AUTHOR + assert chat.message_history[1].content == expected_candidate_0 + def test_code_chat_model_send_message_streaming(self): """Tests the chat generation model.""" aiplatform.init( diff --git a/vertexai/language_models/_language_models.py b/vertexai/language_models/_language_models.py index 8dd422a336..618e7bbbaa 100644 --- a/vertexai/language_models/_language_models.py +++ b/vertexai/language_models/_language_models.py @@ -2112,7 +2112,8 @@ def send_message( max_output_tokens: Optional[int] = None, temperature: Optional[float] = None, stop_sequences: Optional[List[str]] = None, - ) -> "TextGenerationResponse": + candidate_count: Optional[int] = None, + ) -> "MultiCandidateTextGenerationResponse": """Sends message to the code chat model and gets a response. Args: @@ -2122,15 +2123,18 @@ def send_message( temperature: Controls the randomness of predictions. Range: [0, 1]. Uses the value specified when calling `CodeChatModel.start_chat` by default. stop_sequences: Customized stop sequences to stop the decoding process. + candidate_count: Number of candidates to return. Returns: - A `TextGenerationResponse` object that contains the text produced by the model. + A `MultiCandidateTextGenerationResponse` object that contains the + text produced by the model. """ return super().send_message( message=message, max_output_tokens=max_output_tokens, temperature=temperature, stop_sequences=stop_sequences, + candidate_count=candidate_count, ) async def send_message_async( @@ -2139,7 +2143,8 @@ async def send_message_async( *, max_output_tokens: Optional[int] = None, temperature: Optional[float] = None, - ) -> "TextGenerationResponse": + candidate_count: Optional[int] = None, + ) -> "MultiCandidateTextGenerationResponse": """Asynchronously sends message to the code chat model and gets a response. Args: @@ -2148,14 +2153,17 @@ async def send_message_async( Uses the value specified when calling `CodeChatModel.start_chat` by default. temperature: Controls the randomness of predictions. Range: [0, 1]. Uses the value specified when calling `CodeChatModel.start_chat` by default. + candidate_count: Number of candidates to return. Returns: - A `TextGenerationResponse` object that contains the text produced by the model. + A `MultiCandidateTextGenerationResponse` object that contains the + text produced by the model. """ return super().send_message_async( message=message, max_output_tokens=max_output_tokens, temperature=temperature, + candidate_count=candidate_count, ) def send_message_streaming(