Skip to content

Commit 3e1b764

Browse files
authored
Tweak python exception handling (#405)
1 parent f15f73c commit 3e1b764

File tree

5 files changed

+28
-42
lines changed

5 files changed

+28
-42
lines changed

pkg/workloads/cortex/lib/api_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,9 @@ def post_request_metrics(ctx, api, response, prediction_payload, class_set):
138138

139139
metrics_list += prediction_metrics(api_dimensions, api, prediction)
140140
except Exception as e:
141-
logger.warn(str(e), exc_info=True)
141+
logger.warn("unable to record prediction metric", exc_info=True)
142142

143143
try:
144144
ctx.publish_metrics(metrics_list)
145145
except Exception as e:
146-
logger.warn(str(e), exc_info=True)
146+
logger.warn("failure encountered while publishing metrics", exc_info=True)

pkg/workloads/cortex/lib/context.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def load_module(self, module_prefix, module_name, impl_key):
114114
try:
115115
impl = imp.load_source(full_module_name, impl_path)
116116
except Exception as e:
117-
raise UserException("unable to load python file") from e
117+
raise UserException("unable to load python file", str(e)) from e
118118

119119
return impl, impl_path
120120

@@ -200,7 +200,7 @@ def publish_metrics(self, metrics):
200200

201201
if int(response["ResponseMetadata"]["HTTPStatusCode"] / 100) != 2:
202202
logger.warn(response)
203-
raise Exception("failed to publish metrics")
203+
raise Exception("cloudwatch returned a non-200 status")
204204

205205

206206
REQUEST_HANDLER_IMPL_VALIDATION = {

pkg/workloads/cortex/lib/exceptions.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,20 @@
1717

1818
class CortexException(Exception):
1919
def __init__(self, *messages):
20-
super().__init__(self.stringify(messages))
21-
20+
super().__init__(": ".join(messages))
2221
self.errors = deque(messages)
2322

2423
def wrap(self, *messages):
2524
self.errors.extendleft(reversed(messages))
2625

2726
def __str__(self):
28-
return self.stringify(self.errors)
27+
return self.stringify()
28+
29+
def __repr__(self):
30+
return self.stringify()
2931

30-
@staticmethod
31-
def stringify(str_list):
32-
return ": ".join(str_list)
32+
def stringify(self):
33+
return "error: " + ": ".join(self.errors)
3334

3435

3536
class UserException(CortexException):

pkg/workloads/cortex/onnx_serve/api.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def transform_to_numpy(input_pyobj, input_metadata):
124124
np_arr = np_arr.reshape(target_shape)
125125
return np_arr
126126
except Exception as e:
127-
raise UserException(str(e)) from e
127+
raise UserException("failed to convert to numpy array", str(e)) from e
128128

129129

130130
def convert_to_onnx_input(sample, input_metadata_list):
@@ -170,7 +170,7 @@ def predict(app_name, api_name):
170170
try:
171171
sample = request.get_json()
172172
except Exception as e:
173-
return "Malformed JSON", status.HTTP_400_BAD_REQUEST
173+
return "malformed json", status.HTTP_400_BAD_REQUEST
174174

175175
sess = local_cache["sess"]
176176
api = local_cache["api"]
@@ -189,7 +189,7 @@ def predict(app_name, api_name):
189189
debug_obj("pre_inference", prepared_sample, debug)
190190
except Exception as e:
191191
raise UserRuntimeException(
192-
api["request_handler"], "pre_inference request handler"
192+
api["request_handler"], "pre_inference request handler", str(e)
193193
) from e
194194

195195
inference_input = convert_to_onnx_input(prepared_sample, input_metadata)
@@ -208,16 +208,12 @@ def predict(app_name, api_name):
208208
result = request_handler.post_inference(result, output_metadata)
209209
except Exception as e:
210210
raise UserRuntimeException(
211-
api["request_handler"], "post_inference request handler"
211+
api["request_handler"], "post_inference request handler", str(e)
212212
) from e
213213

214214
debug_obj("post_inference", result, debug)
215-
except CortexException as e:
216-
e.wrap("error")
217-
logger.exception(str(e))
218-
return prediction_failed(str(e))
219215
except Exception as e:
220-
logger.exception(str(e))
216+
logger.exception("prediction failed")
221217
return prediction_failed(str(e))
222218

223219
g.prediction = result
@@ -270,9 +266,9 @@ def start(args):
270266
truncate(extract_signature(local_cache["output_metadata"]))
271267
)
272268
)
273-
except CortexException as e:
274-
e.wrap("error")
275-
logger.error(str(e))
269+
270+
except Exception as e:
271+
logger.exception("failed to start api")
276272
sys.exit(1)
277273

278274
if api.get("tracker") is not None and api["tracker"].get("model_type") == "classification":

pkg/workloads/cortex/tf_api/api.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from cortex.lib import util, package, Context, api_utils
3232
from cortex.lib.storage import S3
3333
from cortex.lib.log import get_logger, debug_obj
34-
from cortex.lib.exceptions import CortexException, UserRuntimeException, UserException
34+
from cortex.lib.exceptions import UserRuntimeException, UserException
3535
from cortex.lib.stringify import truncate
3636

3737

@@ -134,7 +134,7 @@ def create_prediction_request(sample):
134134
prediction_request.inputs[column_name].CopyFrom(tensor_proto)
135135
except Exception as e:
136136
raise UserException(
137-
'key "{}"'.format(column_name), "expected shape {}".format(shape)
137+
'key "{}"'.format(column_name), "expected shape {}".format(shape), str(e)
138138
) from e
139139

140140
return prediction_request
@@ -185,7 +185,7 @@ def run_predict(sample, debug=False):
185185
debug_obj("pre_inference", prepared_sample, debug)
186186
except Exception as e:
187187
raise UserRuntimeException(
188-
api["request_handler"], "pre_inference request handler"
188+
api["request_handler"], "pre_inference request handler", str(e)
189189
) from e
190190

191191
validate_sample(prepared_sample)
@@ -201,7 +201,7 @@ def run_predict(sample, debug=False):
201201
debug_obj("post_inference", result, debug)
202202
except Exception as e:
203203
raise UserRuntimeException(
204-
api["request_handler"], "post_inference request handler"
204+
api["request_handler"], "post_inference request handler", str(e)
205205
) from e
206206

207207
return result
@@ -243,12 +243,8 @@ def predict(deployment_name, api_name):
243243

244244
try:
245245
result = run_predict(sample, debug)
246-
except CortexException as e:
247-
e.wrap("error")
248-
logger.exception(str(e))
249-
return prediction_failed(str(e))
250246
except Exception as e:
251-
logger.exception(str(e))
247+
logger.exception("prediction failed")
252248
return prediction_failed(str(e))
253249

254250
g.prediction = result
@@ -284,12 +280,9 @@ def get_signature(app_name, api_name):
284280
local_cache["metadata"]["signatureDef"],
285281
local_cache["api"]["tf_serving"]["signature_key"],
286282
)
287-
except CortexException as e:
288-
logger.exception(str(e))
289-
return str(e), HTTP_404_NOT_FOUND
290283
except Exception as e:
291-
logger.exception(str(e))
292-
return str(e), HTTP_404_NOT_FOUND
284+
logger.exception("failed to get signature")
285+
return jsonify(error=str(e)), 404
293286

294287
response = {"signature": metadata}
295288
return jsonify(response)
@@ -336,12 +329,8 @@ def start(args):
336329
if api.get("request_handler") is not None:
337330
package.install_packages(ctx.python_packages, ctx.storage)
338331
local_cache["request_handler"] = ctx.get_request_handler_impl(api["name"])
339-
except CortexException as e:
340-
e.wrap("error")
341-
logger.exception(str(e))
342-
sys.exit(1)
343332
except Exception as e:
344-
logger.exception(str(e))
333+
logger.exception("failed to start api")
345334
sys.exit(1)
346335

347336
try:
@@ -354,7 +343,7 @@ def start(args):
354343
try:
355344
local_cache["class_set"] = api_utils.get_classes(ctx, api["name"])
356345
except Exception as e:
357-
logger.warn("An error occurred while attempting to load classes", exc_info=True)
346+
logger.warn("an error occurred while attempting to load classes", exc_info=True)
358347

359348
channel = grpc.insecure_channel("localhost:" + str(args.tf_serve_port))
360349
local_cache["stub"] = prediction_service_pb2_grpc.PredictionServiceStub(channel)

0 commit comments

Comments
 (0)