Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tune] Ray Tune API cleanup #1454

Merged
merged 22 commits into from
Jan 25, 2018
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/source/rllib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ in the ``config`` section of the experiments.

.. code-block:: python

import ray
from ray.tune.tune import run_experiments
from ray.tune.variant_generator import grid_search

Expand All @@ -286,6 +287,7 @@ in the ``config`` section of the experiments.
# put additional experiments to run concurrently here
}

ray.init()
run_experiments(experiment)

Contributing to RLlib
Expand Down
7 changes: 2 additions & 5 deletions doc/source/tune.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Getting Started

::

import ray
from ray.tune import register_trainable, grid_search, run_experiments

def my_func(config, reporter):
Expand All @@ -30,6 +31,7 @@ Getting Started

register_trainable("my_func", my_func)

ray.init()
run_experiments({
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe you need to modify rllib.rst too

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

"my_experiment": {
"run": "my_func",
Expand Down Expand Up @@ -154,8 +156,3 @@ The JSON config passed to ``run_experiments`` can also be put in a JSON or YAML


For more examples of experiments described by YAML files, see `RLlib tuned examples <https://github.com/ray-project/ray/tree/master/python/ray/rllib/tuned_examples>`__.

Running in a large cluster
--------------------------

The ``run_experiments`` also takes any arguments that ``ray.init()`` does. This can be used to pass in the redis address of a multi-node Ray cluster. For more details, check out the `tune.py script <https://github.com/ray-project/ray/blob/master/python/ray/tune/tune.py>`__.
2 changes: 2 additions & 0 deletions examples/carla/a3c_lane_keep.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import division
from __future__ import print_function

import ray
from ray.tune import register_env, run_experiments

from env import CarlaEnv, ENV_CONFIG
Expand All @@ -25,6 +26,7 @@
register_env(env_name, lambda env_config: CarlaEnv(env_config))
register_carla_model()

ray.init()
run_experiments({
"carla-a3c": {
"run": "A3C",
Expand Down
2 changes: 2 additions & 0 deletions examples/carla/dqn_lane_keep.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import division
from __future__ import print_function

import ray
from ray.tune import register_env, run_experiments

from env import CarlaEnv, ENV_CONFIG
Expand All @@ -25,6 +26,7 @@
register_env(env_name, lambda env_config: CarlaEnv(env_config))
register_carla_model()

ray.init()
run_experiments({
"carla-dqn": {
"run": "DQN",
Expand Down
2 changes: 2 additions & 0 deletions examples/carla/ppo_lane_keep.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import division
from __future__ import print_function

import ray
from ray.tune import register_env, run_experiments

from env import CarlaEnv, ENV_CONFIG
Expand All @@ -25,6 +26,7 @@
register_env(env_name, lambda env_config: CarlaEnv(env_config))
register_carla_model()

ray.init()
run_experiments({
"carla-ppo": {
"run": "PPO",
Expand Down
3 changes: 2 additions & 1 deletion examples/carla/train_a3c.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
register_carla_model()
redis_address = ray.services.get_node_ip_address() + ":6379"

ray.init(redis_address=redis_address)
run_experiments({
"carla-a3c": {
"run": "A3C",
Expand All @@ -50,4 +51,4 @@
"num_workers": 2,
},
},
}, redis_address=redis_address)
}
2 changes: 2 additions & 0 deletions examples/carla/train_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import division
from __future__ import print_function

import ray
from ray.tune import register_env, run_experiments

from env import CarlaEnv, ENV_CONFIG
Expand All @@ -23,6 +24,7 @@
register_env(env_name, lambda env_config: CarlaEnv(env_config))
register_carla_model()

ray.init()
run_experiments({
"carla-dqn": {
"run": "DQN",
Expand Down
4 changes: 3 additions & 1 deletion examples/carla/train_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import division
from __future__ import print_function

import ray
from ray.tune import register_env, run_experiments

from env import CarlaEnv, ENV_CONFIG
Expand All @@ -22,6 +23,7 @@
register_env(env_name, lambda env_config: CarlaEnv(env_config))
register_carla_model()

ray.init(redirect_output=True)
run_experiments({
"carla": {
"run": "PPO",
Expand Down Expand Up @@ -55,4 +57,4 @@
}
},
},
}, redirect_output=True)
})
5 changes: 3 additions & 2 deletions python/ray/rllib/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import sys
import yaml

import ray
from ray.tune.config_parser import make_parser, resources_to_json
from ray.tune.tune import _make_scheduler, run_experiments

Expand Down Expand Up @@ -76,7 +77,7 @@
if not exp.get("env") and not exp.get("config", {}).get("env"):
parser.error("the following arguments are required: --env")

run_experiments(
experiments, scheduler=_make_scheduler(args),
ray.init(
redis_address=args.redis_address,
num_cpus=args.num_cpus, num_gpus=args.num_gpus)
run_experiments(experiments, scheduler=_make_scheduler(args))
2 changes: 2 additions & 0 deletions python/ray/tune/examples/tune_mnist_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import tempfile
import time

import ray
from ray.tune import grid_search, run_experiments, register_trainable

from tensorflow.examples.tutorials.mnist import input_data
Expand Down Expand Up @@ -222,4 +223,5 @@ def train(config={'activation': 'relu'}, reporter=None):
if args.fast:
mnist_spec['stop']['training_iteration'] = 2

ray.init()
run_experiments({'tune_mnist_test': mnist_spec})
9 changes: 4 additions & 5 deletions python/ray/tune/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def _make_scheduler(args):
args.scheduler, _SCHEDULERS.keys()))


def run_experiments(experiments, scheduler=None, **ray_args):
def run_experiments(experiments, scheduler=None):
if scheduler is None:
scheduler = FIFOScheduler()
runner = TrialRunner(scheduler)
Expand All @@ -71,8 +71,6 @@ def run_experiments(experiments, scheduler=None, **ray_args):
runner.add_trial(trial)
print(runner.debug_string())

ray.init(**ray_args)

while not runner.is_finished():
runner.step()
print(runner.debug_string())
Expand All @@ -89,6 +87,7 @@ def run_experiments(experiments, scheduler=None, **ray_args):
args = parser.parse_args(sys.argv[1:])
with open(args.config_file) as f:
experiments = yaml.load(f)
run_experiments(
experiments, _make_scheduler(args), redis_address=args.redis_address,
ray.init(
redis_address=args.redis_address,
num_cpus=args.num_cpus, num_gpus=args.num_gpus)
run_experiments(experiments, _make_scheduler(args))
3 changes: 3 additions & 0 deletions test/trial_runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@


class TrainableFunctionApiTest(unittest.TestCase):
def setUp(self):
ray.init()

def tearDown(self):
ray.worker.cleanup()
_register_all() # re-register the evicted objects
Expand Down