-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[rllib] Pull out multi-gpu optimizer as a generic class #1313
Changes from 42 commits
49dd02b
cd50cdd
c5ef83e
3005f9e
95b62c7
caed1df
00e1e1c
d1e6570
173e845
edadd65
edb17de
293becc
0f17fde
bab5010
8b5b1a0
de44c91
9839db3
7e1f9e2
b9764c2
48a349c
916f25b
85cfe8d
191b159
1c4ef0d
c8acce0
2043a04
1e223e8
cf2eabd
e36dafd
05fca04
636bf9e
f743ac8
44ec06e
ee7e606
c069491
e9b21c8
bbf6ebf
ff4b416
d35a818
5e14b5e
c9aa526
2e5bda1
9906c54
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ | |
from ray.rllib.dqn.base_evaluator import DQNEvaluator | ||
from ray.rllib.dqn.common.schedules import LinearSchedule | ||
from ray.rllib.dqn.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer | ||
from ray.rllib.optimizers import SampleBatch | ||
|
||
|
||
class DQNReplayEvaluator(DQNEvaluator): | ||
|
@@ -63,38 +64,44 @@ def sample(self, no_replay=False): | |
samples = [DQNEvaluator.sample(self)] | ||
|
||
for s in samples: | ||
for obs, action, rew, new_obs, done in s: | ||
self.replay_buffer.add(obs, action, rew, new_obs, done) | ||
for row in s.rows(): | ||
self.replay_buffer.add( | ||
row["obs"], row["actions"], row["rewards"], row["new_obs"], | ||
row["dones"]) | ||
|
||
if no_replay: | ||
return samples | ||
|
||
# Then return a batch sampled from the buffer | ||
if self.config["prioritized_replay"]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I hope later on all the ReplayBuffer casework can get pushed into the ReplayBuffer There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's difficult. The prioritization requires extra computation that has to be part of the loss. |
||
experience = self.replay_buffer.sample( | ||
self.config["train_batch_size"], | ||
beta=self.beta_schedule.value(self.global_timestep)) | ||
(obses_t, actions, rewards, obses_tp1, | ||
dones, _, batch_idxes) = experience | ||
dones, weights, batch_indexes) = self.replay_buffer.sample( | ||
self.config["train_batch_size"], | ||
beta=self.beta_schedule.value(self.global_timestep)) | ||
self._update_priorities_if_needed() | ||
self.samples_to_prioritize = ( | ||
obses_t, actions, rewards, obses_tp1, dones, batch_idxes) | ||
batch = SampleBatch({ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you could just have the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, but I'm less concerned about that for now -- I'd rather keep replay_buffer unchanged as long as possible since it's still basically similar to the baselines code. |
||
"obs": obses_t, "actions": actions, "rewards": rewards, | ||
"new_obs": obses_tp1, "dones": dones, "weights": weights, | ||
"batch_indexes": batch_indexes}) | ||
self.samples_to_prioritize = batch | ||
else: | ||
obses_t, actions, rewards, obses_tp1, dones = \ | ||
self.replay_buffer.sample(self.config["train_batch_size"]) | ||
batch_idxes = None | ||
|
||
return self.samples_to_prioritize | ||
batch = SampleBatch({ | ||
"obs": obses_t, "actions": actions, "rewards": rewards, | ||
"new_obs": obses_tp1, "dones": dones, | ||
"weights": np.ones_like(rewards)}) | ||
return batch | ||
|
||
def compute_gradients(self, samples): | ||
obses_t, actions, rewards, obses_tp1, dones, batch_indxes = samples | ||
td_errors, grad = self.dqn_graph.compute_gradients( | ||
self.sess, obses_t, actions, rewards, obses_tp1, dones, | ||
np.ones_like(rewards)) | ||
self.sess, samples["obs"], samples["actions"], samples["rewards"], | ||
samples["new_obs"], samples["dones"], samples["weights"]) | ||
if self.config["prioritized_replay"]: | ||
new_priorities = ( | ||
np.abs(td_errors) + self.config["prioritized_replay_eps"]) | ||
self.replay_buffer.update_priorities(batch_indxes, new_priorities) | ||
self.replay_buffer.update_priorities( | ||
samples["batch_indexes"], new_priorities) | ||
self.samples_to_prioritize = None | ||
return grad | ||
|
||
|
@@ -109,14 +116,15 @@ def _update_priorities_if_needed(self): | |
if not self.samples_to_prioritize: | ||
return | ||
|
||
obses_t, actions, rewards, obses_tp1, dones, batch_idxes = \ | ||
self.samples_to_prioritize | ||
batch = self.samples_to_prioritize | ||
td_errors = self.dqn_graph.compute_td_error( | ||
self.sess, obses_t, actions, rewards, obses_tp1, dones, | ||
np.ones_like(rewards)) | ||
self.sess, batch["obs"], batch["actions"], batch["rewards"], | ||
batch["new_obs"], batch["dones"], batch["weights"]) | ||
|
||
new_priorities = ( | ||
np.abs(td_errors) + self.config["prioritized_replay_eps"]) | ||
self.replay_buffer.update_priorities(batch_idxes, new_priorities) | ||
self.replay_buffer.update_priorities( | ||
batch["batch_indexes"], new_priorities) | ||
self.samples_to_prioritize = None | ||
|
||
def stats(self): | ||
|
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,10 @@ | ||
from ray.rllib.optimizers.async import AsyncOptimizer | ||
from ray.rllib.optimizers.local_sync import LocalSyncOptimizer | ||
from ray.rllib.optimizers.multi_gpu import LocalMultiGPUOptimizer | ||
from ray.rllib.optimizers.sample_batch import SampleBatch | ||
from ray.rllib.optimizers.evaluator import Evaluator, TFMultiGPUSupport | ||
|
||
|
||
__all__ = ["AsyncOptimizer", "LocalSyncOptimizer", "LocalMultiGPUOptimizer"] | ||
__all__ = [ | ||
"AsyncOptimizer", "LocalSyncOptimizer", "LocalMultiGPUOptimizer", | ||
"SampleBatch", "Evaluator", "TFMultiGPUSupport"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it make sense to replace this with a
SyncSampler
? A little modification would need to be made to support exploration, but this class will end up cleaner. Can be done in another PR I guessThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I think it should work.