Skip to content

Commit

Permalink
prototype jax with ddpg (vwxyzjn#187)
Browse files Browse the repository at this point in the history
* prototype jax with ddpg

* Quick fix

* quick fix

* Commit changes - successful prototype

* Remove scripts

* Simplify the implementation: careful with shape

* Format

* Remove code

* formatting changes

* formatting change

* bug fix

* correctly implementing keys

* these two lines are not necessary

target_params are initialized with the same RNG key

* Adapting to the `TrainState` API

* Simplify code

* use `optax.incremental_update`

* Also log q values

* Addresses vwxyzjn#211

* update docs

* Add jax benchmark experiments

* remove old files

* update benchmark scripts

* update lock files

* Handle action space bounds

* Add docs

* Typo

* update CI

* bug fix and add docs link

* Add a note explaining the speed

* Update ddpg docs
  • Loading branch information
vwxyzjn authored Jul 12, 2022
1 parent cd2011c commit 7eeb583
Show file tree
Hide file tree
Showing 36 changed files with 689 additions and 1,859 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ jobs:
run: poetry install -E pybullet
- name: Install mujoco dependencies
run: poetry install -E mujoco
- name: Install jax dependencies
run: poetry install -E jax
- name: Downgrade setuptools
run: poetry run pip install setuptools==59.5.0
- name: install mujoco dependencies
Expand Down
4 changes: 2 additions & 2 deletions benchmark/c51.sh
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
poetry install
OMP_NUM_THREADS=1 xvfb-run -a python -m cleanrl_utils.benchmark \
OMP_NUM_THREADS=1 xvfb-run -a poetry run python -m cleanrl_utils.benchmark \
--env-ids CartPole-v1 Acrobot-v1 MountainCar-v0 \
--command "poetry run python cleanrl/c51.py --cuda False --track --capture-video" \
--num-seeds 3 \
--workers 9

poetry install -E atari
OMP_NUM_THREADS=1 xvfb-run -a python -m cleanrl_utils.benchmark \
OMP_NUM_THREADS=1 xvfb-run -a poetry run python -m cleanrl_utils.benchmark \
--env-ids PongNoFrameskip-v4 BeamRiderNoFrameskip-v4 BreakoutNoFrameskip-v4 \
--command "poetry run python cleanrl/c51_atari.py --track --capture-video" \
--num-seeds 3 \
Expand Down
11 changes: 10 additions & 1 deletion benchmark/ddpg.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,13 @@ OMP_NUM_THREADS=1 xvfb-run -a python -m cleanrl_utils.benchmark \
--env-ids HalfCheetah-v2 Walker2d-v2 Hopper-v2 InvertedPendulum-v2 Humanoid-v2 Pusher-v2 \
--command "poetry run python cleanrl/ddpg_continuous_action.py --track --capture-video" \
--num-seeds 3 \
--workers 3
--workers 3

poetry install -E "mujoco jax"
poetry run pip install --upgrade "jax[cuda]==0.3.14" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
poetry run python -c "import mujoco_py"
OMP_NUM_THREADS=1 xvfb-run -a poetry run python -m cleanrl_utils.benchmark \
--env-ids HalfCheetah-v2 Walker2d-v2 Hopper-v2 \
--command "poetry run python cleanrl/ddpg_continuous_action_jax.py --track --capture-video" \
--num-seeds 3 \
--workers 1
4 changes: 2 additions & 2 deletions benchmark/dqn.sh
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
poetry install
OMP_NUM_THREADS=1 xvfb-run -a python -m cleanrl_utils.benchmark \
OMP_NUM_THREADS=1 xvfb-run -a poetry run python -m cleanrl_utils.benchmark \
--env-ids CartPole-v1 Acrobot-v1 MountainCar-v0 \
--command "poetry run python cleanrl/dqn.py --cuda False --track --capture-video" \
--num-seeds 3 \
--workers 9

poetry install -E atari
OMP_NUM_THREADS=1 xvfb-run -a python -m cleanrl_utils.benchmark \
OMP_NUM_THREADS=1 xvfb-run -a poetry run python -m cleanrl_utils.benchmark \
--env-ids PongNoFrameskip-v4 BeamRiderNoFrameskip-v4 BreakoutNoFrameskip-v4 \
--command "poetry run python cleanrl/dqn_atari.py --track --capture-video" \
--num-seeds 3 \
Expand Down
2 changes: 1 addition & 1 deletion benchmark/ppg.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# export WANDB_ENTITY=openrlbenchmark

poetry install -E procgen
xvfb-run -a python -m cleanrl_utils.benchmark \
xvfb-run -a poetry run python -m cleanrl_utils.benchmark \
--env-ids starpilot bossfight bigfish \
--command "poetry run python cleanrl/ppg_procgen.py --track --capture-video" \
--num-seeds 3 \
Expand Down
18 changes: 9 additions & 9 deletions benchmark/ppo.sh
Original file line number Diff line number Diff line change
@@ -1,58 +1,58 @@
# export WANDB_ENTITY=openrlbenchmark

poetry install
OMP_NUM_THREADS=1 xvfb-run -a python -m cleanrl_utils.benchmark \
OMP_NUM_THREADS=1 xvfb-run -a poetry run python -m cleanrl_utils.benchmark \
--env-ids CartPole-v1 Acrobot-v1 MountainCar-v0 \
--command "poetry run python cleanrl/ppo.py --cuda False --track --capture-video" \
--num-seeds 3 \
--workers 9

poetry install -E atari
OMP_NUM_THREADS=1 xvfb-run -a python -m cleanrl_utils.benchmark \
OMP_NUM_THREADS=1 xvfb-run -a poetry run python -m cleanrl_utils.benchmark \
--env-ids PongNoFrameskip-v4 BeamRiderNoFrameskip-v4 BreakoutNoFrameskip-v4 \
--command "poetry run python cleanrl/ppo_atari.py --track --capture-video" \
--num-seeds 3 \
--workers 3

poetry install -E atari
OMP_NUM_THREADS=1 xvfb-run -a python -m cleanrl_utils.benchmark \
OMP_NUM_THREADS=1 xvfb-run -a poetry run python -m cleanrl_utils.benchmark \
--env-ids PongNoFrameskip-v4 BeamRiderNoFrameskip-v4 BreakoutNoFrameskip-v4 \
--command "poetry run python cleanrl/ppo_atari_lstm.py --track --capture-video" \
--num-seeds 3 \
--workers 3

poetry install -E envpool
xvfb-run -a python -m cleanrl_utils.benchmark \
xvfb-run -a poetry run python -m cleanrl_utils.benchmark \
--env-ids Pong-v5 BeamRider-v5 Breakout-v5 \
--command "poetry run python cleanrl/ppo_atari_envpool.py --track --capture-video" \
--num-seeds 3 \
--workers 1

poetry install -E "mujoco pybullet"
python -c "import mujoco_py"
OMP_NUM_THREADS=1 xvfb-run -a python -m cleanrl_utils.benchmark \
poetry run python -c "import mujoco_py"
OMP_NUM_THREADS=1 xvfb-run -a poetry run python -m cleanrl_utils.benchmark \
--env-ids HalfCheetah-v2 Walker2d-v2 Hopper-v2 \
--command "poetry run python cleanrl/ppo_continuous_action.py --cuda False --track --capture-video" \
--num-seeds 3 \
--workers 9

poetry install -E procgen
xvfb-run -a python -m cleanrl_utils.benchmark \
xvfb-run -a poetry run python -m cleanrl_utils.benchmark \
--env-ids starpilot bossfight bigfish \
--command "poetry run python cleanrl/ppo_procgen.py --track --capture-video" \
--num-seeds 3 \
--workers 1

poetry install -E atari
xvfb-run -a python -m cleanrl_utils.benchmark \
xvfb-run -a poetry run python -m cleanrl_utils.benchmark \
--env-ids PongNoFrameskip-v4 BeamRiderNoFrameskip-v4 BreakoutNoFrameskip-v4 \
--command "poetry run torchrun --standalone --nnodes=1 --nproc_per_node=2 cleanrl/ppo_atari_multigpu.py --track --capture-video" \
--num-seeds 3 \
--workers 1

poetry install -E "pettingzoo atari"
poetry run AutoROM --accept-license
xvfb-run -a python -m cleanrl_utils.benchmark \
xvfb-run -a poetry run python -m cleanrl_utils.benchmark \
--env-ids pong_v3 surround_v2 tennis_v3 \
--command "poetry run python cleanrl/ppo_pettingzoo_ma_atari.py --track --capture-video" \
--num-seeds 3 \
Expand Down
4 changes: 2 additions & 2 deletions benchmark/sac.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
poetry install -E "mujoco pybullet"
python -c "import mujoco_py"
OMP_NUM_THREADS=1 xvfb-run -a python -m cleanrl_utils.benchmark \
poetry run python -c "import mujoco_py"
OMP_NUM_THREADS=1 xvfb-run -a poetry run python -m cleanrl_utils.benchmark \
--env-ids HalfCheetah-v2 Walker2d-v2 Hopper-v2 \
--command "poetry run python cleanrl/sac_continuous_action.py --track --capture-video" \
--num-seeds 3 \
Expand Down
Loading

0 comments on commit 7eeb583

Please sign in to comment.