Skip to content

Commit

Permalink
chore: Add headers for CPR model server errors. (#1701)
Browse files Browse the repository at this point in the history
* chore: Add headers for CPR model server errors.

* chore: Fix comments.
  • Loading branch information
abcdefgs0324 authored Sep 28, 2022
1 parent 926d0b6 commit 05efff7
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 15 deletions.
1 change: 1 addition & 0 deletions google/cloud/aiplatform/constants/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@
DEFAULT_LOCAL_RUN_GPU_COUNT = -1

CUSTOM_PREDICTION_ROUTINES = "custom-prediction-routines"
CUSTOM_PREDICTION_ROUTINES_SERVER_ERROR_HEADER_KEY = "X-AIP-CPR-SYSTEM-ERROR"

# Headers' related constants for the handler usage.
CONTENT_TYPE_HEADER_REGEX = re.compile("^[Cc]ontent-?[Tt]ype$")
Expand Down
27 changes: 24 additions & 3 deletions google/cloud/aiplatform/prediction/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@
#

from abc import ABC, abstractmethod
import logging
from typing import Optional, Type
import traceback

try:
from fastapi import HTTPException
from fastapi import Request
from fastapi import Response
except ImportError:
Expand Down Expand Up @@ -103,14 +106,32 @@ async def handle(self, request: Request) -> Response:
Returns:
The response of the prediction request.
Raises:
HTTPException: If any exception is thrown from predictor object.
"""
request_body = await request.body()
content_type = handler_utils.get_content_type_from_headers(request.headers)
prediction_input = DefaultSerializer.deserialize(request_body, content_type)

prediction_results = self._predictor.postprocess(
self._predictor.predict(self._predictor.preprocess(prediction_input))
)
try:
prediction_results = self._predictor.postprocess(
self._predictor.predict(self._predictor.preprocess(prediction_input))
)
except HTTPException:
raise
except Exception as exception:
error_message = (
"The following exception has occurred: {}. Arguments: {}.".format(
type(exception).__name__, exception.args
)
)
logging.info(
"{}\\nTraceback: {}".format(error_message, traceback.format_exc())
)

# Converts all other exceptions to HTTPException.
raise HTTPException(status_code=500, detail=error_message)

accept = handler_utils.get_accept_from_headers(request.headers)
data = DefaultSerializer.serialize(prediction_results, accept)
Expand Down
14 changes: 14 additions & 0 deletions google/cloud/aiplatform/prediction/model_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
'Please install the SDK using `pip install "google-cloud-aiplatform[prediction]>=1.16.0"`.'
)

from google.cloud.aiplatform.constants import prediction
from google.cloud.aiplatform import version


class CprModelServer:
"""Model server to do custom prediction routines."""
Expand All @@ -61,6 +64,9 @@ def __init__(self):
)
handler_module = importlib.import_module(os.environ.get("HANDLER_MODULE"))
handler_class = getattr(handler_module, os.environ.get("HANDLER_CLASS"))
self.is_default_handler = (
handler_module == "google.cloud.aiplatform.prediction.handler"
)

predictor_class = None
if "PREDICTOR_MODULE" in os.environ:
Expand Down Expand Up @@ -145,6 +151,14 @@ async def predict(self, request: Request) -> Response:
)

# Converts all other exceptions to HTTPException.
if self.is_default_handler:
raise HTTPException(
status_code=500,
detail=error_message,
headers={
prediction.CUSTOM_PREDICTION_ROUTINES_SERVER_ERROR_HEADER_KEY: version.__version__
},
)
raise HTTPException(status_code=500, detail=error_message)


Expand Down
91 changes: 79 additions & 12 deletions tests/unit/aiplatform/test_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,10 +821,10 @@ async def test_handle_deserialize_raises_exception(
with pytest.raises(HTTPException):
await handler.handle(get_test_request())

get_content_type_from_headers_mock.assert_called_once_with(get_test_headers())
deserialize_exception_mock.assert_called_once_with(
_TEST_INPUT, _APPLICATION_JSON
)
get_content_type_from_headers_mock.assert_called_once_with(get_test_headers())
assert not predictor_mock().preprocess.called
assert not predictor_mock().predict.called
assert not predictor_mock().postprocess.called
Expand All @@ -845,25 +845,66 @@ async def test_handle_predictor_raises_exception(
handler = PredictionHandler(
_TEST_GCS_ARTIFACTS_URI, predictor=get_test_predictor()
)
expected_message = (
"The following exception has occurred: Exception. Arguments: ()."
)

with mock.patch.multiple(
handler._predictor,
preprocess=preprocess_mock,
predict=predict_mock,
postprocess=postprocess_mock,
):
with pytest.raises(Exception):
with pytest.raises(HTTPException) as exception:
await handler.handle(get_test_request())

deserialize_mock.assert_called_once_with(_TEST_INPUT, _APPLICATION_JSON)
get_content_type_from_headers_mock.assert_called_once_with(
get_test_headers()
)
preprocess_mock.assert_called_once_with(_TEST_DESERIALIZED_INPUT)
predict_mock.assert_called_once_with(_TEST_DESERIALIZED_INPUT)
assert not postprocess_mock.called
assert not get_accept_from_headers_mock.called
assert not serialize_mock.called
assert exception.value.status_code == 500
assert exception.value.detail == expected_message
get_content_type_from_headers_mock.assert_called_once_with(get_test_headers())
deserialize_mock.assert_called_once_with(_TEST_INPUT, _APPLICATION_JSON)
preprocess_mock.assert_called_once_with(_TEST_DESERIALIZED_INPUT)
predict_mock.assert_called_once_with(_TEST_DESERIALIZED_INPUT)
assert not postprocess_mock.called
assert not get_accept_from_headers_mock.called
assert not serialize_mock.called

@pytest.mark.asyncio
async def test_handle_predictor_raises_http_exception(
self,
get_content_type_from_headers_mock,
deserialize_mock,
get_accept_from_headers_mock,
serialize_mock,
):
status_code = 400
expected_message = "This is an user error."
preprocess_mock = mock.MagicMock(return_value=_TEST_DESERIALIZED_INPUT)
predict_mock = mock.MagicMock(
side_effect=HTTPException(status_code=status_code, detail=expected_message)
)
postprocess_mock = mock.MagicMock(return_value=_TEST_SERIALIZED_OUTPUT)
handler = PredictionHandler(
_TEST_GCS_ARTIFACTS_URI, predictor=get_test_predictor()
)

with mock.patch.multiple(
handler._predictor,
preprocess=preprocess_mock,
predict=predict_mock,
postprocess=postprocess_mock,
):
with pytest.raises(HTTPException) as exception:
await handler.handle(get_test_request())

assert exception.value.status_code == status_code
assert exception.value.detail == expected_message
get_content_type_from_headers_mock.assert_called_once_with(get_test_headers())
deserialize_mock.assert_called_once_with(_TEST_INPUT, _APPLICATION_JSON)
preprocess_mock.assert_called_once_with(_TEST_DESERIALIZED_INPUT)
predict_mock.assert_called_once_with(_TEST_DESERIALIZED_INPUT)
assert not postprocess_mock.called
assert not get_accept_from_headers_mock.called
assert not serialize_mock.called

@pytest.mark.asyncio
async def test_handle_serialize_raises_exception(
Expand Down Expand Up @@ -3173,20 +3214,46 @@ def test_predict_thorws_http_exception(
assert response.status_code == 400
assert json.loads(response.content)["detail"] == expected_message

def test_predict_thorws_exceptions_not_http_exception(
def test_predict_thorws_exceptions_not_http_exception_default_handler(
self, model_server_env_mock, importlib_import_module_mock_twice
):
expected_message = (
"An exception ValueError occurred. Arguments: ('Not a correct value.',)."
)
model_server = CprModelServer()
model_server.is_default_handler = True
client = TestClient(model_server.app)

with mock.patch.object(model_server.handler, "handle") as handle_mock:
handle_mock.side_effect = ValueError("Not a correct value.")

response = client.post(_TEST_AIP_PREDICT_ROUTE, json={"x": [1]})

assert (
prediction.CUSTOM_PREDICTION_ROUTINES_SERVER_ERROR_HEADER_KEY
in response.headers
)
assert response.status_code == 500
assert json.loads(response.content)["detail"] == expected_message

def test_predict_thorws_exceptions_not_http_exception_not_default_handler(
self, model_server_env_mock, importlib_import_module_mock_twice
):
expected_message = (
"An exception ValueError occurred. Arguments: ('Not a correct value.',)."
)
model_server = CprModelServer()
client = TestClient(model_server.app)

with mock.patch.object(model_server.handler, "handle") as handle_mock:
handle_mock.side_effect = ValueError("Not a correct value.")

response = client.post(_TEST_AIP_PREDICT_ROUTE, json={"x": [1]})

assert (
prediction.CUSTOM_PREDICTION_ROUTINES_SERVER_ERROR_HEADER_KEY
not in response.headers
)
assert response.status_code == 500
assert json.loads(response.content)["detail"] == expected_message

Expand Down

0 comments on commit 05efff7

Please sign in to comment.