-
Notifications
You must be signed in to change notification settings - Fork 6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[rllib] Move a3c implementation from examples/ to python/ray/rllib/ (#…
…698) * rllib v0 * fix imports * lint * comments * update docs * a3c wip * a3c wip * report stats * update doc * name is too long * fix small bug * propagate exception on error * fetch metrics * fix lint
- Loading branch information
Showing
10 changed files
with
199 additions
and
133 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from ray.rllib.a3c.a3c import A3C, DEFAULT_CONFIG | ||
|
||
__all__ = ["A3C", "DEFAULT_CONFIG"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import numpy as np | ||
import tensorflow as tf | ||
import six.moves.queue as queue | ||
import os | ||
|
||
import ray | ||
from ray.rllib.a3c.LSTM import LSTMPolicy | ||
from ray.rllib.a3c.runner import RunnerThread, process_rollout | ||
from ray.rllib.a3c.envs import create_env | ||
from ray.rllib.common import Algorithm, TrainingResult | ||
|
||
|
||
DEFAULT_CONFIG = { | ||
"num_workers": 4, | ||
"num_batches_per_iteration": 100, | ||
} | ||
|
||
|
||
@ray.remote | ||
class Runner(object): | ||
"""Actor object to start running simulation on workers. | ||
The gradient computation is also executed from this object. | ||
""" | ||
def __init__(self, env_name, actor_id, logdir="/tmp/ray/a3c/", start=True): | ||
env = create_env(env_name) | ||
self.id = actor_id | ||
num_actions = env.action_space.n | ||
self.policy = LSTMPolicy(env.observation_space.shape, num_actions, | ||
actor_id) | ||
self.runner = RunnerThread(env, self.policy, 20) | ||
self.env = env | ||
self.logdir = logdir | ||
if start: | ||
self.start() | ||
|
||
def pull_batch_from_queue(self): | ||
"""Take a rollout from the queue of the thread runner.""" | ||
rollout = self.runner.queue.get(timeout=600.0) | ||
if isinstance(rollout, BaseException): | ||
raise rollout | ||
while not rollout.terminal: | ||
try: | ||
part = self.runner.queue.get_nowait() | ||
if isinstance(part, BaseException): | ||
raise rollout | ||
rollout.extend(part) | ||
except queue.Empty: | ||
break | ||
return rollout | ||
|
||
def get_completed_rollout_metrics(self): | ||
"""Returns metrics on previously completed rollouts. | ||
Calling this clears the queue of completed rollout metrics. | ||
""" | ||
completed = [] | ||
while True: | ||
try: | ||
completed.append(self.runner.metrics_queue.get_nowait()) | ||
except queue.Empty: | ||
break | ||
return completed | ||
|
||
def start(self): | ||
summary_writer = tf.summary.FileWriter( | ||
os.path.join(self.logdir, "agent_%d" % self.id)) | ||
self.summary_writer = summary_writer | ||
self.runner.start_runner(self.policy.sess, summary_writer) | ||
|
||
def compute_gradient(self, params): | ||
self.policy.set_weights(params) | ||
rollout = self.pull_batch_from_queue() | ||
batch = process_rollout(rollout, gamma=0.99, lambda_=1.0) | ||
gradient = self.policy.get_gradients(batch) | ||
info = {"id": self.id, | ||
"size": len(batch.a)} | ||
return gradient, info | ||
|
||
|
||
class A3C(Algorithm): | ||
def __init__(self, env_name, config): | ||
Algorithm.__init__(self, env_name, config) | ||
self.env = create_env(env_name) | ||
self.policy = LSTMPolicy( | ||
self.env.observation_space.shape, self.env.action_space.n, 0) | ||
self.agents = [ | ||
Runner.remote(env_name, i) for i in range(config["num_workers"])] | ||
self.parameters = self.policy.get_weights() | ||
self.iteration = 0 | ||
|
||
def train(self): | ||
gradient_list = [ | ||
agent.compute_gradient.remote(self.parameters) | ||
for agent in self.agents] | ||
max_batches = self.config["num_batches_per_iteration"] | ||
batches_so_far = len(gradient_list) | ||
while gradient_list: | ||
done_id, gradient_list = ray.wait(gradient_list) | ||
gradient, info = ray.get(done_id)[0] | ||
self.policy.model_update(gradient) | ||
self.parameters = self.policy.get_weights() | ||
if batches_so_far < max_batches: | ||
batches_so_far += 1 | ||
gradient_list.extend( | ||
[self.agents[info["id"]].compute_gradient.remote(self.parameters)]) | ||
res = self.fetch_metrics_from_workers() | ||
self.iteration += 1 | ||
return res | ||
|
||
def fetch_metrics_from_workers(self): | ||
episode_rewards = [] | ||
episode_lengths = [] | ||
metric_lists = [ | ||
a.get_completed_rollout_metrics.remote() for a in self.agents] | ||
for metrics in metric_lists: | ||
for episode in ray.get(metrics): | ||
episode_lengths.append(episode.episode_length) | ||
episode_rewards.append(episode.episode_reward) | ||
res = TrainingResult( | ||
self.iteration, np.mean(episode_rewards), np.mean(episode_lengths)) | ||
return res |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
#!/usr/bin/env python | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import argparse | ||
|
||
import ray | ||
from ray.rllib.a3c import A3C, DEFAULT_CONFIG | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="Run the A3C algorithm.") | ||
parser.add_argument("--environment", default="PongDeterministic-v3", | ||
type=str, help="The gym environment to use.") | ||
parser.add_argument("--redis-address", default=None, type=str, | ||
help="The Redis address of the cluster.") | ||
parser.add_argument("--num-workers", default=4, type=int, | ||
help="The number of A3C workers to use>") | ||
|
||
args = parser.parse_args() | ||
ray.init(redis_address=args.redis_address, num_cpus=args.num_workers) | ||
|
||
config = DEFAULT_CONFIG.copy() | ||
config["num_workers"] = args.num_workers | ||
|
||
a3c = A3C(args.environment, config) | ||
|
||
while True: | ||
res = a3c.train() | ||
print("current status: {}".format(res)) |
File renamed without changes.
Oops, something went wrong.