Skip to content

Commit

Permalink
Using jax scan for PPO + atari + envpool XLA (vwxyzjn#328)
Browse files Browse the repository at this point in the history
jax.scan for ppo + atari + envpool and corresponding docs and tests
  • Loading branch information
51616 authored Dec 21, 2022
1 parent b558b2b commit 2dd73af
Show file tree
Hide file tree
Showing 9 changed files with 626 additions and 3 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ You may also use a prebuilt development environment hosted in Gitpod:
| | [`ppo_atari_lstm.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_atari_lstm.py), [docs](https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_atari_lstmpy)
| | [`ppo_atari_envpool.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_atari_envpool.py), [docs](https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_atari_envpoolpy)
| | [`ppo_atari_envpool_xla_jax.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_atari_envpool_xla_jax.py), [docs](/rl-algorithms/ppo/#ppo_atari_envpool_xla_jaxpy)
| | [`ppo_atari_envpool_xla_jax_scan.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_atari_envpool_xla_jax_scan.py), [docs](/rl-algorithms/ppo/#ppo_atari_envpool_xla_jax_scanpy)
| | [`ppo_procgen.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_procgen.py), [docs](https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_procgenpy)
| | [`ppo_atari_multigpu.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_atari_multigpu.py), [docs](https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_atari_multigpupy)
| | [`ppo_pettingzoo_ma_atari.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_pettingzoo_ma_atari.py), [docs](https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_pettingzoo_ma_ataripy)
Expand Down
8 changes: 7 additions & 1 deletion benchmark/ppo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -104,4 +104,10 @@ OMP_NUM_THREADS=1 xvfb-run -a poetry run python -m cleanrl_utils.benchmark \
--command "poetry run python cleanrl/gymnasium_support/ppo_continuous_action.py --cuda False --track" \
--num-seeds 3 \
--workers 9


poetry install --with envpool,jax
python -m cleanrl_utils.benchmark \
--env-ids Pong-v5 BeamRider-v5 Breakout-v5 \
--command "poetry run python cleanrl/ppo_atari_envpool_xla_jax_scan.py --track --capture-video" \
--num-seeds 3 \
--workers 1
Loading

0 comments on commit 2dd73af

Please sign in to comment.