Skip to content

Commit

Permalink
Refactor rest store and rest utils to support custom verification/pro…
Browse files Browse the repository at this point in the history
…cessing of the response (mlflow#1261)

* Refactor rest store and rest utils to support custom deserialization of response in RestStore

* Fix lint and error message bugs

* Resolve minor comments

* Keep message as string and test json se(de)rialization

* Add back error_code for rest_utils tests
  • Loading branch information
eddiestudies authored and tomasatdatabricks committed May 17, 2019
1 parent 565f3db commit 5329678
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 15 deletions.
6 changes: 3 additions & 3 deletions mlflow/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ class RestException(MlflowException):
"""Exception thrown on non 200-level responses from the REST API"""
def __init__(self, json):
error_code = json.get('error_code', INTERNAL_ERROR)
message = error_code
if 'message' in json:
message = "%s: %s" % (error_code, json['message'])
message = "%s: %s" % (error_code,
json['message'] if 'message' in json else "Response: " + str(json))

super(RestException, self).__init__(message, error_code=error_code)
self.json = json

Expand Down
11 changes: 8 additions & 3 deletions mlflow/store/rest_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from mlflow.entities import Experiment, Run, RunInfo, Metric, ViewType

from mlflow.utils.proto_json_utils import message_to_json, parse_dict
from mlflow.utils.rest_utils import http_request_safe
from mlflow.utils.rest_utils import http_request, verify_rest_response

from mlflow.protos.service_pb2 import CreateExperiment, MlflowService, GetExperiment, \
GetRun, SearchRuns, ListExperiments, GetMetricHistory, LogMetric, LogParam, SetTag, \
Expand Down Expand Up @@ -48,6 +48,9 @@ def __init__(self, get_host_creds):
super(RestStore, self).__init__()
self.get_host_creds = get_host_creds

def _verify_rest_response(self, response, endpoint):
return verify_rest_response(response, endpoint)

def _call_endpoint(self, api, json_body):
endpoint, method = _METHOD_TO_INFO[api]
response_proto = api.Response()
Expand All @@ -57,12 +60,14 @@ def _call_endpoint(self, api, json_body):
host_creds = self.get_host_creds()

if method == 'GET':
response = http_request_safe(
response = http_request(
host_creds=host_creds, endpoint=endpoint, method=method, params=json_body)
else:
response = http_request_safe(
response = http_request(
host_creds=host_creds, endpoint=endpoint, method=method, json=json_body)

response = self._verify_rest_response(response, endpoint)

js_dict = json.loads(response.text)
parse_dict(js_dict=js_dict, message=response_proto)
return response_proto
Expand Down
5 changes: 5 additions & 0 deletions mlflow/utils/rest_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ def http_request_safe(host_creds, endpoint, **kwargs):
Wrapper around ``http_request`` that also verifies that the request succeeds with code 200.
"""
response = http_request(host_creds=host_creds, endpoint=endpoint, **kwargs)
return verify_rest_response(response, endpoint)


def verify_rest_response(response, endpoint):
"""Verify the return code and raise exception if the request was not successful."""
if response.status_code != 200:
base_msg = "API request to endpoint %s failed with error code " \
"%s != 200" % (endpoint, response.status_code)
Expand Down
39 changes: 30 additions & 9 deletions tests/store/test_rest_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,16 @@
from mlflow.utils.rest_utils import MlflowHostCreds, _DEFAULT_HEADERS


class MyCoolException(Exception):
pass


class CustomErrorHandlingRestStore(RestStore):
def _verify_rest_response(self, response, endpoint):
if response.status_code != 200:
raise MyCoolException()


class TestRestStore(unittest.TestCase):
@mock.patch('requests.request')
def test_successful_http_request(self, request):
Expand Down Expand Up @@ -51,6 +61,17 @@ def test_failed_http_request(self, request):
store.list_experiments()
self.assertIn("RESOURCE_DOES_NOT_EXIST: No experiment", str(cm.exception))

@mock.patch('requests.request')
def test_failed_http_request_custom_handler(self, request):
response = mock.MagicMock
response.status_code = 404
response.text = '{"error_code": "RESOURCE_DOES_NOT_EXIST", "message": "No experiment"}'
request.return_value = response

store = CustomErrorHandlingRestStore(lambda: MlflowHostCreds('https://hello'))
with self.assertRaises(MyCoolException):
store.list_experiments()

@mock.patch('requests.request')
def test_response_with_unknown_fields(self, request):
experiment_json = {
Expand Down Expand Up @@ -100,7 +121,7 @@ def test_requestor(self, request):
source_type_patch = mock.patch(
"mlflow.tracking.context._get_source_type", return_value=SourceType.LOCAL
)
with mock.patch('mlflow.store.rest_store.http_request_safe') as mock_http, \
with mock.patch('mlflow.store.rest_store.http_request') as mock_http, \
mock.patch('mlflow.tracking.utils._get_store', return_value=store), \
mock.patch('mlflow.tracking.context._get_user', return_value=user_name), \
mock.patch('time.time', return_value=13579), \
Expand Down Expand Up @@ -129,28 +150,28 @@ def test_requestor(self, request):
)
assert expected_kwargs == actual_kwargs

with mock.patch('mlflow.store.rest_store.http_request_safe') as mock_http:
with mock.patch('mlflow.store.rest_store.http_request') as mock_http:
store.log_param("some_uuid", Param("k1", "v1"))
body = message_to_json(LogParam(
run_uuid="some_uuid", run_id="some_uuid", key="k1", value="v1"))
self._verify_requests(mock_http, creds,
"runs/log-parameter", "POST", body)

with mock.patch('mlflow.store.rest_store.http_request_safe') as mock_http:
with mock.patch('mlflow.store.rest_store.http_request') as mock_http:
store.set_tag("some_uuid", RunTag("t1", "abcd"*1000))
body = message_to_json(SetTag(
run_uuid="some_uuid", run_id="some_uuid", key="t1", value="abcd"*1000))
self._verify_requests(mock_http, creds,
"runs/set-tag", "POST", body)

with mock.patch('mlflow.store.rest_store.http_request_safe') as mock_http:
with mock.patch('mlflow.store.rest_store.http_request') as mock_http:
store.log_metric("u2", Metric("m1", 0.87, 12345, 3))
body = message_to_json(LogMetric(
run_uuid="u2", run_id="u2", key="m1", value=0.87, timestamp=12345, step=3))
self._verify_requests(mock_http, creds,
"runs/log-metric", "POST", body)

with mock.patch('mlflow.store.rest_store.http_request_safe') as mock_http:
with mock.patch('mlflow.store.rest_store.http_request') as mock_http:
metrics = [Metric("m1", 0.87, 12345, 0), Metric("m2", 0.49, 12345, -1),
Metric("m3", 0.58, 12345, 2)]
params = [Param("p1", "p1val"), Param("p2", "p2val")]
Expand All @@ -164,25 +185,25 @@ def test_requestor(self, request):
self._verify_requests(mock_http, creds,
"runs/log-batch", "POST", body)

with mock.patch('mlflow.store.rest_store.http_request_safe') as mock_http:
with mock.patch('mlflow.store.rest_store.http_request') as mock_http:
store.delete_run("u25")
self._verify_requests(mock_http, creds,
"runs/delete", "POST",
message_to_json(DeleteRun(run_id="u25")))

with mock.patch('mlflow.store.rest_store.http_request_safe') as mock_http:
with mock.patch('mlflow.store.rest_store.http_request') as mock_http:
store.restore_run("u76")
self._verify_requests(mock_http, creds,
"runs/restore", "POST",
message_to_json(RestoreRun(run_id="u76")))

with mock.patch('mlflow.store.rest_store.http_request_safe') as mock_http:
with mock.patch('mlflow.store.rest_store.http_request') as mock_http:
store.delete_experiment("0")
self._verify_requests(mock_http, creds,
"experiments/delete", "POST",
message_to_json(DeleteExperiment(experiment_id="0")))

with mock.patch('mlflow.store.rest_store.http_request_safe') as mock_http:
with mock.patch('mlflow.store.rest_store.http_request') as mock_http:
store.restore_experiment("0")
self._verify_requests(mock_http, creds,
"experiments/restore", "POST",
Expand Down
16 changes: 16 additions & 0 deletions tests/utils/test_exception.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import json
from mlflow.exceptions import ExecutionException, RestException


def test_execution_exception_string_repr():
exc = ExecutionException("Uh oh")
assert str(exc) == "Uh oh"
json.loads(exc.serialize_as_json())


def test_rest_exception_default_error_code():
Expand All @@ -16,3 +18,17 @@ def test_rest_exception_error_code_is_not_none():
exc = RestException({"message": error_string})
assert "None" not in error_string
assert "None" not in str(exc)
json.loads(exc.serialize_as_json())


def test_rest_exception_without_message():
exc = RestException({"my_property": "something important."})
assert "something important." in str(exc)
json.loads(exc.serialize_as_json())


def test_rest_exception_error_code_and_no_message():
exc = RestException({"error_code": 2, "messages": "something important."})
assert "something important." in str(exc)
assert "2" in str(exc)
json.loads(exc.serialize_as_json())

0 comments on commit 5329678

Please sign in to comment.