Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port in torchified PPO from sb3_contrib #5

Merged
merged 34 commits into from
Oct 7, 2023
Merged

Port in torchified PPO from sb3_contrib #5

merged 34 commits into from
Oct 7, 2023

Conversation

rhaps0dy
Copy link
Collaborator

@rhaps0dy rhaps0dy commented Sep 21, 2023

I started by copying over the numpy code and then fixing whatever broke. This led to working code, unlike previous iterations. I also added the PPORecurrent algorithm to a bunch of existing tests, which is the same thing that sb3_contrib does.

This PR contains lots of Mypy type errors. I want to fix them when I make things generic over the hidden state, as in #4, and not now. If you check the CircleCI job you'll see that the tests themselves pass.

Other changes:

  • Increased test parallelism in CircleCI to make tests less slow (there are some slow new tests)

@dan-pandori
Copy link

Can you add some lineage of what was copied from where? This is a big PR so it's going to take me a while to review, and understanding what is new code vs copied code would be helpful.

Copy link

@dan-pandori dan-pandori left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some quick comments to the smaller files. Waiting to hear back on the sources of the larger files before reviewing them.

Thanks for the refactors in the tests to treating OnPolicy/OffPolicy algorithms separately rather than using sets, seems much cleaner now!

.circleci/config.yml Outdated Show resolved Hide resolved
stable_baselines3/common/buffers.py Outdated Show resolved Hide resolved
@@ -457,7 +457,7 @@ def compute_returns_and_advantage(self, last_values: th.Tensor, dones: th.Tensor
next_non_terminal = ~dones
next_values = last_values
else:
next_non_terminal = 1.0 - self.episode_starts[step + 1]
next_non_terminal = ~self.episode_starts[step + 1]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

next_non_terminal seems like a confusing name here. Thoughts on is_next_non_terminal?

next_non_terminal reads to me as referring to the next nonterminal value, but this variable seems to be referring to whether the following value to be iterated through is nonterminal.

@rhaps0dy rhaps0dy changed the base branch from master to just-copy-contrib September 21, 2023 21:51
@rhaps0dy rhaps0dy changed the base branch from just-copy-contrib to master September 22, 2023 00:02
Copy link

@dan-pandori dan-pandori left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for splitting out the other parts!

}[replay_buffer_cls]
env = make_vec_env(env)

buffer = replay_buffer_cls(100, env.observation_space, env.action_space, device=device)
if replay_buffer_cls == RecurrentRolloutBuffer:
buffer = RecurrentRolloutBuffer(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional:
Here and below, 100 comes off as a bit of a magic number that it's hard to see why it was picked (and what it does). Passing all these in as kwargs would make it a little more readable.

If this 100 is the 'same' 100 that is used in line 153, it would be good to refactor them to use a shared constant.

@rhaps0dy rhaps0dy merged commit fc9b730 into main Oct 7, 2023
0 of 3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants