diff --git a/benchmark/c51.sh b/benchmark/c51.sh
index 4cd8215e4..6aba77810 100644
--- a/benchmark/c51.sh
+++ b/benchmark/c51.sh
@@ -13,7 +13,7 @@ OMP_NUM_THREADS=1 xvfb-run -a poetry run python -m cleanrl_utils.benchmark \
--workers 1
poetry install -E "jax"
-poetry run pip install --upgrade "jax[cuda]==0.3.17" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
+poetry run pip install --upgrade "jax[cuda11_cudnn82]==0.4.8" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
CUDA_VISIBLE_DEVICES=-1 xvfb-run -a python -m cleanrl_utils.benchmark \
--env-ids CartPole-v1 Acrobot-v1 MountainCar-v0 \
--command "poetry run python cleanrl/c51_jax.py --track --capture_video" \
@@ -21,7 +21,7 @@ CUDA_VISIBLE_DEVICES=-1 xvfb-run -a python -m cleanrl_utils.benchmark \
--workers 1
poetry install -E "atari jax"
-poetry run pip install --upgrade "jax[cuda]==0.3.17" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
+poetry run pip install --upgrade "jax[cuda11_cudnn82]==0.4.8" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
xvfb-run -a python -m cleanrl_utils.benchmark \
--env-ids PongNoFrameskip-v4 BeamRiderNoFrameskip-v4 BreakoutNoFrameskip-v4 \
--command "poetry run python cleanrl/c51_atari_jax.py --track --capture_video" \
diff --git a/benchmark/ddpg.sh b/benchmark/ddpg.sh
index f22d8496f..3746b4d99 100755
--- a/benchmark/ddpg.sh
+++ b/benchmark/ddpg.sh
@@ -1,16 +1,22 @@
-poetry install -E "mujoco_py"
-python -c "import mujoco_py"
-xvfb-run -a python -m cleanrl_utils.benchmark \
- --env-ids HalfCheetah-v2 Walker2d-v2 Hopper-v2 InvertedPendulum-v2 Humanoid-v2 Pusher-v2 \
- --command "poetry run python cleanrl/ddpg_continuous_action.py --track --capture_video" \
+poetry install -E "mujoco"
+python -m cleanrl_utils.benchmark \
+ --env-ids HalfCheetah-v4 Walker2d-v4 Hopper-v4 InvertedPendulum-v4 Humanoid-v4 Pusher-v4 \
+ --command "poetry run python cleanrl/ddpg_continuous_action.py --track" \
--num-seeds 3 \
- --workers 1
+ --workers 18 \
+ --slurm-gpus-per-task 1 \
+ --slurm-ntasks 1 \
+ --slurm-total-cpus 10 \
+ --slurm-template-path benchmark/cleanrl_1gpu.slurm_template
-poetry install -E "mujoco_py jax"
-poetry run pip install --upgrade "jax[cuda]==0.3.17" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
-poetry run python -c "import mujoco_py"
-xvfb-run -a poetry run python -m cleanrl_utils.benchmark \
- --env-ids HalfCheetah-v2 Walker2d-v2 Hopper-v2 \
- --command "poetry run python cleanrl/ddpg_continuous_action_jax.py --track --capture_video" \
+poetry install -E "mujoco jax"
+poetry run pip install --upgrade "jax[cuda11_cudnn82]==0.4.8" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
+poetry run python -m cleanrl_utils.benchmark \
+ --env-ids HalfCheetah-v4 Walker2d-v4 Hopper-v4 InvertedPendulum-v4 Humanoid-v4 Pusher-v4 \
+ --command "poetry run python cleanrl/ddpg_continuous_action_jax.py --track" \
--num-seeds 3 \
- --workers 1
+ --workers 18 \
+ --slurm-gpus-per-task 1 \
+ --slurm-ntasks 1 \
+ --slurm-total-cpus 10 \
+ --slurm-template-path benchmark/cleanrl_1gpu.slurm_template
diff --git a/benchmark/ddpg_plot.sh b/benchmark/ddpg_plot.sh
new file mode 100755
index 000000000..d36db199e
--- /dev/null
+++ b/benchmark/ddpg_plot.sh
@@ -0,0 +1,20 @@
+python -m openrlbenchmark.rlops \
+ --filters '?we=openrlbenchmark&wpn=cleanrl&ceik=env_id&cen=exp_name&metric=charts/episodic_return' \
+ 'ddpg_continuous_action?tag=pr-424' \
+ --env-ids HalfCheetah-v4 Walker2d-v4 Hopper-v4 InvertedPendulum-v4 Humanoid-v4 Pusher-v4 \
+ --no-check-empty-runs \
+ --pc.ncols 3 \
+ --pc.ncols-legend 2 \
+ --output-filename benchmark/cleanrl/ddpg \
+ --scan-history
+
+python -m openrlbenchmark.rlops \
+ --filters '?we=openrlbenchmark&wpn=cleanrl&ceik=env_id&cen=exp_name&metric=charts/episodic_return' \
+ 'ddpg_continuous_action?tag=pr-424' \
+ 'ddpg_continuous_action_jax?tag=pr-424' \
+ --env-ids HalfCheetah-v4 Walker2d-v4 Hopper-v4 InvertedPendulum-v4 Humanoid-v4 Pusher-v4 \
+ --no-check-empty-runs \
+ --pc.ncols 3 \
+ --pc.ncols-legend 2 \
+ --output-filename benchmark/cleanrl/ddpg_jax \
+ --scan-history
diff --git a/benchmark/dqn.sh b/benchmark/dqn.sh
index 213966ae8..dcd90446b 100644
--- a/benchmark/dqn.sh
+++ b/benchmark/dqn.sh
@@ -13,7 +13,7 @@ OMP_NUM_THREADS=1 xvfb-run -a poetry run python -m cleanrl_utils.benchmark \
--workers 1
poetry install -E jax
-poetry run pip install --upgrade "jax[cuda]==0.3.17" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
+poetry run pip install --upgrade "jax[cuda11_cudnn82]==0.4.8" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
xvfb-run -a python -m cleanrl_utils.benchmark \
--env-ids CartPole-v1 Acrobot-v1 MountainCar-v0 \
--command "poetry run python cleanrl/dqn_jax.py --track --capture_video" \
@@ -21,7 +21,7 @@ xvfb-run -a python -m cleanrl_utils.benchmark \
--workers 1
poetry install -E "atari jax"
-poetry run pip install --upgrade "jax[cuda]==0.3.17" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
+poetry run pip install --upgrade "jax[cuda11_cudnn82]==0.4.8" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
xvfb-run -a python -m cleanrl_utils.benchmark \
--env-ids PongNoFrameskip-v4 BeamRiderNoFrameskip-v4 BreakoutNoFrameskip-v4 \
--command "poetry run python cleanrl/dqn_atari_jax.py --track --capture_video" \
diff --git a/benchmark/qdagger.sh b/benchmark/qdagger.sh
index 0a28fd933..dc7851fb3 100644
--- a/benchmark/qdagger.sh
+++ b/benchmark/qdagger.sh
@@ -7,7 +7,7 @@ OMP_NUM_THREADS=1 xvfb-run -a poetry run python -m cleanrl_utils.benchmark \
poetry install -E "atari jax"
-poetry run pip install --upgrade "jax[cuda]==0.3.17" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
+poetry run pip install --upgrade "jax[cuda11_cudnn82]==0.4.8" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
xvfb-run -a poetry run python -m cleanrl_utils.benchmark \
--env-ids PongNoFrameskip-v4 BeamRiderNoFrameskip-v4 BreakoutNoFrameskip-v4 \
--command "poetry run python cleanrl/qdagger_dqn_atari_jax_impalacnn.py --track --capture_video" \
diff --git a/benchmark/td3.sh b/benchmark/td3.sh
index 97aa5406d..e68004c73 100644
--- a/benchmark/td3.sh
+++ b/benchmark/td3.sh
@@ -1,16 +1,22 @@
-poetry install -E mujoco_py
-python -c "import mujoco_py"
-OMP_NUM_THREADS=1 xvfb-run -a python -m cleanrl_utils.benchmark \
- --env-ids HalfCheetah-v2 Walker2d-v2 Hopper-v2 InvertedPendulum-v2 Humanoid-v2 Pusher-v2 \
- --command "poetry run python cleanrl/td3_continuous_action.py --track --capture_video" \
+poetry install -E "mujoco"
+python -m cleanrl_utils.benchmark \
+ --env-ids HalfCheetah-v4 Walker2d-v4 Hopper-v4 InvertedPendulum-v4 Humanoid-v4 Pusher-v4 \
+ --command "poetry run python cleanrl/td3_continuous_action.py --track" \
--num-seeds 3 \
- --workers 1
+ --workers 18 \
+ --slurm-gpus-per-task 1 \
+ --slurm-ntasks 1 \
+ --slurm-total-cpus 10 \
+ --slurm-template-path benchmark/cleanrl_1gpu.slurm_template
-poetry install -E "mujoco_py jax"
-poetry run pip install --upgrade "jax[cuda]==0.3.17" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
-poetry run python -c "import mujoco_py"
-xvfb-run -a poetry run python -m cleanrl_utils.benchmark \
- --env-ids HalfCheetah-v2 Walker2d-v2 Hopper-v2 \
- --command "poetry run python cleanrl/td3_continuous_action_jax.py --track --capture_video" \
+poetry install -E "mujoco jax"
+poetry run pip install --upgrade "jax[cuda11_cudnn82]==0.4.8" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
+poetry run python -m cleanrl_utils.benchmark \
+ --env-ids HalfCheetah-v4 Walker2d-v4 Hopper-v4 InvertedPendulum-v4 Humanoid-v4 Pusher-v4 \
+ --command "poetry run python cleanrl/td3_continuous_action_jax.py --track" \
--num-seeds 3 \
- --workers 1
+ --workers 18 \
+ --slurm-gpus-per-task 1 \
+ --slurm-ntasks 1 \
+ --slurm-total-cpus 10 \
+ --slurm-template-path benchmark/cleanrl_1gpu.slurm_template
diff --git a/benchmark/td3_plot.sh b/benchmark/td3_plot.sh
new file mode 100644
index 000000000..ad37305cc
--- /dev/null
+++ b/benchmark/td3_plot.sh
@@ -0,0 +1,21 @@
+python -m openrlbenchmark.rlops \
+ --filters '?we=openrlbenchmark&wpn=cleanrl&ceik=env_id&cen=exp_name&metric=charts/episodic_return' \
+ 'td3_continuous_action?tag=pr-424' \
+ 'td3_continuous_action_jax?tag=pr-424' \
+ --filters '?we=openrlbenchmark&wpn=cleanrl&ceik=env_id&cen=exp_name&metric=charts/episodic_return' \
+ --env-ids HalfCheetah-v4 Walker2d-v4 Hopper-v4 InvertedPendulum-v4 Humanoid-v4 Pusher-v4 \
+ --no-check-empty-runs \
+ --pc.ncols 3 \
+ --pc.ncols-legend 2 \
+ --output-filename benchmark/cleanrl/td3 \
+ --scan-history
+
+python -m openrlbenchmark.rlops \
+ --filters '?we=openrlbenchmark&wpn=cleanrl&ceik=env_id&cen=exp_name&metric=charts/episodic_return' \
+ 'sac_continuous_action?tag=pr-424' \
+ --env-ids HalfCheetah-v4 Walker2d-v4 Hopper-v4 InvertedPendulum-v4 Humanoid-v4 Pusher-v4 \
+ --no-check-empty-runs \
+ --pc.ncols 3 \
+ --pc.ncols-legend 2 \
+ --output-filename benchmark/cleanrl/sac \
+ --scan-history
diff --git a/docs/rl-algorithms/ddpg.md b/docs/rl-algorithms/ddpg.md
index 060a3c06d..671d4e903 100644
--- a/docs/rl-algorithms/ddpg.md
+++ b/docs/rl-algorithms/ddpg.md
@@ -257,7 +257,6 @@ Below are the average episodic returns for [`ddpg_continuous_action.py`](https:/
Learning curves:
-
``` title="benchmark/ddpg_plot.sh" linenums="1"
--8<-- "benchmark/ddpg_plot.sh::9"
```
@@ -343,7 +342,7 @@ Learning curves:
???+ info
-
+
These are some previous experiments with TPUs. Note the results are very similar to the ones above, but the runtime can be different due to different hardware used.
diff --git a/docs/rl-algorithms/td3.md b/docs/rl-algorithms/td3.md
index a105de300..6bf4494f9 100644
--- a/docs/rl-algorithms/td3.md
+++ b/docs/rl-algorithms/td3.md
@@ -124,7 +124,9 @@ Additionally, when drawing exploration noise that is added to the actions produc
To run benchmark experiments, see :material-github: [benchmark/td3.sh](https://github.com/vwxyzjn/cleanrl/blob/master/benchmark/td3.sh). Specifically, execute the following command:
-
+``` title="benchmark/td3.sh" linenums="1"
+--8<-- "benchmark/td3.sh::7"
+```
Below are the average episodic returns for [`td3_continuous_action.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/td3_continuous_action.py) (3 random seeds). To ensure the quality of the implementation, we compared the results against (Fujimoto et al., 2018)[^2].
@@ -150,6 +152,10 @@ Below are the average episodic returns for [`td3_continuous_action.py`](https://
Learning curves:
+``` title="benchmark/td3_plot.sh" linenums="1"
+--8<-- "benchmark/td3_plot.sh::9"
+```
+
@@ -203,7 +209,9 @@ See [related docs](/rl-algorithms/td3/#implementation-details) for `td3_continuo
To run benchmark experiments, see :material-github: [benchmark/td3.sh](https://github.com/vwxyzjn/cleanrl/blob/master/benchmark/td3.sh). Specifically, execute the following command:
-
+``` title="benchmark/td3.sh" linenums="1"
+--8<-- "benchmark/td3.sh:12:19"
+```
Below are the average episodic returns for [`td3_continuous_action_jax.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/td3_continuous_action_jax.py) (3 random seeds).
@@ -211,6 +219,11 @@ Below are the average episodic returns for [`td3_continuous_action_jax.py`](http
Learning curves:
+
+``` title="benchmark/td3_plot.sh" linenums="1"
+--8<-- "benchmark/td3_plot.sh:11:20"
+```
+