Skip to content

Commit 05efff7

Browse files
authored
chore: Add headers for CPR model server errors. (#1701)
* chore: Add headers for CPR model server errors. * chore: Fix comments.
1 parent 926d0b6 commit 05efff7

File tree

4 files changed

+118
-15
lines changed

4 files changed

+118
-15
lines changed

google/cloud/aiplatform/constants/prediction.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@
176176
DEFAULT_LOCAL_RUN_GPU_COUNT = -1
177177

178178
CUSTOM_PREDICTION_ROUTINES = "custom-prediction-routines"
179+
CUSTOM_PREDICTION_ROUTINES_SERVER_ERROR_HEADER_KEY = "X-AIP-CPR-SYSTEM-ERROR"
179180

180181
# Headers' related constants for the handler usage.
181182
CONTENT_TYPE_HEADER_REGEX = re.compile("^[Cc]ontent-?[Tt]ype$")

google/cloud/aiplatform/prediction/handler.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,12 @@
1616
#
1717

1818
from abc import ABC, abstractmethod
19+
import logging
1920
from typing import Optional, Type
21+
import traceback
2022

2123
try:
24+
from fastapi import HTTPException
2225
from fastapi import Request
2326
from fastapi import Response
2427
except ImportError:
@@ -103,14 +106,32 @@ async def handle(self, request: Request) -> Response:
103106
104107
Returns:
105108
The response of the prediction request.
109+
110+
Raises:
111+
HTTPException: If any exception is thrown from predictor object.
106112
"""
107113
request_body = await request.body()
108114
content_type = handler_utils.get_content_type_from_headers(request.headers)
109115
prediction_input = DefaultSerializer.deserialize(request_body, content_type)
110116

111-
prediction_results = self._predictor.postprocess(
112-
self._predictor.predict(self._predictor.preprocess(prediction_input))
113-
)
117+
try:
118+
prediction_results = self._predictor.postprocess(
119+
self._predictor.predict(self._predictor.preprocess(prediction_input))
120+
)
121+
except HTTPException:
122+
raise
123+
except Exception as exception:
124+
error_message = (
125+
"The following exception has occurred: {}. Arguments: {}.".format(
126+
type(exception).__name__, exception.args
127+
)
128+
)
129+
logging.info(
130+
"{}\\nTraceback: {}".format(error_message, traceback.format_exc())
131+
)
132+
133+
# Converts all other exceptions to HTTPException.
134+
raise HTTPException(status_code=500, detail=error_message)
114135

115136
accept = handler_utils.get_accept_from_headers(request.headers)
116137
data = DefaultSerializer.serialize(prediction_results, accept)

google/cloud/aiplatform/prediction/model_server.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@
4040
'Please install the SDK using `pip install "google-cloud-aiplatform[prediction]>=1.16.0"`.'
4141
)
4242

43+
from google.cloud.aiplatform.constants import prediction
44+
from google.cloud.aiplatform import version
45+
4346

4447
class CprModelServer:
4548
"""Model server to do custom prediction routines."""
@@ -61,6 +64,9 @@ def __init__(self):
6164
)
6265
handler_module = importlib.import_module(os.environ.get("HANDLER_MODULE"))
6366
handler_class = getattr(handler_module, os.environ.get("HANDLER_CLASS"))
67+
self.is_default_handler = (
68+
handler_module == "google.cloud.aiplatform.prediction.handler"
69+
)
6470

6571
predictor_class = None
6672
if "PREDICTOR_MODULE" in os.environ:
@@ -145,6 +151,14 @@ async def predict(self, request: Request) -> Response:
145151
)
146152

147153
# Converts all other exceptions to HTTPException.
154+
if self.is_default_handler:
155+
raise HTTPException(
156+
status_code=500,
157+
detail=error_message,
158+
headers={
159+
prediction.CUSTOM_PREDICTION_ROUTINES_SERVER_ERROR_HEADER_KEY: version.__version__
160+
},
161+
)
148162
raise HTTPException(status_code=500, detail=error_message)
149163

150164

tests/unit/aiplatform/test_prediction.py

Lines changed: 79 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -821,10 +821,10 @@ async def test_handle_deserialize_raises_exception(
821821
with pytest.raises(HTTPException):
822822
await handler.handle(get_test_request())
823823

824+
get_content_type_from_headers_mock.assert_called_once_with(get_test_headers())
824825
deserialize_exception_mock.assert_called_once_with(
825826
_TEST_INPUT, _APPLICATION_JSON
826827
)
827-
get_content_type_from_headers_mock.assert_called_once_with(get_test_headers())
828828
assert not predictor_mock().preprocess.called
829829
assert not predictor_mock().predict.called
830830
assert not predictor_mock().postprocess.called
@@ -845,25 +845,66 @@ async def test_handle_predictor_raises_exception(
845845
handler = PredictionHandler(
846846
_TEST_GCS_ARTIFACTS_URI, predictor=get_test_predictor()
847847
)
848+
expected_message = (
849+
"The following exception has occurred: Exception. Arguments: ()."
850+
)
848851

849852
with mock.patch.multiple(
850853
handler._predictor,
851854
preprocess=preprocess_mock,
852855
predict=predict_mock,
853856
postprocess=postprocess_mock,
854857
):
855-
with pytest.raises(Exception):
858+
with pytest.raises(HTTPException) as exception:
856859
await handler.handle(get_test_request())
857860

858-
deserialize_mock.assert_called_once_with(_TEST_INPUT, _APPLICATION_JSON)
859-
get_content_type_from_headers_mock.assert_called_once_with(
860-
get_test_headers()
861-
)
862-
preprocess_mock.assert_called_once_with(_TEST_DESERIALIZED_INPUT)
863-
predict_mock.assert_called_once_with(_TEST_DESERIALIZED_INPUT)
864-
assert not postprocess_mock.called
865-
assert not get_accept_from_headers_mock.called
866-
assert not serialize_mock.called
861+
assert exception.value.status_code == 500
862+
assert exception.value.detail == expected_message
863+
get_content_type_from_headers_mock.assert_called_once_with(get_test_headers())
864+
deserialize_mock.assert_called_once_with(_TEST_INPUT, _APPLICATION_JSON)
865+
preprocess_mock.assert_called_once_with(_TEST_DESERIALIZED_INPUT)
866+
predict_mock.assert_called_once_with(_TEST_DESERIALIZED_INPUT)
867+
assert not postprocess_mock.called
868+
assert not get_accept_from_headers_mock.called
869+
assert not serialize_mock.called
870+
871+
@pytest.mark.asyncio
872+
async def test_handle_predictor_raises_http_exception(
873+
self,
874+
get_content_type_from_headers_mock,
875+
deserialize_mock,
876+
get_accept_from_headers_mock,
877+
serialize_mock,
878+
):
879+
status_code = 400
880+
expected_message = "This is an user error."
881+
preprocess_mock = mock.MagicMock(return_value=_TEST_DESERIALIZED_INPUT)
882+
predict_mock = mock.MagicMock(
883+
side_effect=HTTPException(status_code=status_code, detail=expected_message)
884+
)
885+
postprocess_mock = mock.MagicMock(return_value=_TEST_SERIALIZED_OUTPUT)
886+
handler = PredictionHandler(
887+
_TEST_GCS_ARTIFACTS_URI, predictor=get_test_predictor()
888+
)
889+
890+
with mock.patch.multiple(
891+
handler._predictor,
892+
preprocess=preprocess_mock,
893+
predict=predict_mock,
894+
postprocess=postprocess_mock,
895+
):
896+
with pytest.raises(HTTPException) as exception:
897+
await handler.handle(get_test_request())
898+
899+
assert exception.value.status_code == status_code
900+
assert exception.value.detail == expected_message
901+
get_content_type_from_headers_mock.assert_called_once_with(get_test_headers())
902+
deserialize_mock.assert_called_once_with(_TEST_INPUT, _APPLICATION_JSON)
903+
preprocess_mock.assert_called_once_with(_TEST_DESERIALIZED_INPUT)
904+
predict_mock.assert_called_once_with(_TEST_DESERIALIZED_INPUT)
905+
assert not postprocess_mock.called
906+
assert not get_accept_from_headers_mock.called
907+
assert not serialize_mock.called
867908

868909
@pytest.mark.asyncio
869910
async def test_handle_serialize_raises_exception(
@@ -3173,20 +3214,46 @@ def test_predict_thorws_http_exception(
31733214
assert response.status_code == 400
31743215
assert json.loads(response.content)["detail"] == expected_message
31753216

3176-
def test_predict_thorws_exceptions_not_http_exception(
3217+
def test_predict_thorws_exceptions_not_http_exception_default_handler(
31773218
self, model_server_env_mock, importlib_import_module_mock_twice
31783219
):
31793220
expected_message = (
31803221
"An exception ValueError occurred. Arguments: ('Not a correct value.',)."
31813222
)
31823223
model_server = CprModelServer()
3224+
model_server.is_default_handler = True
31833225
client = TestClient(model_server.app)
31843226

31853227
with mock.patch.object(model_server.handler, "handle") as handle_mock:
31863228
handle_mock.side_effect = ValueError("Not a correct value.")
31873229

31883230
response = client.post(_TEST_AIP_PREDICT_ROUTE, json={"x": [1]})
31893231

3232+
assert (
3233+
prediction.CUSTOM_PREDICTION_ROUTINES_SERVER_ERROR_HEADER_KEY
3234+
in response.headers
3235+
)
3236+
assert response.status_code == 500
3237+
assert json.loads(response.content)["detail"] == expected_message
3238+
3239+
def test_predict_thorws_exceptions_not_http_exception_not_default_handler(
3240+
self, model_server_env_mock, importlib_import_module_mock_twice
3241+
):
3242+
expected_message = (
3243+
"An exception ValueError occurred. Arguments: ('Not a correct value.',)."
3244+
)
3245+
model_server = CprModelServer()
3246+
client = TestClient(model_server.app)
3247+
3248+
with mock.patch.object(model_server.handler, "handle") as handle_mock:
3249+
handle_mock.side_effect = ValueError("Not a correct value.")
3250+
3251+
response = client.post(_TEST_AIP_PREDICT_ROUTE, json={"x": [1]})
3252+
3253+
assert (
3254+
prediction.CUSTOM_PREDICTION_ROUTINES_SERVER_ERROR_HEADER_KEY
3255+
not in response.headers
3256+
)
31903257
assert response.status_code == 500
31913258
assert json.loads(response.content)["detail"] == expected_message
31923259

0 commit comments

Comments
 (0)