Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rllib] Add the ability to run arbitrary Python scripts with ray.tune #1132

Merged
merged 39 commits into from
Oct 18, 2017
Merged
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
07e9263
fix yaml bug
ericl Oct 13, 2017
ed9d2d1
add ext agent
ericl Oct 13, 2017
4f4d6de
gpus
ericl Oct 14, 2017
2aaf661
update
ericl Oct 14, 2017
bc86b46
tuning
ericl Oct 16, 2017
96775f1
docs
ericl Oct 16, 2017
56ad4ea
Sun Oct 15 21:09:25 PDT 2017
ericl Oct 16, 2017
55b2386
lint
ericl Oct 16, 2017
a9e8345
update
ericl Oct 16, 2017
140b612
Sun Oct 15 22:39:55 PDT 2017
ericl Oct 16, 2017
924741d
Sun Oct 15 22:40:17 PDT 2017
ericl Oct 16, 2017
9524e03
Sun Oct 15 22:43:06 PDT 2017
ericl Oct 16, 2017
d74fafc
Sun Oct 15 22:46:06 PDT 2017
ericl Oct 16, 2017
c0a98b1
Sun Oct 15 22:46:21 PDT 2017
ericl Oct 16, 2017
2f5a93f
Sun Oct 15 22:48:11 PDT 2017
ericl Oct 16, 2017
de3f5f9
Sun Oct 15 22:48:44 PDT 2017
ericl Oct 16, 2017
9cea84e
Sun Oct 15 22:49:23 PDT 2017
ericl Oct 16, 2017
a5eb3bb
Sun Oct 15 22:50:21 PDT 2017
ericl Oct 16, 2017
be304f8
Sun Oct 15 22:53:00 PDT 2017
ericl Oct 16, 2017
da8fff1
Sun Oct 15 22:53:34 PDT 2017
ericl Oct 16, 2017
f3b75b7
Sun Oct 15 22:54:33 PDT 2017
ericl Oct 16, 2017
f7ba39f
Sun Oct 15 22:54:50 PDT 2017
ericl Oct 16, 2017
b3ae696
Sun Oct 15 22:55:20 PDT 2017
ericl Oct 16, 2017
2eedda8
Sun Oct 15 22:56:56 PDT 2017
ericl Oct 16, 2017
7477228
Sun Oct 15 22:59:03 PDT 2017
ericl Oct 16, 2017
2c6ccd0
fix
ericl Oct 16, 2017
1b13de1
Update tune_mnist_ray.py
ericl Oct 16, 2017
cbb168e
remove script trial
ericl Oct 16, 2017
196aa6c
fix
ericl Oct 16, 2017
9265b83
reorder
ericl Oct 16, 2017
ef8f5c0
fix ex
ericl Oct 16, 2017
88c3a96
py2 support
ericl Oct 16, 2017
d8babc2
upd
ericl Oct 16, 2017
1060ca9
comments
ericl Oct 16, 2017
2b64a06
comments
ericl Oct 16, 2017
0d4d6c6
cleanup readme
ericl Oct 16, 2017
9cea4cf
fix trial
ericl Oct 17, 2017
0323c24
annotate
ericl Oct 17, 2017
d5ba6da
Update rllib.rst
ericl Oct 17, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
/python/ray/pyarrow_files/pyarrow/
/python/build
/python/dist
/python/flatbuffers-1.7.1/
/src/common/thirdparty/redis
/src/thirdparty/arrow
/flatbuffers-1.7.1/
Expand Down
3 changes: 2 additions & 1 deletion python/ray/rllib/a3c/a3c.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
import os

import ray
from ray.rllib.agent import Agent
from ray.rllib.a3c.runner import RunnerThread, process_rollout
from ray.rllib.a3c.envs import create_and_wrap
from ray.rllib.common import Agent, TrainingResult
from ray.rllib.a3c.shared_model import SharedModel
from ray.rllib.a3c.shared_model_lstm import SharedModelLSTM
from ray.tune.result import TrainingResult


DEFAULT_CONFIG = {
Expand Down
134 changes: 73 additions & 61 deletions python/ray/rllib/common.py → python/ray/rllib/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from __future__ import division
from __future__ import print_function

from collections import namedtuple
from datetime import datetime

import json
Expand All @@ -16,6 +15,7 @@
import uuid

import tensorflow as tf
from ray.tune.result import TrainingResult

if sys.version_info[0] == 2:
import cStringIO as StringIO
Expand All @@ -26,39 +26,6 @@
logger.setLevel(logging.INFO)


TrainingResult = namedtuple("TrainingResult", [
# Unique string identifier for this experiment. This id is preserved
# across checkpoint / restore calls.
"experiment_id",

# The index of this training iteration, e.g. call to train().
"training_iteration",

# The mean episode reward reported during this iteration.
"episode_reward_mean",

# The mean episode length reported during this iteration.
"episode_len_mean",

# Agent-specific metadata to report for this iteration.
"info",

# Number of timesteps in the simulator in this iteration.
"timesteps_this_iter",

# Accumulated timesteps for this entire experiment.
"timesteps_total",

# Time in seconds this iteration took to run.
"time_this_iter_s",

# Accumulated time in seconds for this entire experiment.
"time_total_s",
])

TrainingResult.__new__.__defaults__ = (None,) * len(TrainingResult._fields)


class Agent(object):
"""All RLlib agents extend this base class.

Expand All @@ -71,6 +38,8 @@ class Agent(object):
logdir (str): Directory in which training outputs should be placed.
"""

_allow_unknown_configs = False

def __init__(
self, env_creator, config, local_dir='/tmp/ray',
upload_dir=None, agent_id=None):
Expand All @@ -97,11 +66,12 @@ def __init__(
self.env_creator = env_creator

self.config = self._default_config.copy()
for k in config.keys():
if k not in self.config:
raise Exception(
"Unknown agent config `{}`, "
"all agent configs: {}".format(k, self.config.keys()))
if not self._allow_unknown_configs:
for k in config.keys():
if k not in self.config:
raise Exception(
"Unknown agent config `{}`, "
"all agent configs: {}".format(k, self.config.keys()))
self.config.update(config)
self.config.update({
"agent_id": agent_id,
Expand All @@ -112,7 +82,7 @@ def __init__(

logdir_suffix = "{}_{}_{}".format(
env_name,
self.__class__.__name__,
self._agent_name,
agent_id or datetime.today().strftime("%Y-%m-%d_%H-%M-%S"))

if not os.path.exists(local_dir):
Expand All @@ -128,12 +98,12 @@ def __init__(
# TODO(ekl) consider inlining config into the result jsons
config_out = os.path.join(self.logdir, "config.json")
with open(config_out, "w") as f:
json.dump(self.config, f, sort_keys=True, cls=RLLibEncoder)
json.dump(self.config, f, sort_keys=True, cls=_Encoder)
logger.info(
"%s algorithm created with logdir '%s' and upload uri '%s'",
"%s agent created with logdir '%s' and upload uri '%s'",
self.__class__.__name__, self.logdir, log_upload_uri)

self._result_logger = RLLibLogger(
self._result_logger = _Logger(
os.path.join(self.logdir, "result.json"),
log_upload_uri and os.path.join(log_upload_uri, "result.json"))
self._file_writer = tf.summary.FileWriter(self.logdir)
Expand Down Expand Up @@ -162,6 +132,8 @@ def train(self):
self._iteration += 1
time_this_iter = time.time() - start

assert result.timesteps_this_iter is not None

self._time_total += time_this_iter
self._timesteps_total += result.timesteps_this_iter

Expand All @@ -170,10 +142,9 @@ def train(self):
training_iteration=self._iteration,
timesteps_total=self._timesteps_total,
time_this_iter_s=time_this_iter,
time_total_s=self._time_total)

for field in result:
assert field is not None, result
time_total_s=self._time_total,
pid=os.getpid(),
hostname=os.uname()[1])

self._log_result(result)

Expand All @@ -184,18 +155,18 @@ def _log_result(self, result):

# We need to use a custom json serializer class so that NaNs get
# encoded as null as required by Athena.
json.dump(result._asdict(), self._result_logger, cls=RLLibEncoder)
json.dump(result._asdict(), self._result_logger, cls=_Encoder)
self._result_logger.write("\n")
train_stats = tf.Summary(value=[
tf.Summary.Value(
tag="rllib/time_this_iter_s",
simple_value=result.time_this_iter_s),
tf.Summary.Value(
tag="rllib/episode_reward_mean",
simple_value=result.episode_reward_mean),
tf.Summary.Value(
tag="rllib/episode_len_mean",
simple_value=result.episode_len_mean)])
attrs_to_log = [
"time_this_iter_s", "mean_loss", "mean_accuracy",
"episode_reward_mean", "episode_len_mean"]
values = []
for attr in attrs_to_log:
if getattr(result, attr) is not None:
values.append(tf.Summary.Value(
tag="ray/tune/{}".format(attr),
simple_value=getattr(result, attr)))
train_stats = tf.Summary(value=values)
self._file_writer.add_summary(train_stats, result.training_iteration)

def save(self):
Expand Down Expand Up @@ -269,10 +240,10 @@ def _restore(self):
raise NotImplementedError


class RLLibEncoder(json.JSONEncoder):
class _Encoder(json.JSONEncoder):

def __init__(self, nan_str="null", **kwargs):
super(RLLibEncoder, self).__init__(**kwargs)
super(_Encoder, self).__init__(**kwargs)
self.nan_str = nan_str

def iterencode(self, o, _one_shot=False):
Expand All @@ -299,7 +270,7 @@ def default(self, value):
return int(value)


class RLLibLogger(object):
class _Logger(object):
"""Writing small amounts of data to S3 with real-time updates.
"""

Expand All @@ -322,3 +293,44 @@ def write(self, b):
with self.smart_open(self.uri, "w") as f:
self.result_buffer.write(b)
f.write(self.result_buffer.getvalue())


class _MockAgent(Agent):
"""Mock agent for use in tests"""

_agent_name = "MockAgent"
_default_config = {}

def _init(self):
pass

def _train(self):
return TrainingResult(
episode_reward_mean=10, episode_len_mean=10,
timesteps_this_iter=10, info={})


def get_agent_class(alg):
"""Returns the class of an known agent given its name."""

if alg == "PPO":
from ray.rllib import ppo
return ppo.PPOAgent
elif alg == "ES":
from ray.rllib import es
return es.ESAgent
elif alg == "DQN":
from ray.rllib import dqn
return dqn.DQNAgent
elif alg == "A3C":
from ray.rllib import a3c
return a3c.A3CAgent
elif alg == "script":
from ray.tune import script_runner
return script_runner.ScriptRunner
elif alg == "__fake":
return _MockAgent
else:
raise Exception(
("Unknown algorithm {}, check --alg argument. Valid choices " +
"are PPO, ES, DQN, and A3C.").format(alg))
43 changes: 0 additions & 43 deletions python/ray/rllib/agents.py

This file was deleted.

3 changes: 2 additions & 1 deletion python/ray/rllib/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@
import tensorflow as tf

import ray
from ray.rllib.common import Agent, TrainingResult
from ray.rllib.agent import Agent
from ray.rllib.dqn import logger, models
from ray.rllib.dqn.common.atari_wrappers_deprecated \
import wrap_dqn, ScaledFloatFrame
from ray.rllib.dqn.common.schedules import LinearSchedule
from ray.rllib.dqn.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer
from ray.tune.result import TrainingResult


"""The default configuration dict for the DQN algorithm.
Expand Down
3 changes: 2 additions & 1 deletion python/ray/rllib/es/es.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@
import time

import ray
from ray.rllib.common import Agent, TrainingResult
from ray.rllib.agent import Agent
from ray.rllib.models import ModelCatalog

from ray.rllib.es import optimizers
from ray.rllib.es import policies
from ray.rllib.es import tabular_logger as tlogger
from ray.rllib.es import tf_util
from ray.rllib.es import utils
from ray.tune.result import TrainingResult


Result = namedtuple("Result", [
Expand Down
3 changes: 2 additions & 1 deletion python/ray/rllib/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from tensorflow.python import debug as tf_debug

import ray
from ray.rllib.common import Agent, TrainingResult
from ray.rllib.agent import Agent
from ray.tune.result import TrainingResult
from ray.rllib.ppo.runner import Runner, RemoteRunner
from ray.rllib.ppo.rollout import collect_samples
from ray.rllib.ppo.utils import shuffle
Expand Down
2 changes: 1 addition & 1 deletion python/ray/rllib/test/test_checkpoint_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import ray
import random

from ray.rllib.agents import get_agent_class
from ray.rllib.agent import get_agent_class


def get_mean_action(alg, obs):
Expand Down
19 changes: 16 additions & 3 deletions python/ray/rllib/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,18 @@
# defined there.
parser.add_argument("--redis-address", default=None, type=str,
help="The Redis address of the cluster.")
parser.add_argument("--num-cpus", default=None, type=int,
help="Number of CPUs to allocate to Ray.")
parser.add_argument("--num-gpus", default=None, type=int,
help="Number of GPUs to allocate to Ray.")
parser.add_argument("--restore", default=None, type=str,
help="If specified, restore from this checkpoint.")
parser.add_argument("-f", "--config-file", default=None, type=str,
help="If specified, use config options from this file.")


if __name__ == "__main__":
args = parser.parse_args()
def main(argv):
args = parser.parse_args(argv)
runner = TrialRunner()

if args.config_file:
Expand All @@ -56,12 +60,21 @@
args.resources, args.stop, args.checkpoint_freq,
args.restore, args.upload_dir))

ray.init(redis_address=args.redis_address)
ray.init(
redis_address=args.redis_address, num_cpus=args.num_cpus,
num_gpus=args.num_gpus)

while not runner.is_finished():
runner.step()
print(runner.debug_string())

for trial in runner.get_trials():
if trial.status != Trial.TERMINATED:
print("Exit 1")
sys.exit(1)

print("Exit 0")


if __name__ == "__main__":
main(sys.argv[1:])
2 changes: 1 addition & 1 deletion python/ray/rllib/tuned_examples/hopper-ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ hopper-ppo:
resources:
cpu: 64
gpu: 4
config: {"gamma": 0.995, "kl_coeff": 1.0, "num_sgd_iter": 20, "sgd_stepsize": 1e-4, "sgd_batchsize": 32768, "devices": ["/gpu:0", "/gpu:1", "/gpu:2", "/gpu:3"], "tf_session_args": {"device_count": {"GPU": 4}, "log_device_placement": false, "allow_soft_placement": true}, "timesteps_per_batch": 160000, "num_workers": 64}
config: {"gamma": 0.995, "kl_coeff": 1.0, "num_sgd_iter": 20, "sgd_stepsize": .0001, "sgd_batchsize": 32768, "devices": ["/gpu:0", "/gpu:1", "/gpu:2", "/gpu:3"], "tf_session_args": {"device_count": {"GPU": 4}, "log_device_placement": false, "allow_soft_placement": true}, "timesteps_per_batch": 160000, "num_workers": 64}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apparently, pyyaml has a bug where it can't parse scientific notation correctly.

Loading