Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PPO + JAX + EnvPool + Atari #227

Merged
merged 37 commits into from
Oct 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
52e62b4
PPO + jax + envpool + atari
vwxyzjn Jul 7, 2022
0e66dd4
fix bug: only report metric when lifes are used up
vwxyzjn Jul 11, 2022
2cf5c3a
pre-commit
vwxyzjn Jul 11, 2022
4c33285
quick fix
vwxyzjn Jul 11, 2022
9108810
Quick refactor
vwxyzjn Jul 11, 2022
cc1c7d9
Merge branch 'master' into jax-ppo-envpool-atari
vwxyzjn Aug 22, 2022
8449fe3
push changes
vwxyzjn Aug 22, 2022
73309a5
pre-commit and use EnvPool's new API
vwxyzjn Aug 22, 2022
2527d62
update envpool
vwxyzjn Aug 22, 2022
cc861e2
update docs
vwxyzjn Aug 24, 2022
1d57472
update ppo benchmark script
vwxyzjn Aug 24, 2022
beb82d2
update docs
vwxyzjn Aug 24, 2022
931593c
Merge branch 'master' into jax-ppo-envpool-atari
vwxyzjn Aug 25, 2022
3c5c499
Merge branch 'master' into jax-ppo-envpool-atari
vwxyzjn Aug 26, 2022
55c9a74
use the latest envpool interface
vwxyzjn Aug 26, 2022
5631fb5
update envpool to the latest version
vwxyzjn Aug 27, 2022
f1624cf
update pyproject.toml
vwxyzjn Aug 27, 2022
0fd6729
update lock files
vwxyzjn Aug 27, 2022
5c07c28
Quick clarification
vwxyzjn Aug 28, 2022
879965e
Merge branch 'jax-ppo-envpool-atari' of https://github.com/vwxyzjn/cl…
vwxyzjn Aug 28, 2022
ed1b15b
Update docs
vwxyzjn Sep 12, 2022
bc1b6d7
remove non benchmarked script
vwxyzjn Sep 12, 2022
0905b93
update docs
vwxyzjn Sep 12, 2022
217a108
revert poetry changes
vwxyzjn Oct 5, 2022
58318ec
Merge branch 'master' into jax-ppo-envpool-atari
vwxyzjn Oct 5, 2022
6a3c91a
docs fix
vwxyzjn Oct 5, 2022
5f2680f
remove uncessary code, add docs
vwxyzjn Oct 5, 2022
19cc1fe
add a note one envpool
vwxyzjn Oct 5, 2022
2b10f75
update test cases
vwxyzjn Oct 5, 2022
918ccdf
explain `get_action_and_value`
vwxyzjn Oct 5, 2022
ffde717
fix indent
vwxyzjn Oct 5, 2022
6b67283
Fix weird error with `np.mean`. See below:
vwxyzjn Oct 5, 2022
6aa73d0
update docs
vwxyzjn Oct 5, 2022
c61a521
pre-commit
vwxyzjn Oct 5, 2022
0f72ce0
add note on `charts/avg_episodic_return`
vwxyzjn Oct 5, 2022
8cbb5b1
update reproducibility script
vwxyzjn Oct 5, 2022
ed877ad
add note on value function clipping
vwxyzjn Oct 6, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ jobs:

# envpool tests
- name: Install envpool dependencies
run: poetry install --with pytest,envpool
run: poetry install --with pytest,envpool,jax
- name: Downgrade setuptools
run: poetry run pip install setuptools==59.5.0
- name: Run envpool tests
Expand Down
18 changes: 18 additions & 0 deletions benchmark/ppo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,21 @@ xvfb-run -a poetry run python -m cleanrl_utils.benchmark \
--command "poetry run python cleanrl/ppo_continuous_action_isaacgym/ppo_continuous_action_isaacgym.py --track --capture-video --num-envs 8192 --num-steps 8 --update-epochs 5 --num-minibatches 4 --reward-scaler 0.01 --total-timesteps 600000000 --record-video-step-frequency 3660" \
--num-seeds 3 \
--workers 1


poetry install --with envpool
poetry run python -m cleanrl_utils.benchmark \
--env-ids Alien-v5 Amidar-v5 Assault-v5 Asterix-v5 Asteroids-v5 Atlantis-v5 BankHeist-v5 BattleZone-v5 BeamRider-v5 Berzerk-v5 Bowling-v5 Boxing-v5 Breakout-v5 Centipede-v5 ChopperCommand-v5 CrazyClimber-v5 Defender-v5 DemonAttack-v5 \
--command "poetry run python ppo_atari_envpool_xla_jax.py --track --wandb-project-name envpool-atari --wandb-entity openrlbenchmark" \
--num-seeds 3 \
--workers 1
poetry run python -m cleanrl_utils.benchmark \
--env-ids DoubleDunk-v5 Enduro-v5 FishingDerby-v5 Freeway-v5 Frostbite-v5 Gopher-v5 Gravitar-v5 Hero-v5 IceHockey-v5 Jamesbond-v5 Kangaroo-v5 Krull-v5 KungFuMaster-v5 MontezumaRevenge-v5 MsPacman-v5 NameThisGame-v5 Phoenix-v5 Pitfall-v5 Pong-v5 \
--command "poetry run python ppo_atari_envpool_xla_jax.py --track --wandb-project-name envpool-atari --wandb-entity openrlbenchmark" \
--num-seeds 3 \
--workers 1
poetry run python -m cleanrl_utils.benchmark \
--env-ids PrivateEye-v5 Qbert-v5 Riverraid-v5 RoadRunner-v5 Robotank-v5 Seaquest-v5 Skiing-v5 Solaris-v5 SpaceInvaders-v5 StarGunner-v5 Surround-v5 Tennis-v5 TimePilot-v5 Tutankham-v5 UpNDown-v5 Venture-v5 VideoPinball-v5 WizardOfWor-v5 YarsRevenge-v5 Zaxxon-v5 \
--command "poetry run python ppo_atari_envpool_xla_jax.py --track --wandb-project-name envpool-atari --wandb-entity openrlbenchmark" \
--num-seeds 3 \
--workers 1
16 changes: 2 additions & 14 deletions cleanrl/ppo_atari_envpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,6 @@ def __init__(self, env, deque_size=100):
self.num_envs = getattr(env, "num_envs", 1)
self.episode_returns = None
self.episode_lengths = None
# get if the env has lives
self.has_lives = False
env.reset()
info = env.step(np.zeros(self.num_envs, dtype=int))[-1]
if info["lives"].sum() > 0:
self.has_lives = True
print("env has lives")

def reset(self, **kwargs):
observations = super().reset(**kwargs)
Expand All @@ -107,13 +100,8 @@ def step(self, action):
self.episode_lengths += 1
self.returned_episode_returns[:] = self.episode_returns
self.returned_episode_lengths[:] = self.episode_lengths
all_lives_exhausted = infos["lives"] == 0
if self.has_lives:
self.episode_returns *= 1 - all_lives_exhausted
self.episode_lengths *= 1 - all_lives_exhausted
else:
self.episode_returns *= 1 - dones
self.episode_lengths *= 1 - dones
self.episode_returns *= 1 - infos["terminated"]
self.episode_lengths *= 1 - infos["terminated"]
infos["r"] = self.returned_episode_returns
infos["l"] = self.returned_episode_lengths
return (
Expand Down
Loading