Skip to content

Commit

Permalink
update td3 docs
Browse files Browse the repository at this point in the history
  • Loading branch information
vwxyzjn committed Nov 9, 2023
1 parent b2542e0 commit 33a5609
Show file tree
Hide file tree
Showing 9 changed files with 100 additions and 35 deletions.
4 changes: 2 additions & 2 deletions benchmark/c51.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ OMP_NUM_THREADS=1 xvfb-run -a poetry run python -m cleanrl_utils.benchmark \
--workers 1

poetry install -E "jax"
poetry run pip install --upgrade "jax[cuda]==0.3.17" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
poetry run pip install --upgrade "jax[cuda11_cudnn82]==0.4.8" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
CUDA_VISIBLE_DEVICES=-1 xvfb-run -a python -m cleanrl_utils.benchmark \
--env-ids CartPole-v1 Acrobot-v1 MountainCar-v0 \
--command "poetry run python cleanrl/c51_jax.py --track --capture_video" \
--num-seeds 3 \
--workers 1

poetry install -E "atari jax"
poetry run pip install --upgrade "jax[cuda]==0.3.17" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
poetry run pip install --upgrade "jax[cuda11_cudnn82]==0.4.8" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
xvfb-run -a python -m cleanrl_utils.benchmark \
--env-ids PongNoFrameskip-v4 BeamRiderNoFrameskip-v4 BreakoutNoFrameskip-v4 \
--command "poetry run python cleanrl/c51_atari_jax.py --track --capture_video" \
Expand Down
32 changes: 19 additions & 13 deletions benchmark/ddpg.sh
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
poetry install -E "mujoco_py"
python -c "import mujoco_py"
xvfb-run -a python -m cleanrl_utils.benchmark \
--env-ids HalfCheetah-v2 Walker2d-v2 Hopper-v2 InvertedPendulum-v2 Humanoid-v2 Pusher-v2 \
--command "poetry run python cleanrl/ddpg_continuous_action.py --track --capture_video" \
poetry install -E "mujoco"
python -m cleanrl_utils.benchmark \
--env-ids HalfCheetah-v4 Walker2d-v4 Hopper-v4 InvertedPendulum-v4 Humanoid-v4 Pusher-v4 \
--command "poetry run python cleanrl/ddpg_continuous_action.py --track" \
--num-seeds 3 \
--workers 1
--workers 18 \
--slurm-gpus-per-task 1 \
--slurm-ntasks 1 \
--slurm-total-cpus 10 \
--slurm-template-path benchmark/cleanrl_1gpu.slurm_template

poetry install -E "mujoco_py jax"
poetry run pip install --upgrade "jax[cuda]==0.3.17" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
poetry run python -c "import mujoco_py"
xvfb-run -a poetry run python -m cleanrl_utils.benchmark \
--env-ids HalfCheetah-v2 Walker2d-v2 Hopper-v2 \
--command "poetry run python cleanrl/ddpg_continuous_action_jax.py --track --capture_video" \
poetry install -E "mujoco jax"
poetry run pip install --upgrade "jax[cuda11_cudnn82]==0.4.8" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
poetry run python -m cleanrl_utils.benchmark \
--env-ids HalfCheetah-v4 Walker2d-v4 Hopper-v4 InvertedPendulum-v4 Humanoid-v4 Pusher-v4 \
--command "poetry run python cleanrl/ddpg_continuous_action_jax.py --track" \
--num-seeds 3 \
--workers 1
--workers 18 \
--slurm-gpus-per-task 1 \
--slurm-ntasks 1 \
--slurm-total-cpus 10 \
--slurm-template-path benchmark/cleanrl_1gpu.slurm_template
20 changes: 20 additions & 0 deletions benchmark/ddpg_plot.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
python -m openrlbenchmark.rlops \
--filters '?we=openrlbenchmark&wpn=cleanrl&ceik=env_id&cen=exp_name&metric=charts/episodic_return' \
'ddpg_continuous_action?tag=pr-424' \
--env-ids HalfCheetah-v4 Walker2d-v4 Hopper-v4 InvertedPendulum-v4 Humanoid-v4 Pusher-v4 \
--no-check-empty-runs \
--pc.ncols 3 \
--pc.ncols-legend 2 \
--output-filename benchmark/cleanrl/ddpg \
--scan-history

python -m openrlbenchmark.rlops \
--filters '?we=openrlbenchmark&wpn=cleanrl&ceik=env_id&cen=exp_name&metric=charts/episodic_return' \
'ddpg_continuous_action?tag=pr-424' \
'ddpg_continuous_action_jax?tag=pr-424' \
--env-ids HalfCheetah-v4 Walker2d-v4 Hopper-v4 InvertedPendulum-v4 Humanoid-v4 Pusher-v4 \
--no-check-empty-runs \
--pc.ncols 3 \
--pc.ncols-legend 2 \
--output-filename benchmark/cleanrl/ddpg_jax \
--scan-history
4 changes: 2 additions & 2 deletions benchmark/dqn.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ OMP_NUM_THREADS=1 xvfb-run -a poetry run python -m cleanrl_utils.benchmark \
--workers 1

poetry install -E jax
poetry run pip install --upgrade "jax[cuda]==0.3.17" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
poetry run pip install --upgrade "jax[cuda11_cudnn82]==0.4.8" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
xvfb-run -a python -m cleanrl_utils.benchmark \
--env-ids CartPole-v1 Acrobot-v1 MountainCar-v0 \
--command "poetry run python cleanrl/dqn_jax.py --track --capture_video" \
--num-seeds 3 \
--workers 1

poetry install -E "atari jax"
poetry run pip install --upgrade "jax[cuda]==0.3.17" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
poetry run pip install --upgrade "jax[cuda11_cudnn82]==0.4.8" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
xvfb-run -a python -m cleanrl_utils.benchmark \
--env-ids PongNoFrameskip-v4 BeamRiderNoFrameskip-v4 BreakoutNoFrameskip-v4 \
--command "poetry run python cleanrl/dqn_atari_jax.py --track --capture_video" \
Expand Down
2 changes: 1 addition & 1 deletion benchmark/qdagger.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ OMP_NUM_THREADS=1 xvfb-run -a poetry run python -m cleanrl_utils.benchmark \


poetry install -E "atari jax"
poetry run pip install --upgrade "jax[cuda]==0.3.17" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
poetry run pip install --upgrade "jax[cuda11_cudnn82]==0.4.8" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
xvfb-run -a poetry run python -m cleanrl_utils.benchmark \
--env-ids PongNoFrameskip-v4 BeamRiderNoFrameskip-v4 BreakoutNoFrameskip-v4 \
--command "poetry run python cleanrl/qdagger_dqn_atari_jax_impalacnn.py --track --capture_video" \
Expand Down
32 changes: 19 additions & 13 deletions benchmark/td3.sh
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
poetry install -E mujoco_py
python -c "import mujoco_py"
OMP_NUM_THREADS=1 xvfb-run -a python -m cleanrl_utils.benchmark \
--env-ids HalfCheetah-v2 Walker2d-v2 Hopper-v2 InvertedPendulum-v2 Humanoid-v2 Pusher-v2 \
--command "poetry run python cleanrl/td3_continuous_action.py --track --capture_video" \
poetry install -E "mujoco"
python -m cleanrl_utils.benchmark \
--env-ids HalfCheetah-v4 Walker2d-v4 Hopper-v4 InvertedPendulum-v4 Humanoid-v4 Pusher-v4 \
--command "poetry run python cleanrl/td3_continuous_action.py --track" \
--num-seeds 3 \
--workers 1
--workers 18 \
--slurm-gpus-per-task 1 \
--slurm-ntasks 1 \
--slurm-total-cpus 10 \
--slurm-template-path benchmark/cleanrl_1gpu.slurm_template

poetry install -E "mujoco_py jax"
poetry run pip install --upgrade "jax[cuda]==0.3.17" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
poetry run python -c "import mujoco_py"
xvfb-run -a poetry run python -m cleanrl_utils.benchmark \
--env-ids HalfCheetah-v2 Walker2d-v2 Hopper-v2 \
--command "poetry run python cleanrl/td3_continuous_action_jax.py --track --capture_video" \
poetry install -E "mujoco jax"
poetry run pip install --upgrade "jax[cuda11_cudnn82]==0.4.8" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
poetry run python -m cleanrl_utils.benchmark \
--env-ids HalfCheetah-v4 Walker2d-v4 Hopper-v4 InvertedPendulum-v4 Humanoid-v4 Pusher-v4 \
--command "poetry run python cleanrl/td3_continuous_action_jax.py --track" \
--num-seeds 3 \
--workers 1
--workers 18 \
--slurm-gpus-per-task 1 \
--slurm-ntasks 1 \
--slurm-total-cpus 10 \
--slurm-template-path benchmark/cleanrl_1gpu.slurm_template
21 changes: 21 additions & 0 deletions benchmark/td3_plot.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
python -m openrlbenchmark.rlops \
--filters '?we=openrlbenchmark&wpn=cleanrl&ceik=env_id&cen=exp_name&metric=charts/episodic_return' \
'td3_continuous_action?tag=pr-424' \
'td3_continuous_action_jax?tag=pr-424' \
--filters '?we=openrlbenchmark&wpn=cleanrl&ceik=env_id&cen=exp_name&metric=charts/episodic_return' \
--env-ids HalfCheetah-v4 Walker2d-v4 Hopper-v4 InvertedPendulum-v4 Humanoid-v4 Pusher-v4 \
--no-check-empty-runs \
--pc.ncols 3 \
--pc.ncols-legend 2 \
--output-filename benchmark/cleanrl/td3 \
--scan-history

python -m openrlbenchmark.rlops \
--filters '?we=openrlbenchmark&wpn=cleanrl&ceik=env_id&cen=exp_name&metric=charts/episodic_return' \
'sac_continuous_action?tag=pr-424' \
--env-ids HalfCheetah-v4 Walker2d-v4 Hopper-v4 InvertedPendulum-v4 Humanoid-v4 Pusher-v4 \
--no-check-empty-runs \
--pc.ncols 3 \
--pc.ncols-legend 2 \
--output-filename benchmark/cleanrl/sac \
--scan-history
3 changes: 1 addition & 2 deletions docs/rl-algorithms/ddpg.md
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,6 @@ Below are the average episodic returns for [`ddpg_continuous_action.py`](https:/

Learning curves:


``` title="benchmark/ddpg_plot.sh" linenums="1"
--8<-- "benchmark/ddpg_plot.sh::9"
```
Expand Down Expand Up @@ -343,7 +342,7 @@ Learning curves:


???+ info

These are some previous experiments with TPUs. Note the results are very similar to the ones above, but the runtime can be different due to different hardware used.


Expand Down
17 changes: 15 additions & 2 deletions docs/rl-algorithms/td3.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,9 @@ Additionally, when drawing exploration noise that is added to the actions produc

To run benchmark experiments, see :material-github: [benchmark/td3.sh](https://github.com/vwxyzjn/cleanrl/blob/master/benchmark/td3.sh). Specifically, execute the following command:

<script src="https://emgithub.com/embed.js?target=https%3A%2F%2Fgithub.com%2Fvwxyzjn%2Fcleanrl%2Fblob%2Fmaster%2Fbenchmark%2Ftd3.sh%23L1-L7&style=github&type=code&showBorder=on&showLineNumbers=on&showFileMeta=on&showFullPath=on&showCopy=on"></script>
``` title="benchmark/td3.sh" linenums="1"
--8<-- "benchmark/td3.sh::7"
```


Below are the average episodic returns for [`td3_continuous_action.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/td3_continuous_action.py) (3 random seeds). To ensure the quality of the implementation, we compared the results against (Fujimoto et al., 2018)[^2].
Expand All @@ -150,6 +152,10 @@ Below are the average episodic returns for [`td3_continuous_action.py`](https://

Learning curves:

``` title="benchmark/td3_plot.sh" linenums="1"
--8<-- "benchmark/td3_plot.sh::9"
```

<img loading="lazy" src="https://huggingface.co/datasets/cleanrl/benchmark/resolve/main/benchmark/pr-424/td3.png">
<img loading="lazy" src="https://huggingface.co/datasets/cleanrl/benchmark/resolve/main/benchmark/pr-424/td3-time.png">

Expand Down Expand Up @@ -203,14 +209,21 @@ See [related docs](/rl-algorithms/td3/#implementation-details) for `td3_continuo

To run benchmark experiments, see :material-github: [benchmark/td3.sh](https://github.com/vwxyzjn/cleanrl/blob/master/benchmark/td3.sh). Specifically, execute the following command:

<script src="https://emgithub.com/embed.js?target=https%3A%2F%2Fgithub.com%2Fvwxyzjn%2Fcleanrl%2Fblob%2Fmaster%2Fbenchmark%2Ftd3.sh%23L9-L16&style=github&type=code&showBorder=on&showLineNumbers=on&showFileMeta=on&showFullPath=on&showCopy=on"></script>
``` title="benchmark/td3.sh" linenums="1"
--8<-- "benchmark/td3.sh:12:19"
```

Below are the average episodic returns for [`td3_continuous_action_jax.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/td3_continuous_action_jax.py) (3 random seeds).

{!benchmark/td3.md!}

Learning curves:


``` title="benchmark/td3_plot.sh" linenums="1"
--8<-- "benchmark/td3_plot.sh:11:20"
```

<img loading="lazy" src="https://huggingface.co/datasets/cleanrl/benchmark/resolve/main/benchmark/pr-424/td3.png">
<img loading="lazy" src="https://huggingface.co/datasets/cleanrl/benchmark/resolve/main/benchmark/pr-424/td3-time.png">

Expand Down

0 comments on commit 33a5609

Please sign in to comment.