-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[rllib] Fix DQN inefficiency, and cleanup for different modes of parallelism #1151
Conversation
Merged build finished. Test PASSed. |
Test PASSed. |
ping @pcmoritz @richardliaw |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure about leaving all the timing code.
Would be nice to avoid hardcoding times and use dicts (with keys).
Would be nice to see @pcmoritz comments (but if none any time soon, we can merge first)
python/ray/rllib/dqn/models.py
Outdated
@@ -97,8 +98,50 @@ def _scope_vars(scope, trainable_only=False): | |||
scope=scope if isinstance(scope, str) else scope.name) | |||
|
|||
|
|||
class ModelAndLoss(object): | |||
def __init__( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be nice to have description of this object and what it's doing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
obs_t, act_t, rew_t, obs_tp1, done_mask, importance_weights): | ||
# q network evaluation | ||
with tf.variable_scope("q_func", reuse=True): | ||
self.q_t = _build_q_network(obs_t, num_actions, config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
q_t
is a bit unintuitive...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, it is from the existing code, I guess Q_t vs Q_t+1 makes sense from the formulation.
python/ray/rllib/dqn/dqn.py
Outdated
@@ -168,12 +181,22 @@ def step(self, cur_timestep): | |||
self.episode_lengths.append(0.0) | |||
return ret | |||
|
|||
def collect_samples(self, num_steps, cur_timestep): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not keep do_steps
with a flag to store
for replay buffer, otherwise no storing
and always returning all states?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def do_steps(self, num_steps, cur_timestep, store=False):
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
python/ray/rllib/dqn/dqn.py
Outdated
dt = time.time() | ||
for _ in range(self.config["num_sgd_iter"]): | ||
batches = list(range(num_batches)) | ||
random.shuffle(batches) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not np.random.shuffle
, as np
is already imported?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I think it's the same.
np.abs(td_errors) + self.config["prioritized_replay_eps"]) | ||
self.replay_buffer.update_priorities( | ||
batch_idxes, new_priorities) | ||
prioritization_time = (time.time() - dt) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are we sure we want to keep this timing code? if you really want, I would put everything into a info
dict and return that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, timing is important part for a experimental feature. I changed it to a dict.
I was also debating leaving out the multi-gpu part, but it's probably better to leave it in otherwise that stuff tends to code rot.
python/ray/rllib/dqn/dqn.py
Outdated
gradient = None | ||
return gradient, {"id": worker_id, "gradient_id": gradient_id} | ||
|
||
def get_gradient(self, cur_timestep): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we change this to something a little more informative? ie sample_buffer_gradient
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
stats = ray.get( | ||
[w.stats.remote(self.cur_timestep) for w in self.workers]) | ||
for stat in stats: | ||
mean_100ep_reward += stat[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the indexing is not that extensible; would be nice to use dicts instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a TODO to clean this up.
Merged build finished. Test PASSed. |
Test PASSed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can approve after last comment is addressed
python/ray/rllib/dqn/dqn.py
Outdated
self.replay_buffer.add(obs, action, rew, new_obs, done) | ||
else: | ||
output.append(result) | ||
if store: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is only a return value when store is true; but the docs say "otherwise"
Merged build finished. Test PASSed. |
Test PASSed. |
There was a regression introduced in the last PR which incurred extra object store copy overhead for every gradient, even if parallelism was disabled. Pong now trains within a couple hours on GPU again (expect perhaps 5-10x that on CPU):
Note that this is with a downscaling filter applied, similar to what A3C is using. With the filter disabled, training gets a bit slower but it's still much faster than before. The exact hyper-parameters used are in
rllib/tuned_examples/pong-dqn.yaml
and curves from the above figure can be reproduced via./train.py -f tuned_examples/pong-dqn.yaml --num-gpus=N
.I also took the chance to check in some experimental modes of DQN for multi-gpu and async updates. Note that those don't currently provide any speedups for Pong.
@pcmoritz @royf