-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Port in torchified PPO from sb3_contrib #5
Conversation
Can you add some lineage of what was copied from where? This is a big PR so it's going to take me a while to review, and understanding what is new code vs copied code would be helpful. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added some quick comments to the smaller files. Waiting to hear back on the sources of the larger files before reviewing them.
Thanks for the refactors in the tests to treating OnPolicy/OffPolicy algorithms separately rather than using sets, seems much cleaner now!
stable_baselines3/common/buffers.py
Outdated
@@ -457,7 +457,7 @@ def compute_returns_and_advantage(self, last_values: th.Tensor, dones: th.Tensor | |||
next_non_terminal = ~dones | |||
next_values = last_values | |||
else: | |||
next_non_terminal = 1.0 - self.episode_starts[step + 1] | |||
next_non_terminal = ~self.episode_starts[step + 1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
next_non_terminal
seems like a confusing name here. Thoughts on is_next_non_terminal
?
next_non_terminal
reads to me as referring to the next nonterminal value, but this variable seems to be referring to whether the following value to be iterated through is nonterminal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for splitting out the other parts!
}[replay_buffer_cls] | ||
env = make_vec_env(env) | ||
|
||
buffer = replay_buffer_cls(100, env.observation_space, env.action_space, device=device) | ||
if replay_buffer_cls == RecurrentRolloutBuffer: | ||
buffer = RecurrentRolloutBuffer( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Optional:
Here and below, 100
comes off as a bit of a magic number that it's hard to see why it was picked (and what it does). Passing all these in as kwargs would make it a little more readable.
If this 100 is the 'same' 100 that is used in line 153, it would be good to refactor them to use a shared constant.
I started by copying over the numpy code and then fixing whatever broke. This led to working code, unlike previous iterations. I also added the
PPORecurrent
algorithm to a bunch of existing tests, which is the same thing thatsb3_contrib
does.This PR contains lots of Mypy type errors. I want to fix them when I make things generic over the hidden state, as in #4, and not now. If you check the CircleCI job you'll see that the tests themselves pass.
Other changes: