Skip to content

Commit

Permalink
[rllib] Move a3c implementation from examples/ to python/ray/rllib/ (#…
Browse files Browse the repository at this point in the history
…698)

* rllib v0

* fix imports

* lint

* comments

* update docs

* a3c wip

* a3c wip

* report stats

* update doc

* name is too long

* fix small bug

* propagate exception on error

* fetch metrics

* fix lint
  • Loading branch information
ericl authored and pcmoritz committed Jun 29, 2017
1 parent efce49c commit 2d81edf
Show file tree
Hide file tree
Showing 10 changed files with 199 additions and 133 deletions.
6 changes: 3 additions & 3 deletions doc/source/example-a3c.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ You can run the code with

.. code-block:: bash
python ray/examples/a3c/driver.py [num_workers]
python/ray/rllib/a3c/example.py --num-workers=N
Reinforcement Learning
----------------------
Expand Down Expand Up @@ -153,6 +153,6 @@ workers, we can train the agent in around 25 minutes.

You can visualize performance by running
:code:`tensorboard --logdir [directory]` in a separate screen, where
:code:`[directory]` is defaulted to :code:`./results/`. If you are running
:code:`[directory]` is defaulted to :code:`/tmp/ray/a3c/`. If you are running
multiple experiments, be sure to vary the directory to which Tensorflow saves
its progress (found in :code:`driver.py`).
its progress (found in :code:`a3c.py`).
83 changes: 0 additions & 83 deletions examples/a3c/driver.py

This file was deleted.

33 changes: 0 additions & 33 deletions examples/a3c/misc.py

This file was deleted.

6 changes: 4 additions & 2 deletions examples/a3c/LSTM.py → python/ray/rllib/a3c/LSTM.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
import tensorflow as tf
import tensorflow.contrib.rnn as rnn
import distutils.version
from policy import (categorical_sample, conv2d, linear, flatten,
normalized_columns_initializer, Policy)

from ray.rllib.a3c.policy import (
categorical_sample, conv2d, linear, flatten,
normalized_columns_initializer, Policy)

use_tf100_api = (distutils.version.LooseVersion(tf.VERSION) >=
distutils.version.LooseVersion("1.0.0"))
Expand Down
3 changes: 3 additions & 0 deletions python/ray/rllib/a3c/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from ray.rllib.a3c.a3c import A3C, DEFAULT_CONFIG

__all__ = ["A3C", "DEFAULT_CONFIG"]
126 changes: 126 additions & 0 deletions python/ray/rllib/a3c/a3c.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import tensorflow as tf
import six.moves.queue as queue
import os

import ray
from ray.rllib.a3c.LSTM import LSTMPolicy
from ray.rllib.a3c.runner import RunnerThread, process_rollout
from ray.rllib.a3c.envs import create_env
from ray.rllib.common import Algorithm, TrainingResult


DEFAULT_CONFIG = {
"num_workers": 4,
"num_batches_per_iteration": 100,
}


@ray.remote
class Runner(object):
"""Actor object to start running simulation on workers.
The gradient computation is also executed from this object.
"""
def __init__(self, env_name, actor_id, logdir="/tmp/ray/a3c/", start=True):
env = create_env(env_name)
self.id = actor_id
num_actions = env.action_space.n
self.policy = LSTMPolicy(env.observation_space.shape, num_actions,
actor_id)
self.runner = RunnerThread(env, self.policy, 20)
self.env = env
self.logdir = logdir
if start:
self.start()

def pull_batch_from_queue(self):
"""Take a rollout from the queue of the thread runner."""
rollout = self.runner.queue.get(timeout=600.0)
if isinstance(rollout, BaseException):
raise rollout
while not rollout.terminal:
try:
part = self.runner.queue.get_nowait()
if isinstance(part, BaseException):
raise rollout
rollout.extend(part)
except queue.Empty:
break
return rollout

def get_completed_rollout_metrics(self):
"""Returns metrics on previously completed rollouts.
Calling this clears the queue of completed rollout metrics.
"""
completed = []
while True:
try:
completed.append(self.runner.metrics_queue.get_nowait())
except queue.Empty:
break
return completed

def start(self):
summary_writer = tf.summary.FileWriter(
os.path.join(self.logdir, "agent_%d" % self.id))
self.summary_writer = summary_writer
self.runner.start_runner(self.policy.sess, summary_writer)

def compute_gradient(self, params):
self.policy.set_weights(params)
rollout = self.pull_batch_from_queue()
batch = process_rollout(rollout, gamma=0.99, lambda_=1.0)
gradient = self.policy.get_gradients(batch)
info = {"id": self.id,
"size": len(batch.a)}
return gradient, info


class A3C(Algorithm):
def __init__(self, env_name, config):
Algorithm.__init__(self, env_name, config)
self.env = create_env(env_name)
self.policy = LSTMPolicy(
self.env.observation_space.shape, self.env.action_space.n, 0)
self.agents = [
Runner.remote(env_name, i) for i in range(config["num_workers"])]
self.parameters = self.policy.get_weights()
self.iteration = 0

def train(self):
gradient_list = [
agent.compute_gradient.remote(self.parameters)
for agent in self.agents]
max_batches = self.config["num_batches_per_iteration"]
batches_so_far = len(gradient_list)
while gradient_list:
done_id, gradient_list = ray.wait(gradient_list)
gradient, info = ray.get(done_id)[0]
self.policy.model_update(gradient)
self.parameters = self.policy.get_weights()
if batches_so_far < max_batches:
batches_so_far += 1
gradient_list.extend(
[self.agents[info["id"]].compute_gradient.remote(self.parameters)])
res = self.fetch_metrics_from_workers()
self.iteration += 1
return res

def fetch_metrics_from_workers(self):
episode_rewards = []
episode_lengths = []
metric_lists = [
a.get_completed_rollout_metrics.remote() for a in self.agents]
for metrics in metric_lists:
for episode in ray.get(metrics):
episode_lengths.append(episode.episode_length)
episode_rewards.append(episode.episode_reward)
res = TrainingResult(
self.iteration, np.mean(episode_rewards), np.mean(episode_lengths))
return res
File renamed without changes.
32 changes: 32 additions & 0 deletions python/ray/rllib/a3c/example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#!/usr/bin/env python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse

import ray
from ray.rllib.a3c import A3C, DEFAULT_CONFIG


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the A3C algorithm.")
parser.add_argument("--environment", default="PongDeterministic-v3",
type=str, help="The gym environment to use.")
parser.add_argument("--redis-address", default=None, type=str,
help="The Redis address of the cluster.")
parser.add_argument("--num-workers", default=4, type=int,
help="The number of A3C workers to use>")

args = parser.parse_args()
ray.init(redis_address=args.redis_address, num_cpus=args.num_workers)

config = DEFAULT_CONFIG.copy()
config["num_workers"] = args.num_workers

a3c = A3C(args.environment, config)

while True:
res = a3c.train()
print("current status: {}".format(res))
File renamed without changes.
Loading

0 comments on commit 2d81edf

Please sign in to comment.