Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rllib] Fix DQN inefficiency, and cleanup for different modes of parallelism #1151

Merged
merged 11 commits into from
Oct 29, 2017

Conversation

ericl
Copy link
Contributor

@ericl ericl commented Oct 21, 2017

There was a regression introduced in the last PR which incurred extra object store copy overhead for every gradient, even if parallelism was disabled. Pong now trains within a couple hours on GPU again (expect perhaps 5-10x that on CPU):

screenshot from 2017-10-21 01-04-50

Note that this is with a downscaling filter applied, similar to what A3C is using. With the filter disabled, training gets a bit slower but it's still much faster than before. The exact hyper-parameters used are in rllib/tuned_examples/pong-dqn.yaml and curves from the above figure can be reproduced via ./train.py -f tuned_examples/pong-dqn.yaml --num-gpus=N.

I also took the chance to check in some experimental modes of DQN for multi-gpu and async updates. Note that those don't currently provide any speedups for Pong.

@pcmoritz @royf

@AmplabJenkins
Copy link

Merged build finished. Test PASSed.

@AmplabJenkins
Copy link

Test PASSed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/Ray-PRB/2192/
Test PASSed.

@ericl
Copy link
Contributor Author

ericl commented Oct 27, 2017

ping @pcmoritz @richardliaw

Copy link
Contributor

@richardliaw richardliaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure about leaving all the timing code.

Would be nice to avoid hardcoding times and use dicts (with keys).

Would be nice to see @pcmoritz comments (but if none any time soon, we can merge first)

@@ -97,8 +98,50 @@ def _scope_vars(scope, trainable_only=False):
scope=scope if isinstance(scope, str) else scope.name)


class ModelAndLoss(object):
def __init__(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice to have description of this object and what it's doing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

obs_t, act_t, rew_t, obs_tp1, done_mask, importance_weights):
# q network evaluation
with tf.variable_scope("q_func", reuse=True):
self.q_t = _build_q_network(obs_t, num_actions, config)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

q_t is a bit unintuitive...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, it is from the existing code, I guess Q_t vs Q_t+1 makes sense from the formulation.

@@ -168,12 +181,22 @@ def step(self, cur_timestep):
self.episode_lengths.append(0.0)
return ret

def collect_samples(self, num_steps, cur_timestep):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not keep do_steps with a flag to store for replay buffer, otherwise no storing
and always returning all states?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def do_steps(self, num_steps, cur_timestep, store=False):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

dt = time.time()
for _ in range(self.config["num_sgd_iter"]):
batches = list(range(num_batches))
random.shuffle(batches)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not np.random.shuffle, as np is already imported?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I think it's the same.

np.abs(td_errors) + self.config["prioritized_replay_eps"])
self.replay_buffer.update_priorities(
batch_idxes, new_priorities)
prioritization_time = (time.time() - dt)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we sure we want to keep this timing code? if you really want, I would put everything into a info dict and return that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, timing is important part for a experimental feature. I changed it to a dict.

I was also debating leaving out the multi-gpu part, but it's probably better to leave it in otherwise that stuff tends to code rot.

gradient = None
return gradient, {"id": worker_id, "gradient_id": gradient_id}

def get_gradient(self, cur_timestep):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we change this to something a little more informative? ie sample_buffer_gradient?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

stats = ray.get(
[w.stats.remote(self.cur_timestep) for w in self.workers])
for stat in stats:
mean_100ep_reward += stat[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the indexing is not that extensible; would be nice to use dicts instead

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a TODO to clean this up.

@AmplabJenkins
Copy link

Merged build finished. Test PASSed.

@AmplabJenkins
Copy link

Test PASSed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/Ray-PRB/2226/
Test PASSed.

Copy link
Contributor

@richardliaw richardliaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can approve after last comment is addressed

self.replay_buffer.add(obs, action, rew, new_obs, done)
else:
output.append(result)
if store:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is only a return value when store is true; but the docs say "otherwise"

@AmplabJenkins
Copy link

Merged build finished. Test PASSed.

@AmplabJenkins
Copy link

Test PASSed.
Refer to this link for build results (access rights to CI server needed):
https://amplab.cs.berkeley.edu/jenkins//job/Ray-PRB/2234/
Test PASSed.

@richardliaw richardliaw merged commit 4cace09 into ray-project:master Oct 29, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants