Skip to content

Commit

Permalink
PPO + JAX + EnvPool + Atari (vwxyzjn#227)
Browse files Browse the repository at this point in the history
* PPO + jax + envpool + atari

* fix bug: only report metric when lifes are used up

* pre-commit

* quick fix

* Quick refactor

* push changes

* pre-commit and use EnvPool's new API

* update envpool

* update docs

* update ppo benchmark script

* update docs

* use the latest envpool interface

* update envpool to the latest version

* update pyproject.toml

* update lock files

* Quick clarification

* Update docs

* remove non benchmarked script

* update docs

* revert poetry changes

* docs fix

* remove uncessary code, add docs

* add a note one envpool

* update test cases

* explain `get_action_and_value`

* fix indent

* Fix weird error with `np.mean`. See below:

We got this message. See vwxyzjn#227 (comment)
```
NotImplementedError: Got <class 'jaxlib.xla_extension.DeviceArray'>, but numpy array, torch tensor, or caffe2 blob name are expected.
```

* update docs

* pre-commit

* add note on `charts/avg_episodic_return`

* update reproducibility script

* add note on value function clipping
  • Loading branch information
vwxyzjn authored Oct 6, 2022
1 parent c20c799 commit 42d21bd
Show file tree
Hide file tree
Showing 22 changed files with 157,414 additions and 32 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ jobs:

# envpool tests
- name: Install envpool dependencies
run: poetry install --with pytest,envpool
run: poetry install --with pytest,envpool,jax
- name: Downgrade setuptools
run: poetry run pip install setuptools==59.5.0
- name: Run envpool tests
Expand Down
18 changes: 18 additions & 0 deletions benchmark/ppo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,21 @@ xvfb-run -a poetry run python -m cleanrl_utils.benchmark \
--command "poetry run python cleanrl/ppo_continuous_action_isaacgym/ppo_continuous_action_isaacgym.py --track --capture-video --num-envs 8192 --num-steps 8 --update-epochs 5 --num-minibatches 4 --reward-scaler 0.01 --total-timesteps 600000000 --record-video-step-frequency 3660" \
--num-seeds 3 \
--workers 1


poetry install --with envpool
poetry run python -m cleanrl_utils.benchmark \
--env-ids Alien-v5 Amidar-v5 Assault-v5 Asterix-v5 Asteroids-v5 Atlantis-v5 BankHeist-v5 BattleZone-v5 BeamRider-v5 Berzerk-v5 Bowling-v5 Boxing-v5 Breakout-v5 Centipede-v5 ChopperCommand-v5 CrazyClimber-v5 Defender-v5 DemonAttack-v5 \
--command "poetry run python ppo_atari_envpool_xla_jax.py --track --wandb-project-name envpool-atari --wandb-entity openrlbenchmark" \
--num-seeds 3 \
--workers 1
poetry run python -m cleanrl_utils.benchmark \
--env-ids DoubleDunk-v5 Enduro-v5 FishingDerby-v5 Freeway-v5 Frostbite-v5 Gopher-v5 Gravitar-v5 Hero-v5 IceHockey-v5 Jamesbond-v5 Kangaroo-v5 Krull-v5 KungFuMaster-v5 MontezumaRevenge-v5 MsPacman-v5 NameThisGame-v5 Phoenix-v5 Pitfall-v5 Pong-v5 \
--command "poetry run python ppo_atari_envpool_xla_jax.py --track --wandb-project-name envpool-atari --wandb-entity openrlbenchmark" \
--num-seeds 3 \
--workers 1
poetry run python -m cleanrl_utils.benchmark \
--env-ids PrivateEye-v5 Qbert-v5 Riverraid-v5 RoadRunner-v5 Robotank-v5 Seaquest-v5 Skiing-v5 Solaris-v5 SpaceInvaders-v5 StarGunner-v5 Surround-v5 Tennis-v5 TimePilot-v5 Tutankham-v5 UpNDown-v5 Venture-v5 VideoPinball-v5 WizardOfWor-v5 YarsRevenge-v5 Zaxxon-v5 \
--command "poetry run python ppo_atari_envpool_xla_jax.py --track --wandb-project-name envpool-atari --wandb-entity openrlbenchmark" \
--num-seeds 3 \
--workers 1
16 changes: 2 additions & 14 deletions cleanrl/ppo_atari_envpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,6 @@ def __init__(self, env, deque_size=100):
self.num_envs = getattr(env, "num_envs", 1)
self.episode_returns = None
self.episode_lengths = None
# get if the env has lives
self.has_lives = False
env.reset()
info = env.step(np.zeros(self.num_envs, dtype=int))[-1]
if info["lives"].sum() > 0:
self.has_lives = True
print("env has lives")

def reset(self, **kwargs):
observations = super().reset(**kwargs)
Expand All @@ -107,13 +100,8 @@ def step(self, action):
self.episode_lengths += 1
self.returned_episode_returns[:] = self.episode_returns
self.returned_episode_lengths[:] = self.episode_lengths
all_lives_exhausted = infos["lives"] == 0
if self.has_lives:
self.episode_returns *= 1 - all_lives_exhausted
self.episode_lengths *= 1 - all_lives_exhausted
else:
self.episode_returns *= 1 - dones
self.episode_lengths *= 1 - dones
self.episode_returns *= 1 - infos["terminated"]
self.episode_lengths *= 1 - infos["terminated"]
infos["r"] = self.returned_episode_returns
infos["l"] = self.returned_episode_lengths
return (
Expand Down
Loading

0 comments on commit 42d21bd

Please sign in to comment.