From 1f81cf200c9394b50a43c3830ab8343ead1dc0d3 Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Wed, 15 Nov 2023 15:53:51 -0800 Subject: [PATCH] feat: LLM - include error code into blocked response from TextGenerationModel, ChatModel, CodeChatMode, and CodeGenerationModel. PiperOrigin-RevId: 582832899 --- tests/unit/aiplatform/test_language_models.py | 15 ++++++++++++++- vertexai/language_models/_language_models.py | 15 +++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/tests/unit/aiplatform/test_language_models.py b/tests/unit/aiplatform/test_language_models.py index ca3c033e9a..7d6134e67d 100644 --- a/tests/unit/aiplatform/test_language_models.py +++ b/tests/unit/aiplatform/test_language_models.py @@ -221,7 +221,8 @@ _TEST_TEXT_GENERATION_PREDICTION = { "safetyAttributes": { "categories": ["Violent"], - "blocked": False, + "blocked": True, + "errors": [100], "scores": [0.10000000149011612], }, "content": """ @@ -254,6 +255,7 @@ }, "safetyAttributes": { "blocked": True, + "errors": [100], "categories": ["Finance"], "scores": [0.1], }, @@ -301,6 +303,7 @@ "scores": [0.1], "categories": ["Finance"], "blocked": True, + "errors": [100], }, ], "candidates": [ @@ -326,6 +329,7 @@ "scores": [0.1], "categories": ["Finance"], "blocked": True, + "errors": [100], }, ], "groundingMetadata": [ @@ -373,6 +377,7 @@ "scores": [0.1], "categories": ["Finance"], "blocked": True, + "errors": [100], }, ], "groundingMetadata": [ @@ -430,6 +435,7 @@ "safetyAttributes": [ { "blocked": True, + "errors": [100], "categories": ["Finance"], "scores": [0.1], } @@ -440,6 +446,7 @@ _TEST_CODE_GENERATION_PREDICTION = { "safetyAttributes": { "blocked": True, + "errors": [100], "categories": ["Finance"], "scores": [0.1], }, @@ -1478,6 +1485,7 @@ def test_text_generation_ga(self): stop_sequences=["\n"], ) + expected_errors = (100,) prediction_parameters = mock_predict.call_args[1]["parameters"] assert prediction_parameters["maxDecodeSteps"] == 128 assert prediction_parameters["temperature"] == 0.0 @@ -1485,6 +1493,7 @@ def test_text_generation_ga(self): assert prediction_parameters["topK"] == 5 assert prediction_parameters["stopSequences"] == ["\n"] assert response.text == _TEST_TEXT_GENERATION_PREDICTION["content"] + assert response.errors == expected_errors # Validating that unspecified parameters are not passed to the model # (except `max_output_tokens`). @@ -2893,12 +2902,16 @@ def test_chat_model_send_message_with_multiple_candidates(self): ) expected_candidate_0 = expected_response_candidates[0]["content"] expected_candidate_1 = expected_response_candidates[1]["content"] + expected_errors_0 = () + expected_errors_1 = (100,) response = chat.send_message(message_text1, candidate_count=2) assert response.text == expected_candidate_0 assert len(response.candidates) == 2 assert response.candidates[0].text == expected_candidate_0 assert response.candidates[1].text == expected_candidate_1 + assert response.candidates[0].errors == expected_errors_0 + assert response.candidates[1].errors == expected_errors_1 assert len(chat.message_history) == 2 assert chat.message_history[0].author == chat.USER_AUTHOR diff --git a/vertexai/language_models/_language_models.py b/vertexai/language_models/_language_models.py index cb975fce8b..f5ed597aba 100644 --- a/vertexai/language_models/_language_models.py +++ b/vertexai/language_models/_language_models.py @@ -25,6 +25,7 @@ Literal, Optional, Sequence, + Tuple, Union, ) import warnings @@ -859,6 +860,9 @@ class TextGenerationResponse: Attributes: text: The generated text is_blocked: Whether the the request was blocked. + errors: The error codes indicate why the response was blocked. + Learn more information about safety errors here: + this documentation https://cloud.google.com/vertex-ai/docs/generative-ai/learn/responsible-ai#safety_errors safety_attributes: Scores for safety attributes. Learn more about the safety attributes here: https://cloud.google.com/vertex-ai/docs/generative-ai/learn/responsible-ai#safety_attribute_descriptions @@ -870,6 +874,7 @@ class TextGenerationResponse: text: str _prediction_response: Any is_blocked: bool = False + errors: Tuple[int] = tuple() safety_attributes: Dict[str, float] = dataclasses.field(default_factory=dict) grounding_metadata: Optional[GroundingMetadata] = None @@ -882,6 +887,7 @@ def __repr__(self): "TextGenerationResponse(" f"text={self.text!r}" f", is_blocked={self.is_blocked!r}" + f", errors={self.errors!r}" f", safety_attributes={self.safety_attributes!r}" f", grounding_metadata={self.grounding_metadata!r}" ")" @@ -891,6 +897,7 @@ def __repr__(self): "TextGenerationResponse(" f"text={self.text!r}" f", is_blocked={self.is_blocked!r}" + f", errors={self.errors!r}" f", safety_attributes={self.safety_attributes!r}" ")" ) @@ -1216,10 +1223,13 @@ def _parse_text_generation_model_response( prediction = prediction_response.predictions[prediction_idx] safety_attributes_dict = prediction.get("safetyAttributes", {}) grounding_metadata_dict = prediction.get("groundingMetadata", {}) + errors_list = safety_attributes_dict.get("errors", []) + errors = tuple(map(int, errors_list)) return TextGenerationResponse( text=prediction["content"], _prediction_response=prediction_response, is_blocked=safety_attributes_dict.get("blocked", False), + errors=errors, safety_attributes=dict( zip( safety_attributes_dict.get("categories") or [], @@ -1251,6 +1261,7 @@ def _parse_text_generation_model_multi_candidate_response( text=candidates[0].text, _prediction_response=prediction_response, is_blocked=candidates[0].is_blocked, + errors=candidates[0].errors, safety_attributes=candidates[0].safety_attributes, grounding_metadata=candidates[0].grounding_metadata, candidates=candidates, @@ -2090,6 +2101,8 @@ def _parse_chat_prediction_response( grounding_metadata_list = prediction.get("groundingMetadata") for candidate_idx in range(candidate_count): safety_attributes = prediction["safetyAttributes"][candidate_idx] + errors_list = safety_attributes.get("errors", []) + errors = tuple(map(int, errors_list)) grounding_metadata_dict = {} if grounding_metadata_list and grounding_metadata_list[candidate_idx]: grounding_metadata_dict = grounding_metadata_list[candidate_idx] @@ -2097,6 +2110,7 @@ def _parse_chat_prediction_response( text=prediction["candidates"][candidate_idx]["content"], _prediction_response=prediction_response, is_blocked=safety_attributes.get("blocked", False), + errors=errors, safety_attributes=dict( zip( # Unlike with normal prediction, in streaming prediction @@ -2112,6 +2126,7 @@ def _parse_chat_prediction_response( text=candidates[0].text, _prediction_response=prediction_response, is_blocked=candidates[0].is_blocked, + errors=candidates[0].errors, safety_attributes=candidates[0].safety_attributes, grounding_metadata=candidates[0].grounding_metadata, candidates=candidates,