Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rllib] Refactor Multi-GPU for PPO #1646

Merged
merged 18 commits into from
Jun 19, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refactoring the rest of PPO
  • Loading branch information
richardliaw committed Jun 18, 2018
commit 698ce219c863b957f89effbc44aeda25e7a570ac
25 changes: 19 additions & 6 deletions python/ray/rllib/optimizers/multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
the TFMultiGPUSupport API.
"""

def _init(self, sgd_batch_size=128, sgd_stepsize=5e-5, num_sgd_iter=10):
def _init(self, sgd_batch_size=128, sgd_stepsize=5e-5, num_sgd_iter=10,
timesteps_per_batch=1024):
assert isinstance(self.local_evaluator, TFMultiGPUSupport)
self.batch_size = sgd_batch_size
self.sgd_stepsize = sgd_stepsize
self.num_sgd_iter = num_sgd_iter
self.timesteps_per_batch = timesteps_per_batch
gpu_ids = ray.get_gpu_ids()
if not gpu_ids:
# self.devices = ["/cpu:0"]
Expand Down Expand Up @@ -80,12 +82,16 @@ def step(self, postprocess_fn=None):

with self.sample_timer:
if self.remote_evaluators:
samples = SampleBatch.concat_samples(
ray.get(
[e.sample.remote() for e in self.remote_evaluators]))
# samples = SampleBatch.concat_samples(
# ray.get(
# [e.sample.remote() for e in self.remote_evaluators]))
from ray.rllib.ppo.rollout import collect_samples
samples = collect_samples(self.remote_evaluators,
self.timesteps_per_batch)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's probably a better treatment for this?

else:
samples = self.local_evaluator.sample()
assert isinstance(samples, SampleBatch)

if postprocess_fn:
postprocess_fn(samples)

Expand All @@ -95,19 +101,26 @@ def step(self, postprocess_fn=None):
samples.columns([key for key, _ in self.loss_inputs]))

with self.grad_timer:
all_extra_fetches = []
for i in range(self.num_sgd_iter):
iter_extra_fetches = []
num_batches = (
int(tuples_per_device) // int(self.per_device_batch_size))
permutation = np.random.permutation(num_batches)
for batch_index in range(num_batches):
# TODO(ekl) support ppo's debugging features, e.g.
# printing the current loss and tracing
self.par_opt.optimize(
batch_fetches = self.par_opt.optimize(
self.sess,
permutation[batch_index] * self.per_device_batch_size)
permutation[batch_index] * self.per_device_batch_size,
extra_ops=list(self.local_evaluator.extra_ops.values()),
extra_feed_dict=self.local_evaluator.extra_feed_dict())
iter_extra_fetches += [batch_fetches]
all_extra_fetches += [iter_extra_fetches]

self.num_steps_sampled += samples.count
self.num_steps_trained += samples.count
return all_extra_fetches

def stats(self):
return dict(PolicyOptimizer.stats(), **{
Expand Down
67 changes: 9 additions & 58 deletions python/ray/rllib/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ def _init(self):
self.optimizer = LocalMultiGPUOptimizer(
{"sgd_batch_size": self.config["sgd_batchsize"],
"sgd_stepsize": self.config["sgd_stepsize"],
"num_sgd_iter": self.config["num_sgd_iter"]},
"num_sgd_iter": self.config["num_sgd_iter"],
"timesteps_per_batch": self.config["timesteps_per_batch"]},
self.local_evaluator, self.remote_evaluators,)

self.saver = tf.train.Saver(max_to_keep=None)
Expand All @@ -139,67 +140,17 @@ def postprocess_samples(batch):
if not self.config["use_gae"]:
batch.data["value_targets"] = np.zeros_like(batch["advantages"])
batch.data["vf_preds"] = np.zeros_like(batch["advantages"])

agents = self.remote_evaluators
config = self.config
model = self.local_evaluator

if (config["num_workers"] * config["min_steps_per_task"] >
config["timesteps_per_batch"]):
print(
"WARNING: num_workers * min_steps_per_task > "
"timesteps_per_batch. This means that the output of some "
"tasks will be wasted. Consider decreasing "
"min_steps_per_task or increasing timesteps_per_batch.")

weights = ray.put(model.get_weights())
[a.set_weights.remote(weights) for a in agents]
samples = collect_samples(agents, config)

postprocess_samples(samples)
print("Computing policy (iterations=" + str(config["num_sgd_iter"]) +
", stepsize=" + str(config["sgd_stepsize"]) + "):")
names = [
"iter", "total loss", "policy loss", "vf loss", "kl", "entropy"]
print(("{:>15}" * len(names)).format(*names))

tuples_per_device = self.optimizer.par_opt.load_data(
self.local_evaluator.sess,
samples.columns([key for key, _ in self.local_evaluator.tf_loss_inputs()]))

# tuples_per_device = model.load_data(
# samples, self.iteration == 0 and config["full_trace_data_load"])
for i in range(config["num_sgd_iter"]):
num_batches = (
int(tuples_per_device) // int(model.per_device_batch_size))
loss, policy_graph, vf_loss, kl, entropy = [], [], [], [], []
permutation = np.random.permutation(num_batches)
# Prepare to drop into the debugger
for batch_index in range(num_batches):
batch_loss, batch_policy_graph, batch_vf_loss, batch_kl, \
batch_entropy = self.optimizer.par_opt.optimize(
self.local_evaluator.sess,
permutation[batch_index] * model.per_device_batch_size,
extra_ops=list(self.local_evaluator.extra_ops.values()),
extra_feed_dict=self.local_evaluator.extra_feed_dict())
loss.append(batch_loss)
policy_graph.append(batch_policy_graph)
vf_loss.append(batch_vf_loss)
kl.append(batch_kl)
entropy.append(batch_entropy)
loss = np.mean(loss)
policy_graph = np.mean(policy_graph)
vf_loss = np.mean(vf_loss)
kl = np.mean(kl)
entropy = np.mean(entropy)
print(
"{:>15}{:15.5e}{:15.5e}{:15.5e}{:15.5e}{:15.5e}".format(
i, loss, policy_graph, vf_loss, kl, entropy))

extra_fetches = self.optimizer.step(postprocess_fn=postprocess_samples)
final_metrics = np.mean(np.array(extra_fetches), axis=1)[-1, :].tolist()
total_loss, policy_loss, vf_loss, kl, entropy = final_metrics
self.local_evaluator.update_kl(kl)

info = {
"total_loss": total_loss,
"policy_loss": policy_loss,
"vf_loss": vf_loss,
"kl_divergence": kl,
"entropy": entropy,
"kl_coefficient": self.local_evaluator.kl_coeff_val,
}

Expand Down
4 changes: 2 additions & 2 deletions python/ray/rllib/ppo/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ray.rllib.optimizers import SampleBatch


def collect_samples(agents, config):
def collect_samples(agents, timesteps_per_batch):
num_timesteps_so_far = 0
trajectories = []
# This variable maps the object IDs of trajectories that are currently
Expand All @@ -19,7 +19,7 @@ def collect_samples(agents, config):
fut_sample = agent.sample.remote()
agent_dict[fut_sample] = agent

while num_timesteps_so_far < config["timesteps_per_batch"]:
while num_timesteps_so_far < timesteps_per_batch:
# TODO(pcm): Make wait support arbitrary iterators and remove the
# conversion to list here.
[fut_sample], _ = ray.wait(list(agent_dict))
Expand Down