Skip to content

Commit

Permalink
[tune] Support user-defined trainable functions / classes / envs with…
Browse files Browse the repository at this point in the history
… a shared object registry (ray-project#1226)
  • Loading branch information
ericl authored and richardliaw committed Nov 21, 2017
1 parent 9233e49 commit 316f9e2
Show file tree
Hide file tree
Showing 38 changed files with 739 additions and 299 deletions.
9 changes: 9 additions & 0 deletions doc/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,19 @@
# These lines added to enable Sphinx to work without installing Ray.
import mock
MOCK_MODULES = ["gym",
"gym.spaces",
"scipy",
"scipy.signal",
"tensorflow",
"tensorflow.contrib",
"tensorflow.contrib.layers",
"tensorflow.contrib.slim",
"tensorflow.contrib.rnn",
"tensorflow.core",
"tensorflow.core.util",
"tensorflow.python",
"tensorflow.python.client",
"tensorflow.python.util",
"pyarrow",
"pyarrow.plasma",
"smart_open",
Expand Down
2 changes: 1 addition & 1 deletion doc/source/example-a3c.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ You can run the code with

.. code-block:: bash
python/ray/rllib/train.py --env=Pong-ram-v4 --alg=A3C --config='{"num_workers": N}'
python/ray/rllib/train.py --env=Pong-ram-v4 --run=A3C --config='{"num_workers": N}'
Reinforcement Learning
----------------------
Expand Down
4 changes: 2 additions & 2 deletions doc/source/example-evolution-strategies.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ on the ``Humanoid-v1`` gym environment.

.. code-block:: bash
python/ray/rllib/train.py --env=Humanoid-v1 --alg=ES
python/ray/rllib/train.py --env=Humanoid-v1 --run=ES
To train a policy on a cluster (e.g., using 900 workers), run the following.

.. code-block:: bash
python ray/python/ray/rllib/train.py \
--env=Humanoid-v1 \
--alg=ES \
--run=ES \
--redis-address=<redis-address> \
--config='{"num_workers": 900, "episodes_per_batch": 10000, "timesteps_per_batch": 100000}'
Expand Down
2 changes: 1 addition & 1 deletion doc/source/example-policy-gradient.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Then you can run the example as follows.

.. code-block:: bash
python/ray/rllib/train.py --env=Pong-ram-v4 --alg=PPO
python/ray/rllib/train.py --env=Pong-ram-v4 --run=PPO
This will train an agent on the ``Pong-ram-v4`` Atari environment. You can also
try passing in the ``Pong-v0`` environment or the ``CartPole-v0`` environment.
Expand Down
4 changes: 2 additions & 2 deletions doc/source/rllib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ You can run training with

::

python ray/python/ray/rllib/train.py --env CartPole-v0 --alg PPO --config '{"timesteps_per_batch": 10000}'
python ray/python/ray/rllib/train.py --env CartPole-v0 --run PPO --config '{"timesteps_per_batch": 10000}'

By default, the results will be logged to a subdirectory of ``/tmp/ray``.
This subdirectory will contain a file ``config.json`` which contains the
Expand All @@ -51,7 +51,7 @@ The ``train.py`` script has a number of options you can show by running

The most important options are for choosing the environment
with ``--env`` (any OpenAI gym environment including ones registered by the user
can be used) and for choosing the algorithm with ``--alg``
can be used) and for choosing the algorithm with ``-run``
(available options are ``PPO``, ``A3C``, ``ES`` and ``DQN``). Each algorithm
has specific hyperparameters that can be set with ``--config``, see the
``DEFAULT_CONFIG`` variable in
Expand Down
2 changes: 1 addition & 1 deletion python/ray/rllib/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ You can run training with

::

python train.py --env CartPole-v0 --alg PPO
python train.py --env CartPole-v0 --run PPO

The available algorithms are:

Expand Down
19 changes: 19 additions & 0 deletions python/ray/rllib/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from ray.tune.registry import register_trainable
from ray.rllib import ppo, es, dqn, a3c
from ray.rllib.agent import _MockAgent, _SigmoidFakeData


def _register_all():
register_trainable("PPO", ppo.PPOAgent)
register_trainable("ES", es.ESAgent)
register_trainable("DQN", dqn.DQNAgent)
register_trainable("A3C", a3c.A3CAgent)
register_trainable("__fake", _MockAgent)
register_trainable("__sigmoid_fake_data", _SigmoidFakeData)


_register_all()
48 changes: 28 additions & 20 deletions python/ray/rllib/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@

import tensorflow as tf
from ray.tune.logger import UnifiedLogger
from ray.tune.registry import ENV_CREATOR
from ray.tune.result import TrainingResult
from ray.tune.trainable import Trainable

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


class Agent(object):
class Agent(Trainable):
"""All RLlib agents extend this base class.
Agent objects retain internal model state between calls to train(), so
Expand All @@ -33,39 +35,40 @@ class Agent(object):
env_creator (func): Function that creates a new training env.
config (obj): Algorithm-specific configuration data.
logdir (str): Directory in which training outputs should be placed.
registry (obj): Object registry.
"""

_allow_unknown_configs = False
_default_logdir = "/tmp/ray"

def __init__(
self, env_creator, config, logger_creator=None):
self, config={}, env=None, registry=None, logger_creator=None):
"""Initialize an RLLib agent.
Args:
env_creator (str|func): Name of the OpenAI gym environment to train
against, or a function that creates such an env.
config (dict): Algorithm-specific configuration data.
env (str): Name of the environment to use. Note that this can also
be specified as the `env` key in config.
registry (obj): Object registry for user-defined envs, models, etc.
If unspecified, it will be assumed empty.
logger_creator (func): Function that creates a ray.tune.Logger
object. If unspecified, a default logger is created.
"""
self._initialize_ok = False
self._experiment_id = uuid.uuid4().hex
if type(env_creator) is str:
import gym
env_name = env_creator
self.env_creator = lambda: gym.make(env_name)
env = env or config.get("env")
if env:
config["env"] = env
if registry and registry.contains(ENV_CREATOR, env):
self.env_creator = registry.get(ENV_CREATOR, env)
else:
if hasattr(env_creator, "env_name"):
env_name = env_creator.env_name
else:
env_name = "custom"
self.env_creator = env_creator

import gym
self.env_creator = lambda: gym.make(env)
self.config = self._default_config.copy()
self.registry = registry
if not self._allow_unknown_configs:
for k in config.keys():
if k not in self.config:
if k not in self.config and k != "env":
raise Exception(
"Unknown agent config `{}`, "
"all agent configs: {}".format(k, self.config.keys()))
Expand All @@ -76,8 +79,7 @@ def __init__(
self.logdir = self._result_logger.logdir
else:
logdir_suffix = "{}_{}_{}".format(
env_name,
self._agent_name,
env, self._agent_name,
datetime.today().strftime("%Y-%m-%d_%H-%M-%S"))
if not os.path.exists(self._default_logdir):
os.makedirs(self._default_logdir)
Expand Down Expand Up @@ -214,7 +216,14 @@ def restore_from_object(self, obj):
def stop(self):
"""Releases all resources used by this agent."""

self._result_logger.close()
if self._initialize_ok:
self._result_logger.close()
self._stop()

def _stop(self):
"""Subclasses should override this for custom stopping."""

pass

def compute_action(self, observation):
"""Computes an action using the current trained policy."""
Expand Down Expand Up @@ -336,5 +345,4 @@ def get_agent_class(alg):
return _SigmoidFakeData
else:
raise Exception(
("Unknown algorithm {}, check --alg argument. Valid choices " +
"are PPO, ES, DQN, and A3C.").format(alg))
("Unknown algorithm {}.").format(alg))
2 changes: 1 addition & 1 deletion python/ray/rllib/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ class DQNAgent(Agent):
_agent_name = "DQN"
_default_config = DEFAULT_CONFIG

def stop(self):
def _stop(self):
for w in self.workers:
w.stop.remote()

Expand Down
4 changes: 2 additions & 2 deletions python/ray/rllib/test/test_checkpoint_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def get_mean_action(alg, obs):

def test(use_object_store, alg_name):
cls = get_agent_class(alg_name)
alg1 = cls("CartPole-v0", CONFIGS[name])
alg2 = cls("CartPole-v0", CONFIGS[name])
alg1 = cls(config=CONFIGS[name], env="CartPole-v0")
alg2 = cls(config=CONFIGS[name], env="CartPole-v0")

for _ in range(3):
res = alg1.train()
Expand Down
45 changes: 26 additions & 19 deletions python/ray/rllib/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
import yaml

from ray.tune.config_parser import make_parser, resources_to_json
from ray.tune.tune import make_scheduler, run_experiments
from ray.tune.tune import _make_scheduler, run_experiments


EXAMPLE_USAGE = """
Training example:
./train.py --alg DQN --env CartPole-v0
./train.py --run DQN --env CartPole-v0
Grid search example:
./train.py -f tuned_examples/cartpole-grid-search-example.yaml
Expand All @@ -29,16 +29,24 @@
epilog=EXAMPLE_USAGE)

# See also the base parser definition in ray/tune/config_parser.py
parser.add_argument("--redis-address", default=None, type=str,
help="The Redis address of the cluster.")
parser.add_argument("--num-cpus", default=None, type=int,
help="Number of CPUs to allocate to Ray.")
parser.add_argument("--num-gpus", default=None, type=int,
help="Number of GPUs to allocate to Ray.")
parser.add_argument("--experiment-name", default="default", type=str,
help="Name of experiment dir.")
parser.add_argument("-f", "--config-file", default=None, type=str,
help="If specified, use config options from this file.")
parser.add_argument(
"--redis-address", default=None, type=str,
help="The Redis address of the cluster.")
parser.add_argument(
"--num-cpus", default=None, type=int,
help="Number of CPUs to allocate to Ray.")
parser.add_argument(
"--num-gpus", default=None, type=int,
help="Number of GPUs to allocate to Ray.")
parser.add_argument(
"--experiment-name", default="default", type=str,
help="Name of the subdirectory under `local_dir` to put results in.")
parser.add_argument(
"--env", default=None, type=str, help="The gym environment to use.")
parser.add_argument(
"-f", "--config-file", default=None, type=str,
help="If specified, use config options from this file. Note that this "
"overrides any trial-specific options set via flags above.")


if __name__ == "__main__":
Expand All @@ -50,26 +58,25 @@
# Note: keep this in sync with tune/config_parser.py
experiments = {
args.experiment_name: { # i.e. log to /tmp/ray/default
"alg": args.alg,
"run": args.run,
"checkpoint_freq": args.checkpoint_freq,
"local_dir": args.local_dir,
"env": args.env,
"resources": resources_to_json(args.resources),
"stop": args.stop,
"config": args.config,
"config": dict(args.config, env=args.env),
"restore": args.restore,
"repeat": args.repeat,
"upload_dir": args.upload_dir,
}
}

for exp in experiments.values():
if not exp.get("alg"):
parser.error("the following arguments are required: --alg")
if not exp.get("env"):
if not exp.get("run"):
parser.error("the following arguments are required: --run")
if not exp.get("env") and not exp.get("config", {}).get("env"):
parser.error("the following arguments are required: --env")

run_experiments(
experiments, scheduler=make_scheduler(args),
experiments, scheduler=_make_scheduler(args),
redis_address=args.redis_address,
num_cpus=args.num_cpus, num_gpus=args.num_gpus)
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
cartpole-ppo:
env: CartPole-v0
alg: PPO
run: PPO
stop:
episode_reward_mean: 200
time_total_s: 180
Expand Down
2 changes: 1 addition & 1 deletion python/ray/rllib/tuned_examples/hopper-ppo.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
hopper-ppo:
env: Hopper-v1
alg: PPO
run: PPO
resources:
cpu: 64
gpu: 4
Expand Down
2 changes: 1 addition & 1 deletion python/ray/rllib/tuned_examples/humanoid-es.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
humanoid-es:
env: Humanoid-v1
alg: ES
run: ES
resources:
cpu: 100
driver_cpu_limit: 4
Expand Down
2 changes: 1 addition & 1 deletion python/ray/rllib/tuned_examples/humanoid-ppo-gae.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
humanoid-ppo-gae:
env: Humanoid-v1
alg: PPO
run: PPO
stop:
episode_reward_mean: 6000
resources:
Expand Down
2 changes: 1 addition & 1 deletion python/ray/rllib/tuned_examples/humanoid-ppo.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
humanoid-ppo:
env: Humanoid-v1
alg: PPO
run: PPO
stop:
episode_reward_mean: 6000
resources:
Expand Down
2 changes: 1 addition & 1 deletion python/ray/rllib/tuned_examples/hyperband-cartpole.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
cartpole-ppo:
env: CartPole-v0
alg: PPO
run: PPO
repeat: 3
stop:
episode_reward_mean: 200
Expand Down
2 changes: 1 addition & 1 deletion python/ray/rllib/tuned_examples/pong-a3c.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
pong-a3c:
env: PongDeterministic-v4
alg: A3C
run: A3C
resources:
cpu: 16
driver_cpu_limit: 1
Expand Down
4 changes: 2 additions & 2 deletions python/ray/rllib/tuned_examples/pong-dqn.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
pong-deterministic-dqn:
env: PongDeterministic-v4
alg: DQN
run: DQN
resources:
cpu: 1
gpu: 1
Expand Down Expand Up @@ -28,7 +28,7 @@ pong-deterministic-dqn:
]
pong-noframeskip-dqn:
env: PongNoFrameskip-v4
alg: DQN
run: DQN
resources:
cpu: 1
gpu: 1
Expand Down
2 changes: 1 addition & 1 deletion python/ray/rllib/tuned_examples/walker2d-ppo.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
walker2d-v1-ppo:
env: Walker2d-v1
alg: PPO
run: PPO
resources:
cpu: 64
gpu: 4
Expand Down
Loading

0 comments on commit 316f9e2

Please sign in to comment.