Skip to content

Commit

Permalink
[rllib] Make sure to always record stats like time elapsed, timesteps (
Browse files Browse the repository at this point in the history
…ray-project#965)

* always record training stats

* fix

* comments

* revert assert

* nan

* fix
  • Loading branch information
ericl authored and pcmoritz committed Sep 12, 2017
1 parent 74ac806 commit 9f42ef6
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 67 deletions.
32 changes: 17 additions & 15 deletions python/ray/rllib/a3c/a3c.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,8 @@ def __init__(self, env_name, config,
config["batch_size"], self.logdir)
for i in range(config["num_workers"])]
self.parameters = self.policy.get_weights()
self.iteration = 0

def train(self):
def _train(self):
gradient_list = [
agent.compute_gradient.remote(self.parameters)
for agent in self.agents]
Expand All @@ -119,7 +118,6 @@ def train(self):
[self.agents[info["id"]].compute_gradient.remote(
self.parameters)])
res = self._fetch_metrics_from_workers()
self.iteration += 1
return res

def _fetch_metrics_from_workers(self):
Expand All @@ -131,27 +129,31 @@ def _fetch_metrics_from_workers(self):
for episode in ray.get(metrics):
episode_lengths.append(episode.episode_length)
episode_rewards.append(episode.episode_reward)
avg_reward = np.mean(episode_rewards) if episode_rewards else None
avg_length = np.mean(episode_lengths) if episode_lengths else None
res = TrainingResult(
self.experiment_id.hex, self.iteration,
avg_reward, avg_length, dict())
return res
avg_reward = (
np.mean(episode_rewards) if episode_rewards else float('nan'))
avg_length = (
np.mean(episode_lengths) if episode_lengths else float('nan'))
timesteps = np.sum(episode_lengths) if episode_lengths else 0

result = TrainingResult(
episode_reward_mean=avg_reward,
episode_len_mean=avg_length,
timesteps_this_iter=timesteps,
info={})

return result

def save(self):
def _save(self):
checkpoint_path = os.path.join(
self.logdir, "checkpoint-{}".format(self.iteration))
objects = [
self.parameters,
self.iteration]
objects = [self.parameters]
pickle.dump(objects, open(checkpoint_path, "wb"))
return checkpoint_path

def restore(self, checkpoint_path):
def _restore(self, checkpoint_path):
objects = pickle.load(open(checkpoint_path, "rb"))
self.parameters = objects[0]
self.policy.set_weights(self.parameters)
self.iteration = objects[1]

def compute_action(self, observation):
actions = self.policy.compute_actions(observation)[0]
Expand Down
90 changes: 82 additions & 8 deletions python/ray/rllib/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import logging
import numpy as np
import os
import pickle
import sys
import tempfile
import time
import uuid
import smart_open

Expand Down Expand Up @@ -46,13 +48,37 @@ def write(self, b):


TrainingResult = namedtuple("TrainingResult", [
# Unique string identifier for this experiment. This id is preserved
# across checkpoint / restore calls.
"experiment_id",

# The index of this training iteration, e.g. call to train().
"training_iteration",

# The mean episode reward reported during this iteration.
"episode_reward_mean",

# The mean episode length reported during this iteration.
"episode_len_mean",
"info"

# Agent-specific metadata to report for this iteration.
"info",

# Number of timesteps in the simulator in this iteration.
"timesteps_this_iter",

# Accumulated timesteps for this entire experiment.
"timesteps_total",

# Time in seconds this iteration took to run.
"time_this_iter_s",

# Accumulated time in seconds for this entire experiment.
"time_total_s",
])

TrainingResult.__new__.__defaults__ = (None,) * len(TrainingResult._fields)


class Agent(object):
"""All RLlib agents extend this base class.
Expand All @@ -64,8 +90,6 @@ class Agent(object):
env_name (str): Name of the OpenAI gym environment to train against.
config (obj): Algorithm-specific configuration data.
logdir (str): Directory in which training outputs should be placed.
TODO(ekl): support checkpoint / restore of training state.
"""

def __init__(self, env_name, config, upload_dir=None):
Expand All @@ -79,10 +103,11 @@ def __init__(self, env_name, config, upload_dir=None):
like s3://bucketname/.
"""
upload_dir = "file:///tmp/ray" if upload_dir is None else upload_dir
self.experiment_id = uuid.uuid4()
self.experiment_id = uuid.uuid4().hex
self.env_name = env_name

self.config = config
self.config.update({"experiment_id": self.experiment_id.hex})
self.config.update({"experiment_id": self.experiment_id})
self.config.update({"env_name": env_name})
prefix = "{}_{}_{}".format(
env_name,
Expand All @@ -92,21 +117,45 @@ def __init__(self, env_name, config, upload_dir=None):
self.logdir = tempfile.mkdtemp(prefix=prefix, dir="/tmp/ray")
else:
self.logdir = os.path.join(upload_dir, prefix)

# TODO(ekl) consider inlining config into the result jsons
log_path = os.path.join(self.logdir, "config.json")
with smart_open.smart_open(log_path, "w") as f:
json.dump(self.config, f, sort_keys=True, cls=RLLibEncoder)
logger.info(
"%s algorithm created with logdir '%s'",
self.__class__.__name__, self.logdir)

self.iteration = 0
self.time_total = 0.0
self.timesteps_total = 0

def train(self):
"""Runs one logical iteration of training.
Returns:
A TrainingResult that describes training progress.
"""

raise NotImplementedError
start = time.time()
self.iteration += 1
result = self._train()
time_this_iter = time.time() - start

self.time_total += time_this_iter
self.timesteps_total += result.timesteps_this_iter

result = result._replace(
experiment_id=self.experiment_id,
training_iteration=self.iteration,
timesteps_total=self.timesteps_total,
time_this_iter_s=time_this_iter,
time_total_s=self.time_total)

for field in result:
assert field is not None, result

return result

def save(self):
"""Saves the current model state to a checkpoint.
Expand All @@ -115,17 +164,42 @@ def save(self):
Checkpoint path that may be passed to restore().
"""

raise NotImplementedError
checkpoint_path = self._save()
pickle.dump(
[self.experiment_id, self.iteration, self.timesteps_total,
self.time_total_s],
open(checkpoint_path + ".rllib_metadata", "wb"))
return checkpoint_path

def restore(self, checkpoint_path):
"""Restores training state from a given model checkpoint.
These checkpoints are returned from calls to save().
"""

raise NotImplementedError
self._restore(checkpoint_path)
metadata = pickle.load(open(checkpoint_path + ".rllib_metadata", "rb"))
self.experiment_id = metadata[0]
self.iteration = metadata[1]
self.timesteps_total = metadata[2]
self.time_total_s = metadata[3]

def compute_action(self, observation):
"""Computes an action using the current trained policy."""

raise NotImplementedError

def _train(self):
"""Subclasses should override this to implement train()."""

raise NotImplementedError

def _save(self):
"""Subclasses should override this to implement save()."""

raise NotImplementedError

def _restore(self):
"""Subclasses should override this to implement restore()."""

raise NotImplementedError
19 changes: 11 additions & 8 deletions python/ray/rllib/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,10 @@ def _init(self):
self.file_writer = tf.summary.FileWriter(self.logdir, self.sess.graph)
self.saver = tf.train.Saver(max_to_keep=None)

def train(self):
def _train(self):
config = self.config
sample_time, learn_time = 0, 0
iter_init_timesteps = self.num_timesteps

for _ in range(config["timesteps_per_iteration"]):
self.num_timesteps += 1
Expand Down Expand Up @@ -243,13 +244,15 @@ def train(self):
int(100 * self.exploration.value(self.num_timesteps)))
logger.dump_tabular()

res = TrainingResult(
self.experiment_id.hex, self.num_iterations, mean_100ep_reward,
mean_100ep_length, info)
self.num_iterations += 1
return res
result = TrainingResult(
episode_reward_mean=mean_100ep_reward,
episode_len_mean=mean_100ep_length,
timesteps_this_iter=self.num_timesteps - iter_init_timesteps,
info=info)

def save(self):
return result

def _save(self):
checkpoint_path = self.saver.save(
self.sess,
os.path.join(self.logdir, "checkpoint"),
Expand All @@ -267,7 +270,7 @@ def save(self):
pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb"))
return checkpoint_path

def restore(self, checkpoint_path):
def _restore(self, checkpoint_path):
self.saver.restore(self.sess, checkpoint_path)
extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb"))
self.replay_buffer = extra_data[0]
Expand Down
29 changes: 10 additions & 19 deletions python/ray/rllib/es/es.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,6 @@ def _init(self):
self.episodes_so_far = 0
self.timesteps_so_far = 0
self.tstart = time.time()
self.iteration = 0

def _collect_results(self, theta_id, min_eps, min_timesteps):
num_eps, num_timesteps = 0, 0
Expand All @@ -224,7 +223,7 @@ def _collect_results(self, theta_id, min_eps, min_timesteps):
num_timesteps += result.lengths_n2.sum()
return results

def train(self):
def _train(self):
config = self.config

step_tstart = time.time()
Expand Down Expand Up @@ -314,14 +313,6 @@ def train(self):
tlogger.record_tabular("TimeElapsed", step_tend - self.tstart)
tlogger.dump_tabular()

if (config["snapshot_freq"] != 0 and
self.iteration % config["snapshot_freq"] == 0):
filename = os.path.join(
self.logdir, "snapshot_iter{:05d}.h5".format(self.iteration))
assert not os.path.exists(filename)
self.policy.save(filename)
tlogger.log("Saved snapshot {}".format(filename))

info = {
"weights_norm": np.square(self.policy.get_trainable_flat()).sum(),
"grad_norm": np.square(g).sum(),
Expand All @@ -334,33 +325,33 @@ def train(self):
"time_elapsed_this_iter": step_tend - step_tstart,
"time_elapsed": step_tend - self.tstart
}
res = TrainingResult(self.experiment_id.hex, self.iteration,
returns_n2.mean(), lengths_n2.mean(), info)

self.iteration += 1
result = TrainingResult(
episode_reward_mean=returns_n2.mean(),
episode_len_mean=lengths_n2.mean(),
timesteps_this_iter=lengths_n2.sum(),
info=info)

return res
return result

def save(self):
def _save(self):
checkpoint_path = os.path.join(
self.logdir, "checkpoint-{}".format(self.iteration))
weights = self.policy.get_trainable_flat()
objects = [
weights,
self.ob_stat,
self.episodes_so_far,
self.timesteps_so_far,
self.iteration]
self.timesteps_so_far]
pickle.dump(objects, open(checkpoint_path, "wb"))
return checkpoint_path

def restore(self, checkpoint_path):
def _restore(self, checkpoint_path):
objects = pickle.load(open(checkpoint_path, "rb"))
self.policy.set_trainable_flat(objects[0])
self.ob_stat = objects[1]
self.episodes_so_far = objects[2]
self.timesteps_so_far = objects[3]
self.iteration = objects[4]

def compute_action(self, observation):
return self.policy.act([observation])[0]
Loading

0 comments on commit 9f42ef6

Please sign in to comment.