Skip to content

Commit

Permalink
feat: LLM - include error code into blocked response from TextGenerat…
Browse files Browse the repository at this point in the history
…ionModel, ChatModel, CodeChatMode, and CodeGenerationModel.

PiperOrigin-RevId: 582832899
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Nov 15, 2023
1 parent 469c595 commit 1f81cf2
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
15 changes: 14 additions & 1 deletion tests/unit/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,8 @@
_TEST_TEXT_GENERATION_PREDICTION = {
"safetyAttributes": {
"categories": ["Violent"],
"blocked": False,
"blocked": True,
"errors": [100],
"scores": [0.10000000149011612],
},
"content": """
Expand Down Expand Up @@ -254,6 +255,7 @@
},
"safetyAttributes": {
"blocked": True,
"errors": [100],
"categories": ["Finance"],
"scores": [0.1],
},
Expand Down Expand Up @@ -301,6 +303,7 @@
"scores": [0.1],
"categories": ["Finance"],
"blocked": True,
"errors": [100],
},
],
"candidates": [
Expand All @@ -326,6 +329,7 @@
"scores": [0.1],
"categories": ["Finance"],
"blocked": True,
"errors": [100],
},
],
"groundingMetadata": [
Expand Down Expand Up @@ -373,6 +377,7 @@
"scores": [0.1],
"categories": ["Finance"],
"blocked": True,
"errors": [100],
},
],
"groundingMetadata": [
Expand Down Expand Up @@ -430,6 +435,7 @@
"safetyAttributes": [
{
"blocked": True,
"errors": [100],
"categories": ["Finance"],
"scores": [0.1],
}
Expand All @@ -440,6 +446,7 @@
_TEST_CODE_GENERATION_PREDICTION = {
"safetyAttributes": {
"blocked": True,
"errors": [100],
"categories": ["Finance"],
"scores": [0.1],
},
Expand Down Expand Up @@ -1478,13 +1485,15 @@ def test_text_generation_ga(self):
stop_sequences=["\n"],
)

expected_errors = (100,)
prediction_parameters = mock_predict.call_args[1]["parameters"]
assert prediction_parameters["maxDecodeSteps"] == 128
assert prediction_parameters["temperature"] == 0.0
assert prediction_parameters["topP"] == 1.0
assert prediction_parameters["topK"] == 5
assert prediction_parameters["stopSequences"] == ["\n"]
assert response.text == _TEST_TEXT_GENERATION_PREDICTION["content"]
assert response.errors == expected_errors

# Validating that unspecified parameters are not passed to the model
# (except `max_output_tokens`).
Expand Down Expand Up @@ -2893,12 +2902,16 @@ def test_chat_model_send_message_with_multiple_candidates(self):
)
expected_candidate_0 = expected_response_candidates[0]["content"]
expected_candidate_1 = expected_response_candidates[1]["content"]
expected_errors_0 = ()
expected_errors_1 = (100,)

response = chat.send_message(message_text1, candidate_count=2)
assert response.text == expected_candidate_0
assert len(response.candidates) == 2
assert response.candidates[0].text == expected_candidate_0
assert response.candidates[1].text == expected_candidate_1
assert response.candidates[0].errors == expected_errors_0
assert response.candidates[1].errors == expected_errors_1

assert len(chat.message_history) == 2
assert chat.message_history[0].author == chat.USER_AUTHOR
Expand Down
15 changes: 15 additions & 0 deletions vertexai/language_models/_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Literal,
Optional,
Sequence,
Tuple,
Union,
)
import warnings
Expand Down Expand Up @@ -859,6 +860,9 @@ class TextGenerationResponse:
Attributes:
text: The generated text
is_blocked: Whether the the request was blocked.
errors: The error codes indicate why the response was blocked.
Learn more information about safety errors here:
this documentation https://cloud.google.com/vertex-ai/docs/generative-ai/learn/responsible-ai#safety_errors
safety_attributes: Scores for safety attributes.
Learn more about the safety attributes here:
https://cloud.google.com/vertex-ai/docs/generative-ai/learn/responsible-ai#safety_attribute_descriptions
Expand All @@ -870,6 +874,7 @@ class TextGenerationResponse:
text: str
_prediction_response: Any
is_blocked: bool = False
errors: Tuple[int] = tuple()
safety_attributes: Dict[str, float] = dataclasses.field(default_factory=dict)
grounding_metadata: Optional[GroundingMetadata] = None

Expand All @@ -882,6 +887,7 @@ def __repr__(self):
"TextGenerationResponse("
f"text={self.text!r}"
f", is_blocked={self.is_blocked!r}"
f", errors={self.errors!r}"
f", safety_attributes={self.safety_attributes!r}"
f", grounding_metadata={self.grounding_metadata!r}"
")"
Expand All @@ -891,6 +897,7 @@ def __repr__(self):
"TextGenerationResponse("
f"text={self.text!r}"
f", is_blocked={self.is_blocked!r}"
f", errors={self.errors!r}"
f", safety_attributes={self.safety_attributes!r}"
")"
)
Expand Down Expand Up @@ -1216,10 +1223,13 @@ def _parse_text_generation_model_response(
prediction = prediction_response.predictions[prediction_idx]
safety_attributes_dict = prediction.get("safetyAttributes", {})
grounding_metadata_dict = prediction.get("groundingMetadata", {})
errors_list = safety_attributes_dict.get("errors", [])
errors = tuple(map(int, errors_list))
return TextGenerationResponse(
text=prediction["content"],
_prediction_response=prediction_response,
is_blocked=safety_attributes_dict.get("blocked", False),
errors=errors,
safety_attributes=dict(
zip(
safety_attributes_dict.get("categories") or [],
Expand Down Expand Up @@ -1251,6 +1261,7 @@ def _parse_text_generation_model_multi_candidate_response(
text=candidates[0].text,
_prediction_response=prediction_response,
is_blocked=candidates[0].is_blocked,
errors=candidates[0].errors,
safety_attributes=candidates[0].safety_attributes,
grounding_metadata=candidates[0].grounding_metadata,
candidates=candidates,
Expand Down Expand Up @@ -2090,13 +2101,16 @@ def _parse_chat_prediction_response(
grounding_metadata_list = prediction.get("groundingMetadata")
for candidate_idx in range(candidate_count):
safety_attributes = prediction["safetyAttributes"][candidate_idx]
errors_list = safety_attributes.get("errors", [])
errors = tuple(map(int, errors_list))
grounding_metadata_dict = {}
if grounding_metadata_list and grounding_metadata_list[candidate_idx]:
grounding_metadata_dict = grounding_metadata_list[candidate_idx]
candidate_response = TextGenerationResponse(
text=prediction["candidates"][candidate_idx]["content"],
_prediction_response=prediction_response,
is_blocked=safety_attributes.get("blocked", False),
errors=errors,
safety_attributes=dict(
zip(
# Unlike with normal prediction, in streaming prediction
Expand All @@ -2112,6 +2126,7 @@ def _parse_chat_prediction_response(
text=candidates[0].text,
_prediction_response=prediction_response,
is_blocked=candidates[0].is_blocked,
errors=candidates[0].errors,
safety_attributes=candidates[0].safety_attributes,
grounding_metadata=candidates[0].grounding_metadata,
candidates=candidates,
Expand Down

0 comments on commit 1f81cf2

Please sign in to comment.