-
Notifications
You must be signed in to change notification settings - Fork 6.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[rllib] PPO and A3C unification #1253
Conversation
Merged build finished. Test FAILed. |
Test FAILed. |
Merged build finished. Test FAILed. |
Test FAILed. |
Merged build finished. Test FAILed. |
Test FAILed. |
Merged build finished. Test FAILed. |
Test FAILed. |
Merged build finished. Test PASSed. |
Test PASSed. |
Merged build finished. Test PASSed. |
Test PASSed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice refactoring. To make sure PPO performance hasn't regressed, can you run the tuned humanoid example?
@@ -105,6 +105,7 @@ def _fetch_metrics_from_workers(self): | |||
return result | |||
|
|||
def _save(self): | |||
# TODO(rliaw): extend to also support saving worker state? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that's probably required for advanced hypertune algorithms to work well.
@@ -118,6 +119,8 @@ def _restore(self, checkpoint_path): | |||
self.rew_filter = objects[2] | |||
self.policy.set_weights(self.parameters) | |||
|
|||
# TODO(rliaw): augment to support LSTM |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could be a general TODO on agents.
python/ray/rllib/a3c/runner.py
Outdated
|
||
|
||
class Runner(object): | ||
class Runner(Evaluator): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A3CEvaluator(Evaluator)
python/ray/rllib/a3c/shared_model.py
Outdated
|
||
def value(self, ob, *args): | ||
vf = self.sess.run(self.vf, {self.x: [ob]}) | ||
return vf[0] | ||
|
||
def get_initial_features(self): | ||
# TODO(rliaw): make sure this is right |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Anything is fine since this isn't lstm right? So could return None.
python/ray/rllib/ppo/runner.py
Outdated
dummy], | ||
full_trace=full_trace) | ||
use_gae = self.config["use_gae"] | ||
dummy = np.zeros((trajectories["observations"].shape[0],)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
np.zeros_like?
python/ray/rllib/ppo/runner.py
Outdated
self.config["horizon"], self.config["horizon"]) | ||
if not is_remote: | ||
# local model needs obs_filter for compute | ||
self.obs_filter = obs_filter |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we just use the obs filter in the sampler?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any preference on keeping the (global/master) observation_filter in the model
or can I move it into ppo.py
? (similarly to A3C)
python/ray/rllib/utils/common.py
Outdated
@@ -0,0 +1,40 @@ | |||
from __future__ import absolute_import |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this file should be called process_rollout or something
python/ray/rllib/utils/sampler.py
Outdated
@@ -38,6 +43,9 @@ def __init__(self, extra_fields=None): | |||
|
|||
def add(self, **kwargs): | |||
for k, v in kwargs.items(): | |||
if (k not in ["observations", "features"] | |||
and hasattr(v, "squeeze")): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is kind of fishy, why is it needed?
Merged build finished. Test FAILed. |
Test FAILed. |
Merged build finished. Test FAILed. |
Test FAILed. |
Merged build finished. Test PASSed. |
Test PASSed. |
What do these changes do?
Variety of changes are introduced:
advantages + vf_preds
instead of MC returns.PartialRollouts
squeezes everything except forobservations
andfeatures
TODOS: