-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generic hidden state for RecurrentPPO #4
Conversation
00ddbbd
to
77508da
Compare
@@ -375,23 +371,21 @@ def train(self) -> None: | |||
# Convert discrete action from float to long | |||
actions = rollout_data.actions.long().flatten() | |||
|
|||
# Convert mask from float to bool | |||
mask = rollout_data.mask > 1e-8 | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The rollout_data.mask
is now already a bool
@@ -260,7 +257,7 @@ def collect_rollouts( # type: ignore[override] | |||
|
|||
callback.on_rollout_start() | |||
|
|||
lstm_states = deepcopy(self._last_lstm_states) | |||
lstm_states = non_null(self._last_lstm_states) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not actually necessary to copy each tensor. They don't get overwritten.
.circleci/config.yml
Outdated
@@ -9,7 +9,7 @@ parameters: | |||
docker_img_version: | |||
# Docker image version for running tests. | |||
type: string | |||
default: "a0d53ea" | |||
default: "03a594c" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a more recent image -- a few dependencies were added, though they don't impact this codebase (only learned-planners
).
.circleci/config.yml
Outdated
@@ -51,7 +51,7 @@ jobs: | |||
command: ruff . | |||
- run: | |||
name: Typecheck (mypy) | |||
command: mypy --exclude '^stable_baselines3/common/recurrent/policies\.py$' stable_baselines3/common tests | |||
command: mypy . |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Start typechecking all the things again!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delayed review, I am still only firing on like half of my cylinders (so to speak).
"Get only the vf features, not advancing the hidden state" | ||
if self.lstm_critic is None: | ||
if self.shared_lstm: | ||
with th.no_grad(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would have thought that we need this with th.no_grad():
at the top level (ie, at line 257 and applying to all parts of this function).
In particular, I'm wondering if we might accidentally alter gradients on line 266 otherwise.
buffer_size = self.env.num_envs * self.n_steps | ||
assert buffer_size > 1 or ( | ||
not normalize_advantage | ||
), f"`n_steps * n_envs` must be greater than 1. Currently n_steps={self.n_steps} and n_envs={self.env.num_envs}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
f"`n_steps * n_envs` must be greater than 1 when `normalize_advantage` is true.
etc
(This PR is after #7 and #8 , which I factored out to make review easier.)
Add the remaining parts for generic-state recurrent PPO.
common/recurrent/policies.py
is still based on LSTMs, but now uses therecurrent_initial_state(...)
interface to indicate what its hidden state is.RecurrentPPO
is fully generic over hidden state types.