-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[RLlib] Trajectory View API (preparatory cleanup and enhancements). #9678
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e37a2df
2a40feb
3f242df
0f827ac
87adf93
3b331fd
545cdc5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,10 +39,7 @@ def compute_advantages(rollout: SampleBatch, | |
processed rewards. | ||
""" | ||
|
||
traj = {} | ||
trajsize = len(rollout[SampleBatch.ACTIONS]) | ||
for key in rollout: | ||
traj[key] = np.stack(rollout[key]) | ||
rollout_size = len(rollout[SampleBatch.ACTIONS]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We shouldn't copy the entire batch here. Significantly speeds up postprocessing. |
||
|
||
assert SampleBatch.VF_PREDS in rollout or not use_critic, \ | ||
"use_critic=True but values not found" | ||
|
@@ -54,13 +51,13 @@ def compute_advantages(rollout: SampleBatch, | |
[rollout[SampleBatch.VF_PREDS], | ||
np.array([last_r])]) | ||
delta_t = ( | ||
traj[SampleBatch.REWARDS] + gamma * vpred_t[1:] - vpred_t[:-1]) | ||
rollout[SampleBatch.REWARDS] + gamma * vpred_t[1:] - vpred_t[:-1]) | ||
# This formula for the advantage comes from: | ||
# "Generalized Advantage Estimation": https://arxiv.org/abs/1506.02438 | ||
traj[Postprocessing.ADVANTAGES] = discount(delta_t, gamma * lambda_) | ||
traj[Postprocessing.VALUE_TARGETS] = ( | ||
traj[Postprocessing.ADVANTAGES] + | ||
traj[SampleBatch.VF_PREDS]).copy().astype(np.float32) | ||
rollout[Postprocessing.ADVANTAGES] = discount(delta_t, gamma * lambda_) | ||
rollout[Postprocessing.VALUE_TARGETS] = ( | ||
rollout[Postprocessing.ADVANTAGES] + | ||
rollout[SampleBatch.VF_PREDS]).copy().astype(np.float32) | ||
else: | ||
rewards_plus_v = np.concatenate( | ||
[rollout[SampleBatch.REWARDS], | ||
|
@@ -69,18 +66,18 @@ def compute_advantages(rollout: SampleBatch, | |
gamma)[:-1].copy().astype(np.float32) | ||
|
||
if use_critic: | ||
traj[Postprocessing. | ||
ADVANTAGES] = discounted_returns - rollout[SampleBatch. | ||
VF_PREDS] | ||
traj[Postprocessing.VALUE_TARGETS] = discounted_returns | ||
rollout[Postprocessing. | ||
ADVANTAGES] = discounted_returns - rollout[SampleBatch. | ||
VF_PREDS] | ||
rollout[Postprocessing.VALUE_TARGETS] = discounted_returns | ||
else: | ||
traj[Postprocessing.ADVANTAGES] = discounted_returns | ||
traj[Postprocessing.VALUE_TARGETS] = np.zeros_like( | ||
traj[Postprocessing.ADVANTAGES]) | ||
rollout[Postprocessing.ADVANTAGES] = discounted_returns | ||
rollout[Postprocessing.VALUE_TARGETS] = np.zeros_like( | ||
rollout[Postprocessing.ADVANTAGES]) | ||
|
||
traj[Postprocessing.ADVANTAGES] = traj[ | ||
rollout[Postprocessing.ADVANTAGES] = rollout[ | ||
Postprocessing.ADVANTAGES].copy().astype(np.float32) | ||
|
||
assert all(val.shape[0] == trajsize for val in traj.values()), \ | ||
assert all(val.shape[0] == rollout_size for key, val in rollout.items()), \ | ||
"Rollout stacked incorrectly!" | ||
return SampleBatch(traj) | ||
return rollout |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -334,6 +334,8 @@ def wrap(env): | |
# Deepmind wrappers already handle all preprocessing | ||
self.preprocessing_enabled = False | ||
|
||
# If clip_rewards not explicitly set to False, switch it | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. clip_rewards has been moved into base-Policy's There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Won't this break reward clipping for policies that don't call that method? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. True, if users don't use the build_policy methods, then yes. The thought was: reward clipping is actually a post-processing step. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good. Let's be careful about breaking changes like that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 |
||
# on here (clip between -1.0 and 1.0). | ||
if clip_rewards is None: | ||
clip_rewards = True | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
import gym | ||
|
||
from ray.rllib.env.multi_agent_env import MultiAgentEnv | ||
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole | ||
from ray.rllib.tests.test_rollout_worker import MockEnv, MockEnv2 | ||
|
||
|
||
|
@@ -164,3 +165,5 @@ def step(self, action_dict): | |
MultiAgentCartPole = make_multiagent("CartPole-v0") | ||
MultiAgentMountainCar = make_multiagent("MountainCarContinuous-v0") | ||
MultiAgentPendulum = make_multiagent("Pendulum-v0") | ||
MultiAgentStatelessCartPole = make_multiagent( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added for LSTM multi-agent testing purposes. |
||
lambda config: StatelessCartPole(config)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To ignore errors related to Type-annotations (those that use quotes around the actual class (b/c we want to avoid circular import errors)).