diff --git a/doc/source/tune-usage.rst b/doc/source/tune-usage.rst index b73298a6f5ac1..f889fb09285d5 100644 --- a/doc/source/tune-usage.rst +++ b/doc/source/tune-usage.rst @@ -457,29 +457,43 @@ be wrapped with ``tune.function``): ) -Client API ----------- +Tune Client API +--------------- -You can modify an ongoing experiment by adding or deleting trials using the Tune Client API. To do this, verify that you have the ``requests`` library installed: +You can interact with an ongoing experiment with the Tune Client API. The Tune Client API is organized around REST, which includes resource-oriented URLs, accepts form-encoded requests, returns JSON-encoded responses, and uses standard HTTP protocol. -.. code-block:: bash - - $ pip install requests - -To use the Client API, you can start your experiment with ``with_server=True``: +To allow Tune to receive and respond to your API calls, you have to start your experiment with ``with_server=True``: .. code-block:: python run_experiments({...}, with_server=True, server_port=4321) -Then, on the client side, you can use the following class. The server address defaults to ``localhost:4321``. If on a cluster, you may want to forward this port (e.g. ``ssh -L :localhost:
``) so that you can use the Client on your local machine. +The easiest way to use the Tune Client API is with the built-in TuneClient. To use TuneClient, verify that you have the ``requests`` library installed: + +.. code-block:: bash + + $ pip install requests + +Then, on the client side, you can use the following class. If on a cluster, you may want to forward this port (e.g. ``ssh -L :localhost:
``) so that you can use the Client on your local machine. .. autoclass:: ray.tune.web_server.TuneClient :members: - For an example notebook for using the Client API, see the `Client API Example `__. +The API also supports curl. Here are the examples for getting trials (``GET /trials/[:id]``): + +.. code-block:: bash + + curl http://
:/trials + curl http://
:/trials/ + +And stopping a trial (``PUT /trials/[:id]``): + +.. code-block:: bash + + curl -X PUT http://
:/trials/ + Further Questions or Issues? ---------------------------- diff --git a/python/ray/tune/TuneClient.ipynb b/python/ray/tune/TuneClient.ipynb index dfd5dcb75475f..174b26f1e136b 100644 --- a/python/ray/tune/TuneClient.ipynb +++ b/python/ray/tune/TuneClient.ipynb @@ -8,7 +8,7 @@ "source": [ "from ray.tune.web_server import TuneClient\n", "\n", - "manager = TuneClient(tune_address=\"localhost:4321\")\n", + "manager = TuneClient(tune_address=\"localhost\", port_forward=4321)\n", "\n", "x = manager.get_all_trials()\n", "\n", @@ -19,7 +19,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "collapsed": true, "scrolled": false }, "outputs": [], @@ -31,9 +30,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "collapsed": true - }, + "metadata": {}, "outputs": [], "source": [ "import yaml\n", @@ -45,9 +42,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "collapsed": true - }, + "metadata": {}, "outputs": [], "source": [ "name, spec = [x for x in d.items()][0]" @@ -79,7 +74,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.2" + "version": "3.6.8" } }, "nbformat": 4, diff --git a/python/ray/tune/error.py b/python/ray/tune/error.py index b23d62a081852..badf60a08fdc5 100644 --- a/python/ray/tune/error.py +++ b/python/ray/tune/error.py @@ -6,8 +6,3 @@ class TuneError(Exception): """General error class raised by ray.tune.""" pass - - -class TuneManagerError(TuneError): - """Error raised in operating the Tune Manager.""" - pass diff --git a/python/ray/tune/test/tune_server_test.py b/python/ray/tune/test/tune_server_test.py index e93c7d976d867..7d9143544e228 100644 --- a/python/ray/tune/test/tune_server_test.py +++ b/python/ray/tune/test/tune_server_test.py @@ -4,6 +4,8 @@ import unittest import socket +import subprocess +import json import ray from ray import tune @@ -44,7 +46,7 @@ def basicSetup(self): trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)] for t in trials: runner.add_trial(t) - client = TuneClient("localhost:{}".format(port)) + client = TuneClient("localhost", port) return runner, client def tearDown(self): @@ -126,6 +128,25 @@ def testStopTrial(self): self.assertEqual( len([t for t in all_trials if t["status"] == Trial.RUNNING]), 0) + def testCurlCommand(self): + """Check if Stop Trial works.""" + runner, client = self.basicSetup() + for i in range(2): + runner.step() + stdout = subprocess.check_output( + 'curl "http://{}:{}/trials"'.format(client.server_address, + client.server_port), + shell=True) + self.assertNotEqual(stdout, None) + curl_trials = json.loads(stdout.decode())["trials"] + client_trials = client.get_all_trials()["trials"] + for curl_trial, client_trial in zip(curl_trials, client_trials): + self.assertEqual(curl_trial.keys(), client_trial.keys()) + self.assertEqual(curl_trial["id"], client_trial["id"]) + self.assertEqual(curl_trial["trainable_name"], + client_trial["trainable_name"]) + self.assertEqual(curl_trial["status"], client_trial["status"]) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/python/ray/tune/web_server.py b/python/ray/tune/web_server.py index 81ca332278e2a..4c27f92fdcf8e 100644 --- a/python/ray/tune/web_server.py +++ b/python/ray/tune/web_server.py @@ -8,14 +8,16 @@ import threading import ray.cloudpickle as cloudpickle -from ray.tune.error import TuneError, TuneManagerError +from ray.tune import TuneError from ray.tune.suggest import BasicVariantGenerator from ray.utils import binary_to_hex, hex_to_binary if sys.version_info[0] == 2: + from urlparse import urljoin, urlparse from SimpleHTTPServer import SimpleHTTPRequestHandler from SocketServer import TCPServer as HTTPServer elif sys.version_info[0] == 3: + from urllib.parse import urljoin, urlparse from http.server import SimpleHTTPRequestHandler, HTTPServer logger = logging.getLogger(__name__) @@ -28,83 +30,159 @@ "Be sure to install it on the client side.") -def load_trial_info(trial_info): - trial_info["config"] = cloudpickle.loads( - hex_to_binary(trial_info["config"])) - trial_info["result"] = cloudpickle.loads( - hex_to_binary(trial_info["result"])) - - class TuneClient(object): - """Client to interact with ongoing Tune experiment. + """Client to interact with an ongoing Tune experiment. - Requires server to have started running.""" - STOP = "STOP" - ADD = "ADD" - GET_LIST = "GET_LIST" - GET_TRIAL = "GET_TRIAL" + Requires a TuneServer to have started running. - def __init__(self, tune_address): - # TODO(rliaw): Better to specify address and port forward + Attributes: + tune_address (str): Address of running TuneServer + port_forward (int): Port number of running TuneServer + """ + + def __init__(self, tune_address, port_forward): self._tune_address = tune_address - self._path = "http://{}".format(tune_address) + self._port_forward = port_forward + self._path = "http://{}:{}".format(tune_address, port_forward) def get_all_trials(self): - """Returns a list of all trials (trial_id, config, status).""" - return self._get_response({"command": TuneClient.GET_LIST}) + """Returns a list of all trials' information.""" + response = requests.get(urljoin(self._path, "trials")) + return self._deserialize(response) def get_trial(self, trial_id): - """Returns the last result for queried trial.""" - return self._get_response({ - "command": TuneClient.GET_TRIAL, - "trial_id": trial_id - }) - - def add_trial(self, name, trial_spec): - """Adds a trial of `name` with configurations.""" - # TODO(rliaw): have better way of specifying a new trial - return self._get_response({ - "command": TuneClient.ADD, - "name": name, - "spec": trial_spec - }) + """Returns trial information by trial_id.""" + response = requests.get( + urljoin(self._path, "trials/{}".format(trial_id))) + return self._deserialize(response) + + def add_trial(self, name, specification): + """Adds a trial by name and specification (dict).""" + payload = {"name": name, "spec": specification} + response = requests.post(urljoin(self._path, "trials"), json=payload) + return self._deserialize(response) def stop_trial(self, trial_id): - """Requests to stop trial.""" - return self._get_response({ - "command": TuneClient.STOP, - "trial_id": trial_id - }) - - def _get_response(self, data): - payload = json.dumps(data).encode() - response = requests.get(self._path, data=payload) + """Requests to stop trial by trial_id.""" + response = requests.put( + urljoin(self._path, "trials/{}".format(trial_id))) + return self._deserialize(response) + + @property + def server_address(self): + return self._tune_address + + @property + def server_port(self): + return self._port_forward + + def _load_trial_info(self, trial_info): + trial_info["config"] = cloudpickle.loads( + hex_to_binary(trial_info["config"])) + trial_info["result"] = cloudpickle.loads( + hex_to_binary(trial_info["result"])) + + def _deserialize(self, response): parsed = response.json() - if "trial_info" in parsed: - load_trial_info(parsed["trial_info"]) + if "trial" in parsed: + self._load_trial_info(parsed["trial"]) elif "trials" in parsed: for trial_info in parsed["trials"]: - load_trial_info(trial_info) + self._load_trial_info(trial_info) return parsed def RunnerHandler(runner): class Handler(SimpleHTTPRequestHandler): + """A Handler is a custom handler for TuneServer. + + Handles all requests and responses coming into and from + the TuneServer. + """ + + def _do_header(self, response_code=200, headers=None): + """Sends the header portion of the HTTP response. + + Parameters: + response_code (int): Standard HTTP response code + headers (list[tuples]): Standard HTTP response headers + """ + if headers is None: + headers = [('Content-type', 'application/json')] + + self.send_response(response_code) + for key, value in headers: + self.send_header(key, value) + self.end_headers() + + def do_HEAD(self): + """HTTP HEAD handler method.""" + self._do_header() + def do_GET(self): + """HTTP GET handler method.""" + response_code = 200 + message = "" + try: + result = self._get_trial_by_url(self.path) + resource = {} + if result: + if isinstance(result, list): + infos = [self._trial_info(t) for t in result] + resource["trials"] = infos + else: + resource["trial"] = self._trial_info(result) + message = json.dumps(resource) + except TuneError as e: + response_code = 404 + message = str(e) + + self._do_header(response_code=response_code) + self.wfile.write(message.encode()) + + def do_PUT(self): + """HTTP PUT handler method.""" + response_code = 200 + message = "" + try: + result = self._get_trial_by_url(self.path) + resource = {} + if result: + if isinstance(result, list): + infos = [self._trial_info(t) for t in result] + resource["trials"] = infos + for t in result: + runner.request_stop_trial(t) + else: + resource["trial"] = self._trial_info(result) + runner.request_stop_trial(result) + message = json.dumps(resource) + except TuneError as e: + response_code = 404 + message = str(e) + + self._do_header(response_code=response_code) + self.wfile.write(message.encode()) + + def do_POST(self): + """HTTP POST handler method.""" + response_code = 201 + content_len = int(self.headers.get('Content-Length'), 0) raw_body = self.rfile.read(content_len) parsed_input = json.loads(raw_body.decode()) - status, response = self.execute_command(parsed_input) - if status: - self.send_response(200) - else: - self.send_response(400) - self.end_headers() - self.wfile.write(json.dumps(response).encode()) + resource = self._add_trials(parsed_input["name"], + parsed_input["spec"]) - def trial_info(self, trial): + headers = [('Content-type', 'application/json'), ('Location', + '/trials/')] + self._do_header(response_code=response_code, headers=headers) + self.wfile.write(json.dumps(resource).encode()) + + def _trial_info(self, trial): + """Returns trial information as JSON.""" if trial.last_result: result = trial.last_result.copy() else: @@ -118,62 +196,56 @@ def trial_info(self, trial): } return info_dict - def execute_command(self, args): - def get_trial(): - trial = runner.get_trial(args["trial_id"]) - if trial is None: - error = "Trial ({}) not found.".format(args["trial_id"]) - raise TuneManagerError(error) - else: - return trial - - command = args["command"] - response = {} - try: - if command == TuneClient.GET_LIST: - response["trials"] = [ - self.trial_info(t) for t in runner.get_trials() - ] - elif command == TuneClient.GET_TRIAL: - trial = get_trial() - response["trial_info"] = self.trial_info(trial) - elif command == TuneClient.STOP: - trial = get_trial() - runner.request_stop_trial(trial) - elif command == TuneClient.ADD: - name = args["name"] - spec = args["spec"] - trial_generator = BasicVariantGenerator() - trial_generator.add_configurations({name: spec}) - for trial in trial_generator.next_trials(): - runner.add_trial(trial) - else: - raise TuneManagerError("Unknown command.") - status = True - except TuneError as e: - status = False - response["message"] = str(e) + def _get_trial_by_url(self, url): + """Parses url to get either all trials or trial by trial_id.""" + parts = urlparse(url) + path = parts.path - return status, response + if path == "/trials": + return [t for t in runner.get_trials()] + else: + trial_id = path.split("/")[-1] + return runner.get_trial(trial_id) + + def _add_trials(self, name, spec): + """Add trial by invoking TrialRunner.""" + resource = {} + resource["trials"] = [] + trial_generator = BasicVariantGenerator() + trial_generator.add_configurations({name: spec}) + for trial in trial_generator.next_trials(): + runner.add_trial(trial) + resource["trials"].append(self._trial_info(trial)) + return resource return Handler class TuneServer(threading.Thread): + """A TuneServer is a thread that initializes and runs a HTTPServer. + + The server handles requests from a TuneClient. + + Attributes: + runner (TrialRunner): Runner that modifies and accesses trials. + port_forward (int): Port number of TuneServer. + """ DEFAULT_PORT = 4321 def __init__(self, runner, port=None): - + """Initialize HTTPServer and serve forever by invoking self.run()""" threading.Thread.__init__(self) self._port = port if port else self.DEFAULT_PORT address = ('localhost', self._port) logger.info("Starting Tune Server...") self._server = HTTPServer(address, RunnerHandler(runner)) + self.daemon = True self.start() def run(self): self._server.serve_forever() def shutdown(self): + """Shutdown the underlying server.""" self._server.shutdown()