From 9f8b64bd4bb0dc4db3557be251dcc1dc1b8bfe41 Mon Sep 17 00:00:00 2001 From: Arjun KG Date: Thu, 4 May 2023 06:35:25 +0900 Subject: [PATCH] Gymnasium support for DDPG continuous (+Jax) (#371) * ddpg continuous + jax * fix video recording * remove pybullet * move to usage docs * isort * update lock files * try trigger CI * update ddpg default v4 environments * trigger CI * install jax dependency * fix CI * remove windows CI --------- Co-authored-by: Costa Huang --- .github/workflows/tests.yaml | 98 ++++++++++++++++++++++++++- README.md | 3 +- benchmark/ddpg.sh | 0 cleanrl/ddpg_continuous_action.py | 47 +++++++++---- cleanrl/ddpg_continuous_action_jax.py | 45 ++++++++---- docs/get-started/basic-usage.md | 9 +++ poetry.lock | 2 +- pyproject.toml | 2 +- tests/test_mujoco_gymnasium.py | 17 +++++ tests/test_mujoco_py.py | 10 --- tests/test_mujoco_py_gymnasium.py | 17 +++++ tests/test_pybullet.py | 5 -- 12 files changed, 206 insertions(+), 49 deletions(-) mode change 100644 => 100755 benchmark/ddpg.sh create mode 100644 tests/test_mujoco_gymnasium.py create mode 100644 tests/test_mujoco_py_gymnasium.py diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 7c493319..4558c02a 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -174,13 +174,13 @@ jobs: continue-on-error: true # MUJOCO_GL=osmesa results in `free(): invalid pointer` run: poetry run pytest tests/test_mujoco.py - test-mujoco-envs-windows-mac: + test-mujoco-gymnasium-envs: strategy: fail-fast: false matrix: python-version: [3.8] poetry-version: [1.3] - os: [macos-latest, windows-latest] + os: [ubuntu-22.04] runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v2 @@ -197,9 +197,68 @@ jobs: run: poetry install -E "pytest mujoco dm_control" - name: Downgrade setuptools run: poetry run pip install setuptools==59.5.0 + - name: Run gymnasium migration dependencies + run: poetry run pip install "stable_baselines3==2.0.0a1" + - name: install mujoco dependencies + run: | + sudo apt-get update && sudo apt-get -y install libgl1-mesa-glx libosmesa6 libglfw3 + - name: Run mujoco tests + continue-on-error: true # MUJOCO_GL=osmesa results in `free(): invalid pointer` + run: poetry run pytest tests/test_mujoco_gymnasium.py + + test-mujoco-envs-mac: + strategy: + fail-fast: false + matrix: + python-version: [3.8] + poetry-version: [1.3] + os: [macos-latest] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Run image + uses: abatilo/actions-poetry@v2.0.0 + with: + poetry-version: ${{ matrix.poetry-version }} + + # mujoco tests + - name: Install dependencies + run: poetry install -E "pytest mujoco dm_control jax" + - name: Downgrade setuptools + run: poetry run pip install setuptools==59.5.0 - name: Run mujoco tests run: poetry run pytest tests/test_mujoco.py + test-mujoco-gymnasium-mac: + strategy: + fail-fast: false + matrix: + python-version: [3.8] + poetry-version: [1.3] + os: [macos-latest] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Run image + uses: abatilo/actions-poetry@v2.0.0 + with: + poetry-version: ${{ matrix.poetry-version }} + + # mujoco tests + - name: Install dependencies + run: poetry install -E "pytest mujoco dm_control jax" + - name: Downgrade setuptools + run: poetry run pip install setuptools==59.5.0 + - name: Run gymnasium migration dependencies + run: poetry run pip install "stable_baselines3==2.0.0a1" + - name: Run mujoco tests + run: poetry run pytest tests/test_mujoco_gymnasium.py test-mujoco_py-envs: strategy: @@ -234,6 +293,41 @@ jobs: - name: Run mujoco_py tests run: poetry run pytest tests/test_mujoco_py.py + test-mujoco_py-envs-gymnasium: + strategy: + fail-fast: false + matrix: + python-version: [3.8] + poetry-version: [1.3] + os: [ubuntu-22.04] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Run image + uses: abatilo/actions-poetry@v2.0.0 + with: + poetry-version: ${{ matrix.poetry-version }} + + # mujoco_py tests + - name: Install dependencies + run: poetry install -E "pytest pybullet mujoco_py mujoco jax" + - name: Run gymnasium migration dependencies + run: poetry run pip install "stable_baselines3==2.0.0a1" + - name: Downgrade setuptools + run: poetry run pip install setuptools==59.5.0 + - name: install mujoco_py dependencies + run: | + sudo apt-get update && sudo apt-get -y install wget unzip software-properties-common \ + libgl1-mesa-dev \ + libgl1-mesa-glx \ + libglew-dev \ + libosmesa6-dev patchelf + - name: Run mujoco_py tests + run: poetry run pytest tests/test_mujoco_py_gymnasium.py + test-envpool-envs: strategy: fail-fast: false diff --git a/README.md b/README.md index 2ec2611b..6d24cd20 100644 --- a/README.md +++ b/README.md @@ -30,8 +30,7 @@ You can read more about CleanRL in our [JMLR paper](https://www.jmlr.org/papers/ CleanRL only contains implementations of **online** deep reinforcement learning algorithms. If you are looking for **offline** algorithms, please check out [tinkoff-ai/CORL](https://github.com/tinkoff-ai/CORL), which shares a similar design philosophy as CleanRL. -> ℹ️ **Support for Gymnasium**: [Farama-Foundation/Gymnasium](https://github.com/Farama-Foundation/Gymnasium) is the next generation of [`openai/gym`](https://github.com/openai/gym) that will continue to be maintained and introduce new features. Please see their [announcement](https://farama.org/Announcing-The-Farama-Foundation) for further detail. We are migrating to `gymnasium` and the progress can be tracked in [vwxyzjn/cleanrl#277](https://github.com/vwxyzjn/cleanrl/pull/277). - +> ℹ️ **Support for Gymnasium**: [Farama-Foundation/Gymnasium](https://github.com/Farama-Foundation/Gymnasium) is the next generation of [`openai/gym`](https://github.com/openai/gym) that will continue to be maintained and introduce new features. Please see their [announcement](https://farama.org/Announcing-The-Farama-Foundation) for further detail. We are migrating to `gymnasium` and the progress can be tracked in [vwxyzjn/cleanrl#277](https://github.com/vwxyzjn/cleanrl/pull/277). > ⚠️ **NOTE**: CleanRL is *not* a modular library and therefore it is not meant to be imported. At the cost of duplicate code, we make all implementation details of a DRL algorithm variant easy to understand, so CleanRL comes with its own pros and cons. You should consider using CleanRL if you want to 1) understand all implementation details of an algorithm's varaint or 2) prototype advanced features that other modular DRL libraries do not support (CleanRL has minimal lines of code so it gives you great debugging experience and you don't have do a lot of subclassing like sometimes in modular DRL libraries). diff --git a/benchmark/ddpg.sh b/benchmark/ddpg.sh old mode 100644 new mode 100755 diff --git a/cleanrl/ddpg_continuous_action.py b/cleanrl/ddpg_continuous_action.py index 00a82191..14ccfd25 100644 --- a/cleanrl/ddpg_continuous_action.py +++ b/cleanrl/ddpg_continuous_action.py @@ -5,9 +5,8 @@ import time from distutils.util import strtobool -import gym +import gymnasium as gym import numpy as np -import pybullet_envs # noqa import torch import torch.nn as nn import torch.nn.functional as F @@ -37,7 +36,7 @@ def parse_args(): help="whether to capture videos of the agent performances (check out `videos` folder)") # Algorithm specific arguments - parser.add_argument("--env-id", type=str, default="HopperBulletEnv-v0", + parser.add_argument("--env-id", type=str, default="HalfCheetah-v4", help="the id of the environment") parser.add_argument("--total-timesteps", type=int, default=1000000, help="total timesteps of the experiments") @@ -66,12 +65,14 @@ def parse_args(): def make_env(env_id, seed, idx, capture_video, run_name): def thunk(): - env = gym.make(env_id) + if capture_video: + env = gym.make(env_id, render_mode="rgb_array") + else: + env = gym.make(env_id) env = gym.wrappers.RecordEpisodeStatistics(env) if capture_video: if idx == 0: env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") - env.seed(seed) env.action_space.seed(seed) env.observation_space.seed(seed) return env @@ -117,6 +118,15 @@ def forward(self, x): if __name__ == "__main__": + import stable_baselines3 as sb3 + + if sb3.__version__ < "2.0": + raise ValueError( + """Ongoing migration: run the following command to install the new dependencies: + +poetry run pip install "stable_baselines3==2.0.0a1" +""" + ) args = parse_args() run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" if args.track: @@ -128,7 +138,7 @@ def forward(self, x): sync_tensorboard=True, config=vars(args), name=run_name, - monitor_gym=True, + # monitor_gym=True, # no longer works for gymnasium save_code=True, ) writer = SummaryWriter(f"runs/{run_name}") @@ -164,12 +174,14 @@ def forward(self, x): envs.single_observation_space, envs.single_action_space, device, - handle_timeout_termination=True, + handle_timeout_termination=False, ) start_time = time.time() # TRY NOT TO MODIFY: start the game - obs = envs.reset() + obs, _ = envs.reset(seed=args.seed) + video_filenames = set() + for global_step in range(args.total_timesteps): # ALGO LOGIC: put action logic here if global_step < args.learning_starts: @@ -181,11 +193,12 @@ def forward(self, x): actions = actions.cpu().numpy().clip(envs.single_action_space.low, envs.single_action_space.high) # TRY NOT TO MODIFY: execute the game and log data. - next_obs, rewards, dones, infos = envs.step(actions) + next_obs, rewards, terminateds, truncateds, infos = envs.step(actions) # TRY NOT TO MODIFY: record rewards for plotting purposes - for info in infos: - if "episode" in info.keys(): + + if "final_info" in infos: + for info in infos["final_info"]: print(f"global_step={global_step}, episodic_return={info['episode']['r']}") writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) @@ -193,10 +206,10 @@ def forward(self, x): # TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation` real_next_obs = next_obs.copy() - for idx, d in enumerate(dones): + for idx, d in enumerate(truncateds): if d: - real_next_obs[idx] = infos[idx]["terminal_observation"] - rb.add(obs, real_next_obs, actions, rewards, dones, infos) + real_next_obs[idx] = infos["final_observation"][idx] + rb.add(obs, real_next_obs, actions, rewards, terminateds, infos) # TRY NOT TO MODIFY: CRUCIAL step easy to overlook obs = next_obs @@ -237,4 +250,10 @@ def forward(self, x): writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) envs.close() + + if args.track and args.capture_video: + for filename in os.listdir(f"videos/{run_name}"): + if filename not in video_filenames and filename.endswith(".mp4"): + wandb.log({f"videos": wandb.Video(f"videos/{run_name}/{filename}")}) + video_filenames.add(filename) writer.close() diff --git a/cleanrl/ddpg_continuous_action_jax.py b/cleanrl/ddpg_continuous_action_jax.py index b6291e4d..6ddb87ad 100644 --- a/cleanrl/ddpg_continuous_action_jax.py +++ b/cleanrl/ddpg_continuous_action_jax.py @@ -8,12 +8,11 @@ import flax import flax.linen as nn -import gym +import gymnasium as gym import jax import jax.numpy as jnp import numpy as np import optax -import pybullet_envs # noqa from flax.training.train_state import TrainState from stable_baselines3.common.buffers import ReplayBuffer from torch.utils.tensorboard import SummaryWriter @@ -36,7 +35,7 @@ def parse_args(): help="whether to capture videos of the agent performances (check out `videos` folder)") # Algorithm specific arguments - parser.add_argument("--env-id", type=str, default="HalfCheetah-v2", + parser.add_argument("--env-id", type=str, default="HalfCheetah-v4", help="the id of the environment") parser.add_argument("--total-timesteps", type=int, default=1000000, help="total timesteps of the experiments") @@ -65,12 +64,14 @@ def parse_args(): def make_env(env_id, seed, idx, capture_video, run_name): def thunk(): - env = gym.make(env_id) + if capture_video: + env = gym.make(env_id, render_mode="rgb_array") + else: + env = gym.make(env_id) env = gym.wrappers.RecordEpisodeStatistics(env) if capture_video: if idx == 0: env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") - env.seed(seed) env.action_space.seed(seed) env.observation_space.seed(seed) return env @@ -113,6 +114,15 @@ class TrainState(TrainState): if __name__ == "__main__": + import stable_baselines3 as sb3 + + if sb3.__version__ < "2.0": + raise ValueError( + """Ongoing migration: run the following command to install the new dependencies: + +poetry run pip install "stable_baselines3==2.0.0a1" +""" + ) args = parse_args() run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" if args.track: @@ -124,7 +134,7 @@ class TrainState(TrainState): sync_tensorboard=True, config=vars(args), name=run_name, - monitor_gym=True, + monitor_gym=True, # does not work on gymnasium save_code=True, ) writer = SummaryWriter(f"runs/{run_name}") @@ -150,11 +160,12 @@ class TrainState(TrainState): envs.single_observation_space, envs.single_action_space, device="cpu", - handle_timeout_termination=True, + handle_timeout_termination=False, ) # TRY NOT TO MODIFY: start the game - obs = envs.reset() + obs, _ = envs.reset() + video_filenames = set() action_scale = np.array((envs.action_space.high - envs.action_space.low) / 2.0) action_bias = np.array((envs.action_space.high + envs.action_space.low) / 2.0) actor = Actor( @@ -235,11 +246,11 @@ def actor_loss(params): ) # TRY NOT TO MODIFY: execute the game and log data. - next_obs, rewards, dones, infos = envs.step(actions) + next_obs, rewards, terminateds, truncateds, infos = envs.step(actions) # TRY NOT TO MODIFY: record rewards for plotting purposes - for info in infos: - if "episode" in info.keys(): + if "final_info" in infos: + for info in infos["final_info"]: print(f"global_step={global_step}, episodic_return={info['episode']['r']}") writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) @@ -247,10 +258,10 @@ def actor_loss(params): # TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation` real_next_obs = next_obs.copy() - for idx, d in enumerate(dones): + for idx, d in enumerate(truncateds): if d: - real_next_obs[idx] = infos[idx]["terminal_observation"] - rb.add(obs, real_next_obs, actions, rewards, dones, infos) + real_next_obs[idx] = infos["final_observation"][idx] + rb.add(obs, real_next_obs, actions, rewards, terminateds, infos) # TRY NOT TO MODIFY: CRUCIAL step easy to overlook obs = next_obs @@ -282,4 +293,10 @@ def actor_loss(params): writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) envs.close() + if args.track and args.capture_video: + for filename in os.listdir(f"videos/{run_name}"): + if filename not in video_filenames and filename.endswith(".mp4"): + wandb.log({f"videos": wandb.Video(f"videos/{run_name}/{filename}")}) + video_filenames.add(filename) + writer.close() diff --git a/docs/get-started/basic-usage.md b/docs/get-started/basic-usage.md index 0d2fe1ba..5571c3e7 100644 --- a/docs/get-started/basic-usage.md +++ b/docs/get-started/basic-usage.md @@ -44,6 +44,15 @@ the CleanRL script under the poetry virtual environments. **We will assume to run other commands (e.g. `tensorboard`) in the documentation within the poetry's shell.** +!!! note +Currently, `ddpg_continuous_action_jax.py`, `ddpg_continuous_action.py` have been ported to gymnasium. + +Please note that, `stable-baselines3` version `1.2` does not support `gymnasium`. To use these scripts, please install the `alpha1` version like, + +``` +poetry run pip install sb3==2.0.0a1 +``` + !!! warning If you are using NVIDIA ampere GPUs (e.g., 3060 TI), you might meet the following error diff --git a/poetry.lock b/poetry.lock index 3d5d9bfc..d0ad2dec 100644 --- a/poetry.lock +++ b/poetry.lock @@ -4732,4 +4732,4 @@ pytest = ["pytest"] [metadata] lock-version = "2.0" python-versions = ">=3.7.1,<3.10" -content-hash = "76c5bc466eff3e90d989584942ec32be81f6ef3f7b5c9b137775e8d58efd6f0b" +content-hash = "76c5bc466eff3e90d989584942ec32be81f6ef3f7b5c9b137775e8d58efd6f0b" \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index b2879e57..6248399d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -99,4 +99,4 @@ ppo_atari_envpool_xla_jax_scan = [ "ale-py", "AutoROM", "opencv-python", # atari "jax", "jaxlib", "flax", # jax "envpool", # envpool -] +] \ No newline at end of file diff --git a/tests/test_mujoco_gymnasium.py b/tests/test_mujoco_gymnasium.py new file mode 100644 index 00000000..a887dc0a --- /dev/null +++ b/tests/test_mujoco_gymnasium.py @@ -0,0 +1,17 @@ +import subprocess + + +def test_mujoco(): + """ + Test mujoco + """ + subprocess.run( + "python cleanrl/ddpg_continuous_action.py --env-id Hopper-v4 --learning-starts 100 --batch-size 32 --total-timesteps 105", + shell=True, + check=True, + ) + subprocess.run( + "python cleanrl/ddpg_continuous_action_jax.py --env-id Hopper-v4 --learning-starts 100 --batch-size 32 --total-timesteps 105", + shell=True, + check=True, + ) diff --git a/tests/test_mujoco_py.py b/tests/test_mujoco_py.py index c0238900..882e1680 100644 --- a/tests/test_mujoco_py.py +++ b/tests/test_mujoco_py.py @@ -10,16 +10,6 @@ def test_mujoco_py(): shell=True, check=True, ) - subprocess.run( - "python cleanrl/ddpg_continuous_action.py --env-id Hopper-v2 --learning-starts 100 --batch-size 32 --total-timesteps 105", - shell=True, - check=True, - ) - subprocess.run( - "python cleanrl/ddpg_continuous_action_jax.py --env-id Hopper-v2 --learning-starts 100 --batch-size 32 --total-timesteps 105", - shell=True, - check=True, - ) subprocess.run( "python cleanrl/td3_continuous_action_jax.py --env-id Hopper-v2 --learning-starts 100 --batch-size 32 --total-timesteps 105", shell=True, diff --git a/tests/test_mujoco_py_gymnasium.py b/tests/test_mujoco_py_gymnasium.py new file mode 100644 index 00000000..6474c0e5 --- /dev/null +++ b/tests/test_mujoco_py_gymnasium.py @@ -0,0 +1,17 @@ +import subprocess + + +def test_mujoco_py(): + """ + Test mujoco_py + """ + subprocess.run( + "python cleanrl/ddpg_continuous_action.py --env-id Hopper-v2 --learning-starts 100 --batch-size 32 --total-timesteps 105", + shell=True, + check=True, + ) + subprocess.run( + "python cleanrl/ddpg_continuous_action_jax.py --env-id Hopper-v2 --learning-starts 100 --batch-size 32 --total-timesteps 105", + shell=True, + check=True, + ) diff --git a/tests/test_pybullet.py b/tests/test_pybullet.py index c9fabf70..365f71cc 100644 --- a/tests/test_pybullet.py +++ b/tests/test_pybullet.py @@ -2,11 +2,6 @@ def test_pybullet(): - subprocess.run( - "python cleanrl/ddpg_continuous_action.py --learning-starts 100 --batch-size 32 --total-timesteps 105", - shell=True, - check=True, - ) subprocess.run( "python cleanrl/td3_continuous_action.py --learning-starts 100 --batch-size 32 --total-timesteps 105", shell=True,