-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[RLLib] DDPG #1685
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
[RLLib] DDPG #1685
Changes from all commits
Commits
Show all changes
38 commits
Select commit
Hold shift + click to select a range
f46c8bb
fixed style issues
alvkao58 aa13d9c
removed gae, filter, clipping, value estimator for simplification pur…
alvkao58 a6809f7
added test entry in test_supported_spaces; minor fixes
alvkao58 391afb2
added description of PGAgent
alvkao58 39926b1
minor cosmetic changes
alvkao58 dd12882
eliminated several unnecessary parts of code
alvkao58 8fce9fc
added jenkins tests, horizon
alvkao58 38a5720
initial commit for ddpg
alvkao58 dbfd1db
safety commit before restructuring models
alvkao58 f2cf37a
fixed network, need to move critic and actor back into separate classes
alvkao58 09d7221
did a lot of restructuring; first version that runs
alvkao58 b356a8e
added target update
alvkao58 c8b9ffc
switched to using SyncLocalReplayOptimizer and pre-existing replay bu…
alvkao58 ee07ab4
fixed some algorithmic errors, cleaned up code
alvkao58 bbcfe3a
added fix to actor gradients
alvkao58 fdf5c29
richard
richardliaw fc9fbba
some small fixes
richardliaw 6bd2880
updated stats to match training process, added option for parameter s…
alvkao58 391ea14
nit changes
richardliaw 395cf35
updated actor and critic networks; now learns on Pendulum
alvkao58 cf3e2fe
style fixes
alvkao58 66d7f6f
switching from tflearn to slim, making sampler consistent
alvkao58 6964231
nits
richardliaw 2eaa114
more fixes to actor/critic network
alvkao58 ae208c5
moved tensorflow-specific stuff out of the evaluator
alvkao58 d5f9aca
fixed getting/setting weights
alvkao58 995adc3
formatting fixes
alvkao58 667b597
added descrptions to config items, moved noise process parameters int…
alvkao58 870ea84
updated stats collecting
alvkao58 2f24434
more clean up
alvkao58 9c9b61b
changed stats to support remote evaluators
alvkao58 fc7b9d3
made requested formatting changes
alvkao58 a759eb2
moved actor, critic networks into model directory
alvkao58 90fad1e
Merge branch 'master' into ddpg
richardliaw 1d76eac
fix from merging
alvkao58 e5d0649
added back changes that were removed when fixing rebasing
alvkao58 aab9570
made requested touch ups
alvkao58 831c333
fix test syntax
alvkao58 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from ray.rllib.ddpg.ddpg import DDPGAgent, DEFAULT_CONFIG | ||
|
||
__all__ = ["DDPGAgent", "DEFAULT_CONFIG"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import numpy as np | ||
|
||
import ray | ||
from ray.rllib.agent import Agent | ||
from ray.rllib.ddpg.ddpg_evaluator import DDPGEvaluator, RemoteDDPGEvaluator | ||
from ray.rllib.optimizers import LocalSyncReplayOptimizer | ||
from ray.tune.result import TrainingResult | ||
|
||
DEFAULT_CONFIG = { | ||
# Actor learning rate | ||
"actor_lr": 0.0001, | ||
# Critic learning rate | ||
"critic_lr": 0.001, | ||
# Arguments to pass in to env creator | ||
"env_config": {}, | ||
# MDP Discount factor | ||
"gamma": 0.99, | ||
# Number of steps after which the rollout gets cut | ||
"horizon": 500, | ||
|
||
# Whether to include parameter noise | ||
"noise_add": True, | ||
# Linear decay of exploration policy | ||
"noise_epsilon": 0.0002, | ||
# Parameters for noise process | ||
"noise_parameters": { | ||
"mu": 0, | ||
"sigma": 0.2, | ||
"theta": 0.15, | ||
}, | ||
|
||
# Number of local steps taken for each call to sample | ||
"num_local_steps": 1, | ||
# Number of workers (excluding master) | ||
"num_workers": 0, | ||
|
||
"optimizer": { | ||
# Replay buffer size | ||
"buffer_size": 10000, | ||
# Number of steps in warm-up phase before learning starts | ||
"learning_starts": 500, | ||
# Whether to clip rewards | ||
"clip_rewards": False, | ||
# Whether to use prioritized replay | ||
"prioritized_replay": False, | ||
# Size of batch sampled from replay buffer | ||
"train_batch_size": 64, | ||
}, | ||
|
||
# Controls how fast target networks move | ||
"tau": 0.001, | ||
# Number of steps taken per training iteration | ||
"train_steps": 600, | ||
} | ||
|
||
|
||
class DDPGAgent(Agent): | ||
_agent_name = "DDPG" | ||
_default_config = DEFAULT_CONFIG | ||
|
||
def _init(self): | ||
self.local_evaluator = DDPGEvaluator( | ||
self.registry, self.env_creator, self.config) | ||
self.remote_evaluators = [ | ||
RemoteDDPGEvaluator.remote( | ||
self.registry, self.env_creator, self.config) | ||
for _ in range(self.config["num_workers"])] | ||
self.optimizer = LocalSyncReplayOptimizer( | ||
self.config["optimizer"], self.local_evaluator, | ||
self.remote_evaluators) | ||
|
||
def _train(self): | ||
for _ in range(self.config["train_steps"]): | ||
self.optimizer.step() | ||
# update target | ||
if self.optimizer.num_steps_trained > 0: | ||
self.local_evaluator.update_target() | ||
|
||
# generate training result | ||
return self._fetch_metrics() | ||
|
||
def _fetch_metrics(self): | ||
episode_rewards = [] | ||
episode_lengths = [] | ||
if self.config["num_workers"] > 0: | ||
metric_lists = [a.get_completed_rollout_metrics.remote() | ||
for a in self.remote_evaluators] | ||
for metrics in metric_lists: | ||
for episode in ray.get(metrics): | ||
episode_lengths.append(episode.episode_length) | ||
episode_rewards.append(episode.episode_reward) | ||
else: | ||
metrics = self.local_evaluator.get_completed_rollout_metrics() | ||
for episode in metrics: | ||
episode_lengths.append(episode.episode_length) | ||
episode_rewards.append(episode.episode_reward) | ||
|
||
avg_reward = (np.mean(episode_rewards)) | ||
avg_length = (np.mean(episode_lengths)) | ||
timesteps = np.sum(episode_lengths) | ||
|
||
result = TrainingResult( | ||
episode_reward_mean=avg_reward, | ||
episode_len_mean=avg_length, | ||
timesteps_this_iter=timesteps, | ||
info={}) | ||
|
||
return result |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import numpy as np | ||
|
||
import ray | ||
from ray.rllib.ddpg.models import DDPGModel | ||
from ray.rllib.models.catalog import ModelCatalog | ||
from ray.rllib.optimizers import PolicyEvaluator | ||
from ray.rllib.utils.filter import NoFilter | ||
from ray.rllib.utils.process_rollout import process_rollout | ||
from ray.rllib.utils.sampler import SyncSampler | ||
|
||
|
||
class DDPGEvaluator(PolicyEvaluator): | ||
|
||
def __init__(self, registry, env_creator, config): | ||
self.env = ModelCatalog.get_preprocessor_as_wrapper( | ||
registry, env_creator(config["env_config"])) | ||
|
||
# contains model, target_model | ||
self.model = DDPGModel(registry, self.env, config) | ||
|
||
self.sampler = SyncSampler( | ||
self.env, self.model.model, NoFilter(), | ||
config["num_local_steps"], horizon=config["horizon"]) | ||
|
||
def sample(self): | ||
"""Returns a batch of samples.""" | ||
|
||
rollout = self.sampler.get_data() | ||
rollout.data["weights"] = np.ones_like(rollout.data["rewards"]) | ||
|
||
# since each sample is one step, no discounting needs to be applied; | ||
# this does not involve config["gamma"] | ||
samples = process_rollout( | ||
rollout, NoFilter(), | ||
gamma=1.0, use_gae=False) | ||
|
||
return samples | ||
|
||
def update_target(self): | ||
"""Updates target critic and target actor.""" | ||
self.model.update_target() | ||
|
||
def compute_gradients(self, samples): | ||
"""Returns critic, actor gradients.""" | ||
return self.model.compute_gradients(samples) | ||
|
||
def apply_gradients(self, grads): | ||
"""Applies gradients to evaluator weights.""" | ||
self.model.apply_gradients(grads) | ||
|
||
def compute_apply(self, samples): | ||
grads, _ = self.compute_gradients(samples) | ||
self.apply_gradients(grads) | ||
|
||
def get_weights(self): | ||
"""Returns model weights.""" | ||
return self.model.get_weights() | ||
|
||
def set_weights(self, weights): | ||
"""Sets model weights.""" | ||
self.model.set_weights(weights) | ||
|
||
def get_completed_rollout_metrics(self): | ||
"""Returns metrics on previously completed rollouts. | ||
|
||
Calling this clears the queue of completed rollout metrics. | ||
""" | ||
return self.sampler.get_metrics() | ||
|
||
|
||
RemoteDDPGEvaluator = ray.remote(DDPGEvaluator) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice. It would be good to also
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@alvkao58 can you take care of these comments by Eric? (see PR)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any particular sanity checks you want to see added to multi_node_tests.sh?