Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve gym speedup #210

Merged
merged 16 commits into from
Nov 1, 2022
Prev Previous commit
Next Next commit
vizdoom
  • Loading branch information
Trinkle23897 committed Oct 31, 2022
commit 642eb04f75b2fa5b61dcc2fac23cbe2f3279c97e
55 changes: 12 additions & 43 deletions envpool/mujoco/gym/mujoco_gym_align_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def run_align_check(
def test_ant(self) -> None:
assert version.parse(gym.__version__) >= version.parse("0.26.0")
env0 = gym.make("Ant-v4")
env1 = make_gym("Ant-v4", gym_reset_return_info=True)
env1 = make_gym("Ant-v4")
self.run_space_check(env0, env1)
self.run_align_check(env0, env1)
env0 = gym.make(
Expand All @@ -93,14 +93,13 @@ def test_ant(self) -> None:
terminate_when_unhealthy=False,
exclude_current_positions_from_observation=False,
max_episode_steps=100,
gym_reset_return_info=True,
)
self.run_space_check(env0, env1)
self.run_align_check(env0, env1, no_time_limit=True)

def test_half_cheetah(self) -> None:
env0 = gym.make("HalfCheetah-v4")
env1 = make_gym("HalfCheetah-v4", gym_reset_return_info=True)
env1 = make_gym("HalfCheetah-v4")
self.run_space_check(env0, env1)
self.run_align_check(env0, env1, no_time_limit=True)
env0 = gym.make(
Expand All @@ -109,16 +108,12 @@ def test_half_cheetah(self) -> None:
env1 = make_gym(
"HalfCheetah-v4",
exclude_current_positions_from_observation=True,
gym_reset_return_info=True,
max_episode_steps=1000,
)
self.run_space_check(env0, env1)

def test_hopper(self) -> None:
env0 = gym.make("Hopper-v4")
env1 = make_gym(
"Hopper-v4", gym_reset_return_info=True, max_episode_steps=1000
)
env1 = make_gym("Hopper-v4")
self.run_space_check(env0, env1)
self.run_align_check(env0, env1)
env0 = gym.make(
Expand All @@ -127,19 +122,16 @@ def test_hopper(self) -> None:
exclude_current_positions_from_observation=False,
)
env1 = make_gym(
"Hopper-v4",
terminate_when_unhealthy=False,
exclude_current_positions_from_observation=False,
gym_reset_return_info=True,
max_episode_steps=1000,
)
self.run_space_check(env0, env1)
self.run_align_check(env0, env1, no_time_limit=True)

def test_humanoid(self) -> None:
env0 = gym.make("Humanoid-v4")
env1 = make_gym(
"Humanoid-v4", gym_reset_return_info=True, max_episode_steps=1000
)
env1 = make_gym("Humanoid-v4")
self.run_space_check(env0, env1)
self.run_align_check(env0, env1)
env0 = gym.make(
Expand All @@ -151,61 +143,43 @@ def test_humanoid(self) -> None:
"Humanoid-v4",
terminate_when_unhealthy=False,
exclude_current_positions_from_observation=False,
gym_reset_return_info=True,
max_episode_steps=1000,
)
self.run_space_check(env0, env1)
self.run_align_check(env0, env1, no_time_limit=True)

def test_humanoid_standup(self) -> None:
env0 = gym.make("HumanoidStandup-v4")
env1 = make_gym(
"HumanoidStandup-v4", gym_reset_return_info=True, max_episode_steps=1000
)
env1 = make_gym("HumanoidStandup-v4")
self.run_space_check(env0, env1)
self.run_align_check(env0, env1, no_time_limit=True)

def test_inverted_double_pendulum(self) -> None:
env0 = gym.make("InvertedDoublePendulum-v4")
env1 = make_gym(
"InvertedDoublePendulum-v4",
gym_reset_return_info=True,
max_episode_steps=1000,
)
env1 = make_gym("InvertedDoublePendulum-v4",)
self.run_space_check(env0, env1)
self.run_align_check(env0, env1)

def test_inverted_pendulum(self) -> None:
env0 = gym.make("InvertedPendulum-v4")
env1 = make_gym(
"InvertedPendulum-v4",
gym_reset_return_info=True,
max_episode_steps=1000
)
env1 = make_gym("InvertedPendulum-v4",)
self.run_space_check(env0, env1)
self.run_align_check(env0, env1)

def test_pusher(self) -> None:
env0 = gym.make("Pusher-v4")
env1 = make_gym(
"Pusher-v4", gym_reset_return_info=True, max_episode_steps=100
)
env1 = make_gym("Pusher-v4")
self.run_space_check(env0, env1)
self.run_align_check(env0, env1, no_time_limit=True)

def test_reacher(self) -> None:
env0 = gym.make("Reacher-v4")
env1 = make_gym(
"Reacher-v4", gym_reset_return_info=True, max_episode_steps=50
)
env1 = make_gym("Reacher-v4")
self.run_space_check(env0, env1)
self.run_align_check(env0, env1, no_time_limit=True)

def test_swimmer(self) -> None:
env0 = gym.make("Swimmer-v4")
env1 = make_gym(
"Swimmer-v4", gym_reset_return_info=True, max_episode_steps=1000
)
env1 = make_gym("Swimmer-v4")
self.run_space_check(env0, env1)
self.run_align_check(env0, env1, no_time_limit=True)
env0 = gym.make(
Expand All @@ -214,16 +188,12 @@ def test_swimmer(self) -> None:
env1 = make_gym(
"Swimmer-v4",
exclude_current_positions_from_observation=False,
gym_reset_return_info=True,
max_episode_steps=1000,
)
self.run_space_check(env0, env1)

def test_walker2d(self) -> None:
env0 = gym.make("Walker2d-v4")
env1 = make_gym(
"Walker2d-v4", gym_reset_return_info=True, max_episode_steps=1000
)
env1 = make_gym("Walker2d-v4")
self.run_space_check(env0, env1)
self.run_align_check(env0, env1)
env0 = gym.make(
Expand All @@ -236,7 +206,6 @@ def test_walker2d(self) -> None:
terminate_when_unhealthy=False,
exclude_current_positions_from_observation=False,
max_episode_steps=100,
gym_reset_return_info=True,
)
self.run_space_check(env0, env1)
self.run_align_check(env0, env1, no_time_limit=True)
Expand Down
14 changes: 8 additions & 6 deletions envpool/vizdoom/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,18 @@ py_library(
deps = ["//envpool/python:api"],
)

py_library(
name = "vizdoom_registration",
srcs = ["registration.py"],
deps = ["//envpool:registration"],
)

py_test(
name = "vizdoom_test",
srcs = ["vizdoom_test.py"],
deps = [
":vizdoom",
":vizdoom_registration",
requirement("numpy"),
requirement("absl-py"),
requirement("opencv-python-headless"),
Expand All @@ -84,16 +91,11 @@ py_test(
data = [":gen_pretrain_weight"],
deps = [
":vizdoom",
":vizdoom_registration",
"//envpool/atari:atari_network",
requirement("numpy"),
requirement("absl-py"),
requirement("tianshou"),
requirement("torch"),
],
)

py_library(
name = "vizdoom_registration",
srcs = ["registration.py"],
deps = ["//envpool:registration"],
)
9 changes: 5 additions & 4 deletions envpool/vizdoom/vizdoom_pretrain_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
from tianshou.policy import C51Policy

from envpool.atari.atari_network import C51
from envpool.vizdoom import VizdoomEnvSpec, VizdoomGymEnvPool

import envpool.vizdoom.registration # noqa: F401
from envpool.registration import make_gym

# try:
# import cv2
Expand All @@ -47,6 +49,7 @@ def eval_c51(
cfg_path: Optional[str] = None,
reward_config: Optional[dict] = None,
) -> Tuple[np.ndarray, np.ndarray]:
task_id = "".join([g.capitalize() for g in task.split("_")]) + "-v1"
kwargs = {
"num_envs": num_envs,
"seed": seed,
Expand All @@ -59,9 +62,7 @@ def eval_c51(
kwargs.update(cfg_path=cfg_path)
if reward_config is not None:
kwargs.update(reward_config=reward_config)
env = VizdoomGymEnvPool(
VizdoomEnvSpec(VizdoomEnvSpec.gen_config(**kwargs))
)
env = make_gym(task_id, **kwargs)

state_shape = env.observation_space.shape
action_shape = env.action_space.n
Expand Down
Loading