Skip to content

Commit

Permalink
[rllib] Add the ability to run arbitrary Python scripts with ray.tune (
Browse files Browse the repository at this point in the history
…ray-project#1132)

* fix yaml bug

* add ext agent

* gpus

* update

* tuning

* docs

* Sun Oct 15 21:09:25 PDT 2017

* lint

* update

* Sun Oct 15 22:39:55 PDT 2017

* Sun Oct 15 22:40:17 PDT 2017

* Sun Oct 15 22:43:06 PDT 2017

* Sun Oct 15 22:46:06 PDT 2017

* Sun Oct 15 22:46:21 PDT 2017

* Sun Oct 15 22:48:11 PDT 2017

* Sun Oct 15 22:48:44 PDT 2017

* Sun Oct 15 22:49:23 PDT 2017

* Sun Oct 15 22:50:21 PDT 2017

* Sun Oct 15 22:53:00 PDT 2017

* Sun Oct 15 22:53:34 PDT 2017

* Sun Oct 15 22:54:33 PDT 2017

* Sun Oct 15 22:54:50 PDT 2017

* Sun Oct 15 22:55:20 PDT 2017

* Sun Oct 15 22:56:56 PDT 2017

* Sun Oct 15 22:59:03 PDT 2017

* fix

* Update tune_mnist_ray.py

* remove script trial

* fix

* reorder

* fix ex

* py2 support

* upd

* comments

* comments

* cleanup readme

* fix trial

* annotate

* Update rllib.rst
  • Loading branch information
ericl authored and richardliaw committed Oct 18, 2017
1 parent 4157bcb commit 5a50e0e
Show file tree
Hide file tree
Showing 24 changed files with 745 additions and 166 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
/python/ray/pyarrow_files/pyarrow/
/python/build
/python/dist
/python/flatbuffers-1.7.1/
/src/common/thirdparty/redis
/src/thirdparty/arrow
/flatbuffers-1.7.1/
Expand Down
2 changes: 1 addition & 1 deletion doc/source/rllib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ some number of iterations of the algorithm, save and load the state
of training and evaluate the current policy. All agents inherit from
a common base class:

.. autoclass:: ray.rllib.common.Agent
.. autoclass:: ray.rllib.agent.Agent
:members:

Models
Expand Down
3 changes: 2 additions & 1 deletion python/ray/rllib/a3c/a3c.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
import os

import ray
from ray.rllib.agent import Agent
from ray.rllib.a3c.runner import RunnerThread, process_rollout
from ray.rllib.a3c.envs import create_and_wrap
from ray.rllib.common import Agent, TrainingResult
from ray.rllib.a3c.shared_model import SharedModel
from ray.rllib.a3c.shared_model_lstm import SharedModelLSTM
from ray.tune.result import TrainingResult


DEFAULT_CONFIG = {
Expand Down
134 changes: 73 additions & 61 deletions python/ray/rllib/common.py → python/ray/rllib/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from __future__ import division
from __future__ import print_function

from collections import namedtuple
from datetime import datetime

import json
Expand All @@ -16,6 +15,7 @@
import uuid

import tensorflow as tf
from ray.tune.result import TrainingResult

if sys.version_info[0] == 2:
import cStringIO as StringIO
Expand All @@ -26,39 +26,6 @@
logger.setLevel(logging.INFO)


TrainingResult = namedtuple("TrainingResult", [
# Unique string identifier for this experiment. This id is preserved
# across checkpoint / restore calls.
"experiment_id",

# The index of this training iteration, e.g. call to train().
"training_iteration",

# The mean episode reward reported during this iteration.
"episode_reward_mean",

# The mean episode length reported during this iteration.
"episode_len_mean",

# Agent-specific metadata to report for this iteration.
"info",

# Number of timesteps in the simulator in this iteration.
"timesteps_this_iter",

# Accumulated timesteps for this entire experiment.
"timesteps_total",

# Time in seconds this iteration took to run.
"time_this_iter_s",

# Accumulated time in seconds for this entire experiment.
"time_total_s",
])

TrainingResult.__new__.__defaults__ = (None,) * len(TrainingResult._fields)


class Agent(object):
"""All RLlib agents extend this base class.
Expand All @@ -71,6 +38,8 @@ class Agent(object):
logdir (str): Directory in which training outputs should be placed.
"""

_allow_unknown_configs = False

def __init__(
self, env_creator, config, local_dir='/tmp/ray',
upload_dir=None, agent_id=None):
Expand All @@ -97,11 +66,12 @@ def __init__(
self.env_creator = env_creator

self.config = self._default_config.copy()
for k in config.keys():
if k not in self.config:
raise Exception(
"Unknown agent config `{}`, "
"all agent configs: {}".format(k, self.config.keys()))
if not self._allow_unknown_configs:
for k in config.keys():
if k not in self.config:
raise Exception(
"Unknown agent config `{}`, "
"all agent configs: {}".format(k, self.config.keys()))
self.config.update(config)
self.config.update({
"agent_id": agent_id,
Expand All @@ -112,7 +82,7 @@ def __init__(

logdir_suffix = "{}_{}_{}".format(
env_name,
self.__class__.__name__,
self._agent_name,
agent_id or datetime.today().strftime("%Y-%m-%d_%H-%M-%S"))

if not os.path.exists(local_dir):
Expand All @@ -128,12 +98,12 @@ def __init__(
# TODO(ekl) consider inlining config into the result jsons
config_out = os.path.join(self.logdir, "config.json")
with open(config_out, "w") as f:
json.dump(self.config, f, sort_keys=True, cls=RLLibEncoder)
json.dump(self.config, f, sort_keys=True, cls=_Encoder)
logger.info(
"%s algorithm created with logdir '%s' and upload uri '%s'",
"%s agent created with logdir '%s' and upload uri '%s'",
self.__class__.__name__, self.logdir, log_upload_uri)

self._result_logger = RLLibLogger(
self._result_logger = _Logger(
os.path.join(self.logdir, "result.json"),
log_upload_uri and os.path.join(log_upload_uri, "result.json"))
self._file_writer = tf.summary.FileWriter(self.logdir)
Expand Down Expand Up @@ -162,6 +132,8 @@ def train(self):
self._iteration += 1
time_this_iter = time.time() - start

assert result.timesteps_this_iter is not None

self._time_total += time_this_iter
self._timesteps_total += result.timesteps_this_iter

Expand All @@ -170,10 +142,9 @@ def train(self):
training_iteration=self._iteration,
timesteps_total=self._timesteps_total,
time_this_iter_s=time_this_iter,
time_total_s=self._time_total)

for field in result:
assert field is not None, result
time_total_s=self._time_total,
pid=os.getpid(),
hostname=os.uname()[1])

self._log_result(result)

Expand All @@ -184,18 +155,18 @@ def _log_result(self, result):

# We need to use a custom json serializer class so that NaNs get
# encoded as null as required by Athena.
json.dump(result._asdict(), self._result_logger, cls=RLLibEncoder)
json.dump(result._asdict(), self._result_logger, cls=_Encoder)
self._result_logger.write("\n")
train_stats = tf.Summary(value=[
tf.Summary.Value(
tag="rllib/time_this_iter_s",
simple_value=result.time_this_iter_s),
tf.Summary.Value(
tag="rllib/episode_reward_mean",
simple_value=result.episode_reward_mean),
tf.Summary.Value(
tag="rllib/episode_len_mean",
simple_value=result.episode_len_mean)])
attrs_to_log = [
"time_this_iter_s", "mean_loss", "mean_accuracy",
"episode_reward_mean", "episode_len_mean"]
values = []
for attr in attrs_to_log:
if getattr(result, attr) is not None:
values.append(tf.Summary.Value(
tag="ray/tune/{}".format(attr),
simple_value=getattr(result, attr)))
train_stats = tf.Summary(value=values)
self._file_writer.add_summary(train_stats, result.training_iteration)

def save(self):
Expand Down Expand Up @@ -269,10 +240,10 @@ def _restore(self):
raise NotImplementedError


class RLLibEncoder(json.JSONEncoder):
class _Encoder(json.JSONEncoder):

def __init__(self, nan_str="null", **kwargs):
super(RLLibEncoder, self).__init__(**kwargs)
super(_Encoder, self).__init__(**kwargs)
self.nan_str = nan_str

def iterencode(self, o, _one_shot=False):
Expand All @@ -299,7 +270,7 @@ def default(self, value):
return int(value)


class RLLibLogger(object):
class _Logger(object):
"""Writing small amounts of data to S3 with real-time updates.
"""

Expand All @@ -322,3 +293,44 @@ def write(self, b):
with self.smart_open(self.uri, "w") as f:
self.result_buffer.write(b)
f.write(self.result_buffer.getvalue())


class _MockAgent(Agent):
"""Mock agent for use in tests"""

_agent_name = "MockAgent"
_default_config = {}

def _init(self):
pass

def _train(self):
return TrainingResult(
episode_reward_mean=10, episode_len_mean=10,
timesteps_this_iter=10, info={})


def get_agent_class(alg):
"""Returns the class of an known agent given its name."""

if alg == "PPO":
from ray.rllib import ppo
return ppo.PPOAgent
elif alg == "ES":
from ray.rllib import es
return es.ESAgent
elif alg == "DQN":
from ray.rllib import dqn
return dqn.DQNAgent
elif alg == "A3C":
from ray.rllib import a3c
return a3c.A3CAgent
elif alg == "script":
from ray.tune import script_runner
return script_runner.ScriptRunner
elif alg == "__fake":
return _MockAgent
else:
raise Exception(
("Unknown algorithm {}, check --alg argument. Valid choices " +
"are PPO, ES, DQN, and A3C.").format(alg))
43 changes: 0 additions & 43 deletions python/ray/rllib/agents.py

This file was deleted.

3 changes: 2 additions & 1 deletion python/ray/rllib/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
import tensorflow as tf

import ray
from ray.rllib.common import Agent, TrainingResult
from ray.rllib.agent import Agent
from ray.rllib.dqn import logger, models
from ray.rllib.dqn.common.wrappers import wrap_dqn
from ray.rllib.dqn.common.schedules import LinearSchedule
from ray.rllib.dqn.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer
from ray.tune.result import TrainingResult


"""The default configuration dict for the DQN algorithm.
Expand Down
3 changes: 2 additions & 1 deletion python/ray/rllib/es/es.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@
import time

import ray
from ray.rllib.common import Agent, TrainingResult
from ray.rllib.agent import Agent
from ray.rllib.models import ModelCatalog

from ray.rllib.es import optimizers
from ray.rllib.es import policies
from ray.rllib.es import tabular_logger as tlogger
from ray.rllib.es import tf_util
from ray.rllib.es import utils
from ray.tune.result import TrainingResult


Result = namedtuple("Result", [
Expand Down
3 changes: 2 additions & 1 deletion python/ray/rllib/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from tensorflow.python import debug as tf_debug

import ray
from ray.rllib.common import Agent, TrainingResult
from ray.rllib.agent import Agent
from ray.tune.result import TrainingResult
from ray.rllib.ppo.runner import Runner, RemoteRunner
from ray.rllib.ppo.rollout import collect_samples
from ray.rllib.ppo.utils import shuffle
Expand Down
2 changes: 1 addition & 1 deletion python/ray/rllib/test/test_checkpoint_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import ray
import random

from ray.rllib.agents import get_agent_class
from ray.rllib.agent import get_agent_class


def get_mean_action(alg, obs):
Expand Down
19 changes: 16 additions & 3 deletions python/ray/rllib/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,18 @@
# defined there.
parser.add_argument("--redis-address", default=None, type=str,
help="The Redis address of the cluster.")
parser.add_argument("--num-cpus", default=None, type=int,
help="Number of CPUs to allocate to Ray.")
parser.add_argument("--num-gpus", default=None, type=int,
help="Number of GPUs to allocate to Ray.")
parser.add_argument("--restore", default=None, type=str,
help="If specified, restore from this checkpoint.")
parser.add_argument("-f", "--config-file", default=None, type=str,
help="If specified, use config options from this file.")


if __name__ == "__main__":
args = parser.parse_args()
def main(argv):
args = parser.parse_args(argv)
runner = TrialRunner()

if args.config_file:
Expand All @@ -56,12 +60,21 @@
args.resources, args.stop, args.checkpoint_freq,
args.restore, args.upload_dir))

ray.init(redis_address=args.redis_address)
ray.init(
redis_address=args.redis_address, num_cpus=args.num_cpus,
num_gpus=args.num_gpus)

while not runner.is_finished():
runner.step()
print(runner.debug_string())

for trial in runner.get_trials():
if trial.status != Trial.TERMINATED:
print("Exit 1")
sys.exit(1)

print("Exit 0")


if __name__ == "__main__":
main(sys.argv[1:])
Loading

0 comments on commit 5a50e0e

Please sign in to comment.