diff --git a/benchmark/c51.sh b/benchmark/c51.sh index 4cd8215e4..6aba77810 100644 --- a/benchmark/c51.sh +++ b/benchmark/c51.sh @@ -13,7 +13,7 @@ OMP_NUM_THREADS=1 xvfb-run -a poetry run python -m cleanrl_utils.benchmark \ --workers 1 poetry install -E "jax" -poetry run pip install --upgrade "jax[cuda]==0.3.17" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +poetry run pip install --upgrade "jax[cuda11_cudnn82]==0.4.8" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html CUDA_VISIBLE_DEVICES=-1 xvfb-run -a python -m cleanrl_utils.benchmark \ --env-ids CartPole-v1 Acrobot-v1 MountainCar-v0 \ --command "poetry run python cleanrl/c51_jax.py --track --capture_video" \ @@ -21,7 +21,7 @@ CUDA_VISIBLE_DEVICES=-1 xvfb-run -a python -m cleanrl_utils.benchmark \ --workers 1 poetry install -E "atari jax" -poetry run pip install --upgrade "jax[cuda]==0.3.17" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +poetry run pip install --upgrade "jax[cuda11_cudnn82]==0.4.8" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html xvfb-run -a python -m cleanrl_utils.benchmark \ --env-ids PongNoFrameskip-v4 BeamRiderNoFrameskip-v4 BreakoutNoFrameskip-v4 \ --command "poetry run python cleanrl/c51_atari_jax.py --track --capture_video" \ diff --git a/benchmark/ddpg.sh b/benchmark/ddpg.sh index f22d8496f..3746b4d99 100755 --- a/benchmark/ddpg.sh +++ b/benchmark/ddpg.sh @@ -1,16 +1,22 @@ -poetry install -E "mujoco_py" -python -c "import mujoco_py" -xvfb-run -a python -m cleanrl_utils.benchmark \ - --env-ids HalfCheetah-v2 Walker2d-v2 Hopper-v2 InvertedPendulum-v2 Humanoid-v2 Pusher-v2 \ - --command "poetry run python cleanrl/ddpg_continuous_action.py --track --capture_video" \ +poetry install -E "mujoco" +python -m cleanrl_utils.benchmark \ + --env-ids HalfCheetah-v4 Walker2d-v4 Hopper-v4 InvertedPendulum-v4 Humanoid-v4 Pusher-v4 \ + --command "poetry run python cleanrl/ddpg_continuous_action.py --track" \ --num-seeds 3 \ - --workers 1 + --workers 18 \ + --slurm-gpus-per-task 1 \ + --slurm-ntasks 1 \ + --slurm-total-cpus 10 \ + --slurm-template-path benchmark/cleanrl_1gpu.slurm_template -poetry install -E "mujoco_py jax" -poetry run pip install --upgrade "jax[cuda]==0.3.17" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html -poetry run python -c "import mujoco_py" -xvfb-run -a poetry run python -m cleanrl_utils.benchmark \ - --env-ids HalfCheetah-v2 Walker2d-v2 Hopper-v2 \ - --command "poetry run python cleanrl/ddpg_continuous_action_jax.py --track --capture_video" \ +poetry install -E "mujoco jax" +poetry run pip install --upgrade "jax[cuda11_cudnn82]==0.4.8" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +poetry run python -m cleanrl_utils.benchmark \ + --env-ids HalfCheetah-v4 Walker2d-v4 Hopper-v4 InvertedPendulum-v4 Humanoid-v4 Pusher-v4 \ + --command "poetry run python cleanrl/ddpg_continuous_action_jax.py --track" \ --num-seeds 3 \ - --workers 1 + --workers 18 \ + --slurm-gpus-per-task 1 \ + --slurm-ntasks 1 \ + --slurm-total-cpus 10 \ + --slurm-template-path benchmark/cleanrl_1gpu.slurm_template diff --git a/benchmark/ddpg_plot.sh b/benchmark/ddpg_plot.sh new file mode 100755 index 000000000..d36db199e --- /dev/null +++ b/benchmark/ddpg_plot.sh @@ -0,0 +1,20 @@ +python -m openrlbenchmark.rlops \ + --filters '?we=openrlbenchmark&wpn=cleanrl&ceik=env_id&cen=exp_name&metric=charts/episodic_return' \ + 'ddpg_continuous_action?tag=pr-424' \ + --env-ids HalfCheetah-v4 Walker2d-v4 Hopper-v4 InvertedPendulum-v4 Humanoid-v4 Pusher-v4 \ + --no-check-empty-runs \ + --pc.ncols 3 \ + --pc.ncols-legend 2 \ + --output-filename benchmark/cleanrl/ddpg \ + --scan-history + +python -m openrlbenchmark.rlops \ + --filters '?we=openrlbenchmark&wpn=cleanrl&ceik=env_id&cen=exp_name&metric=charts/episodic_return' \ + 'ddpg_continuous_action?tag=pr-424' \ + 'ddpg_continuous_action_jax?tag=pr-424' \ + --env-ids HalfCheetah-v4 Walker2d-v4 Hopper-v4 InvertedPendulum-v4 Humanoid-v4 Pusher-v4 \ + --no-check-empty-runs \ + --pc.ncols 3 \ + --pc.ncols-legend 2 \ + --output-filename benchmark/cleanrl/ddpg_jax \ + --scan-history diff --git a/benchmark/dqn.sh b/benchmark/dqn.sh index 213966ae8..dcd90446b 100644 --- a/benchmark/dqn.sh +++ b/benchmark/dqn.sh @@ -13,7 +13,7 @@ OMP_NUM_THREADS=1 xvfb-run -a poetry run python -m cleanrl_utils.benchmark \ --workers 1 poetry install -E jax -poetry run pip install --upgrade "jax[cuda]==0.3.17" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +poetry run pip install --upgrade "jax[cuda11_cudnn82]==0.4.8" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html xvfb-run -a python -m cleanrl_utils.benchmark \ --env-ids CartPole-v1 Acrobot-v1 MountainCar-v0 \ --command "poetry run python cleanrl/dqn_jax.py --track --capture_video" \ @@ -21,7 +21,7 @@ xvfb-run -a python -m cleanrl_utils.benchmark \ --workers 1 poetry install -E "atari jax" -poetry run pip install --upgrade "jax[cuda]==0.3.17" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +poetry run pip install --upgrade "jax[cuda11_cudnn82]==0.4.8" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html xvfb-run -a python -m cleanrl_utils.benchmark \ --env-ids PongNoFrameskip-v4 BeamRiderNoFrameskip-v4 BreakoutNoFrameskip-v4 \ --command "poetry run python cleanrl/dqn_atari_jax.py --track --capture_video" \ diff --git a/benchmark/qdagger.sh b/benchmark/qdagger.sh index 0a28fd933..dc7851fb3 100644 --- a/benchmark/qdagger.sh +++ b/benchmark/qdagger.sh @@ -7,7 +7,7 @@ OMP_NUM_THREADS=1 xvfb-run -a poetry run python -m cleanrl_utils.benchmark \ poetry install -E "atari jax" -poetry run pip install --upgrade "jax[cuda]==0.3.17" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +poetry run pip install --upgrade "jax[cuda11_cudnn82]==0.4.8" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html xvfb-run -a poetry run python -m cleanrl_utils.benchmark \ --env-ids PongNoFrameskip-v4 BeamRiderNoFrameskip-v4 BreakoutNoFrameskip-v4 \ --command "poetry run python cleanrl/qdagger_dqn_atari_jax_impalacnn.py --track --capture_video" \ diff --git a/benchmark/td3.sh b/benchmark/td3.sh index 97aa5406d..e68004c73 100644 --- a/benchmark/td3.sh +++ b/benchmark/td3.sh @@ -1,16 +1,22 @@ -poetry install -E mujoco_py -python -c "import mujoco_py" -OMP_NUM_THREADS=1 xvfb-run -a python -m cleanrl_utils.benchmark \ - --env-ids HalfCheetah-v2 Walker2d-v2 Hopper-v2 InvertedPendulum-v2 Humanoid-v2 Pusher-v2 \ - --command "poetry run python cleanrl/td3_continuous_action.py --track --capture_video" \ +poetry install -E "mujoco" +python -m cleanrl_utils.benchmark \ + --env-ids HalfCheetah-v4 Walker2d-v4 Hopper-v4 InvertedPendulum-v4 Humanoid-v4 Pusher-v4 \ + --command "poetry run python cleanrl/td3_continuous_action.py --track" \ --num-seeds 3 \ - --workers 1 + --workers 18 \ + --slurm-gpus-per-task 1 \ + --slurm-ntasks 1 \ + --slurm-total-cpus 10 \ + --slurm-template-path benchmark/cleanrl_1gpu.slurm_template -poetry install -E "mujoco_py jax" -poetry run pip install --upgrade "jax[cuda]==0.3.17" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html -poetry run python -c "import mujoco_py" -xvfb-run -a poetry run python -m cleanrl_utils.benchmark \ - --env-ids HalfCheetah-v2 Walker2d-v2 Hopper-v2 \ - --command "poetry run python cleanrl/td3_continuous_action_jax.py --track --capture_video" \ +poetry install -E "mujoco jax" +poetry run pip install --upgrade "jax[cuda11_cudnn82]==0.4.8" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +poetry run python -m cleanrl_utils.benchmark \ + --env-ids HalfCheetah-v4 Walker2d-v4 Hopper-v4 InvertedPendulum-v4 Humanoid-v4 Pusher-v4 \ + --command "poetry run python cleanrl/td3_continuous_action_jax.py --track" \ --num-seeds 3 \ - --workers 1 + --workers 18 \ + --slurm-gpus-per-task 1 \ + --slurm-ntasks 1 \ + --slurm-total-cpus 10 \ + --slurm-template-path benchmark/cleanrl_1gpu.slurm_template diff --git a/benchmark/td3_plot.sh b/benchmark/td3_plot.sh new file mode 100644 index 000000000..ad37305cc --- /dev/null +++ b/benchmark/td3_plot.sh @@ -0,0 +1,21 @@ +python -m openrlbenchmark.rlops \ + --filters '?we=openrlbenchmark&wpn=cleanrl&ceik=env_id&cen=exp_name&metric=charts/episodic_return' \ + 'td3_continuous_action?tag=pr-424' \ + 'td3_continuous_action_jax?tag=pr-424' \ + --filters '?we=openrlbenchmark&wpn=cleanrl&ceik=env_id&cen=exp_name&metric=charts/episodic_return' \ + --env-ids HalfCheetah-v4 Walker2d-v4 Hopper-v4 InvertedPendulum-v4 Humanoid-v4 Pusher-v4 \ + --no-check-empty-runs \ + --pc.ncols 3 \ + --pc.ncols-legend 2 \ + --output-filename benchmark/cleanrl/td3 \ + --scan-history + +python -m openrlbenchmark.rlops \ + --filters '?we=openrlbenchmark&wpn=cleanrl&ceik=env_id&cen=exp_name&metric=charts/episodic_return' \ + 'sac_continuous_action?tag=pr-424' \ + --env-ids HalfCheetah-v4 Walker2d-v4 Hopper-v4 InvertedPendulum-v4 Humanoid-v4 Pusher-v4 \ + --no-check-empty-runs \ + --pc.ncols 3 \ + --pc.ncols-legend 2 \ + --output-filename benchmark/cleanrl/sac \ + --scan-history diff --git a/docs/rl-algorithms/ddpg.md b/docs/rl-algorithms/ddpg.md index 060a3c06d..671d4e903 100644 --- a/docs/rl-algorithms/ddpg.md +++ b/docs/rl-algorithms/ddpg.md @@ -257,7 +257,6 @@ Below are the average episodic returns for [`ddpg_continuous_action.py`](https:/ Learning curves: - ``` title="benchmark/ddpg_plot.sh" linenums="1" --8<-- "benchmark/ddpg_plot.sh::9" ``` @@ -343,7 +342,7 @@ Learning curves: ???+ info - + These are some previous experiments with TPUs. Note the results are very similar to the ones above, but the runtime can be different due to different hardware used. diff --git a/docs/rl-algorithms/td3.md b/docs/rl-algorithms/td3.md index a105de300..6bf4494f9 100644 --- a/docs/rl-algorithms/td3.md +++ b/docs/rl-algorithms/td3.md @@ -124,7 +124,9 @@ Additionally, when drawing exploration noise that is added to the actions produc To run benchmark experiments, see :material-github: [benchmark/td3.sh](https://github.com/vwxyzjn/cleanrl/blob/master/benchmark/td3.sh). Specifically, execute the following command: - +``` title="benchmark/td3.sh" linenums="1" +--8<-- "benchmark/td3.sh::7" +``` Below are the average episodic returns for [`td3_continuous_action.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/td3_continuous_action.py) (3 random seeds). To ensure the quality of the implementation, we compared the results against (Fujimoto et al., 2018)[^2]. @@ -150,6 +152,10 @@ Below are the average episodic returns for [`td3_continuous_action.py`](https:// Learning curves: +``` title="benchmark/td3_plot.sh" linenums="1" +--8<-- "benchmark/td3_plot.sh::9" +``` + @@ -203,7 +209,9 @@ See [related docs](/rl-algorithms/td3/#implementation-details) for `td3_continuo To run benchmark experiments, see :material-github: [benchmark/td3.sh](https://github.com/vwxyzjn/cleanrl/blob/master/benchmark/td3.sh). Specifically, execute the following command: - +``` title="benchmark/td3.sh" linenums="1" +--8<-- "benchmark/td3.sh:12:19" +``` Below are the average episodic returns for [`td3_continuous_action_jax.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/td3_continuous_action_jax.py) (3 random seeds). @@ -211,6 +219,11 @@ Below are the average episodic returns for [`td3_continuous_action_jax.py`](http Learning curves: + +``` title="benchmark/td3_plot.sh" linenums="1" +--8<-- "benchmark/td3_plot.sh:11:20" +``` +