diff --git a/zoo/classic_control/cartpole/entry/cartpole_eval.py b/zoo/classic_control/cartpole/entry/cartpole_eval.py index b4e3c554e..e574a80fa 100644 --- a/zoo/classic_control/cartpole/entry/cartpole_eval.py +++ b/zoo/classic_control/cartpole/entry/cartpole_eval.py @@ -27,7 +27,8 @@ create_config.env_manager.type = 'base' # Visualization requires the 'type' to be set as base main_config.env.evaluator_env_num = 1 # Visualization requires the 'env_num' to be set as 1 main_config.env.n_evaluator_episode = total_test_episodes - main_config.env.replay_path = './video' + main_config.env.save_replay_gif = True + main_config.env.replay_path_gif = './cartpole_gif' for seed in seeds: """ diff --git a/zoo/classic_control/cartpole/envs/cartpole_lightzero_env.py b/zoo/classic_control/cartpole/envs/cartpole_lightzero_env.py index 62bceb6b6..30f2963c7 100644 --- a/zoo/classic_control/cartpole/envs/cartpole_lightzero_env.py +++ b/zoo/classic_control/cartpole/envs/cartpole_lightzero_env.py @@ -1,4 +1,5 @@ import copy +import os from datetime import datetime from typing import Union, Optional, Dict @@ -9,6 +10,8 @@ from ding.torch_utils import to_ndarray from ding.utils import ENV_REGISTRY from easydict import EasyDict +import matplotlib.pyplot as plt +from matplotlib import animation @ENV_REGISTRY.register('cartpole_lightzero') @@ -21,11 +24,12 @@ class CartPoleEnv(BaseEnv): """ config = dict( - # env_id (str): The name of the environment. + # env_id (str): The name of the CartPole environment. env_id="CartPole-v0", - # replay_path (str): The path to save the replay video. If None, the replay will not be saved. - # Only effective when env_manager.type is 'base'. - replay_path=None, + # save_replay_gif (bool): If True, saves the replay as a gif. + save_replay_gif=False, + # replay_path_gif (str or None): The path to save the gif replay. If None, gif will not be saved. + replay_path_gif=None, ) @classmethod @@ -36,12 +40,18 @@ def default_config(cls: type) -> EasyDict: def __init__(self, cfg: dict = {}) -> None: """ - Initialize the environment with a configuration dictionary. Sets up spaces for observations, actions, and rewards. + Initializes the CartPole environment with the given configuration. + + Args: + cfg (dict): Configuration dict that includes `env_id`, `save_replay_gif`, and `replay_path_gif`. """ self._cfg = cfg self._init_flag = False - self._continuous = False - self._replay_path = cfg.replay_path + self._replay_path_gif = cfg.get('replay_path_gif', None) + self._save_replay_gif = cfg.get('save_replay_gif', False) + self._save_replay_count = 0 + + # Define observation, action, and reward spaces. self._observation_space = gym.spaces.Box( low=np.array([-4.8, float("-inf"), -0.42, float("-inf")]), high=np.array([4.8, float("inf"), 0.42, float("inf")]), @@ -49,42 +59,29 @@ def __init__(self, cfg: dict = {}) -> None: dtype=np.float32 ) self._action_space = gym.spaces.Discrete(2) - self._action_space.seed(0) # default seed self._reward_space = gym.spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float32) def reset(self) -> Dict[str, np.ndarray]: """ - Reset the environment. If it hasn't been initialized yet, this method also handles that. It also handles seeding - if necessary. Returns the first observation. + Reset the environment and return the initial observation. + + Returns: + Dict[str, np.ndarray]: The initial observation from the environment. """ if not self._init_flag: - self._env = gym.make('CartPole-v0', render_mode="rgb_array") - if self._replay_path is not None: - timestamp = datetime.now().strftime("%Y%m%d%H%M%S") - video_name = f'{self._env.spec.id}-video-{timestamp}' - self._env = gym.wrappers.RecordVideo( - self._env, - video_folder=self._replay_path, - episode_trigger=lambda episode_id: True, - name_prefix=video_name - ) + self._env = gym.make(self._cfg['env_id'], render_mode="rgb_array") + # If replay saving as GIF is enabled, prepare for recording. + if self._save_replay_gif: + self._frames = [] if hasattr(self._cfg, 'obs_plus_prev_action_reward') and self._cfg.obs_plus_prev_action_reward: self._env = ObsPlusPrevActRewWrapper(self._env) self._init_flag = True - if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed: - np_seed = 100 * np.random.randint(1, 1000) - self._seed = self._seed + np_seed - self._action_space.seed(self._seed) - obs, _ = self._env.reset(seed=self._seed) - elif hasattr(self, '_seed'): - self._action_space.seed(self._seed) - obs, _ = self._env.reset(seed=self._seed) - else: - obs, _ = self._env.reset() - self._observation_space = self._env.observation_space + + obs, _ = self._env.reset() self._eval_episode_return = 0 obs = to_ndarray(obs) + # Initialize the action mask and return the observation. action_mask = np.ones(self.action_space.n, 'int8') obs = {'observation': obs, 'action_mask': action_mask, 'to_play': -1} @@ -110,23 +107,64 @@ def step(self, action: Union[int, np.ndarray]) -> BaseEnvTimestep: - Observations are returned in a dictionary format containing 'observation', 'action_mask', and 'to_play'. """ if isinstance(action, np.ndarray) and action.shape == (1,): - action = action.squeeze() # 0-dim array + action = action.squeeze() # Handle 0-dim array obs, rew, terminated, truncated, info = self._env.step(action) done = terminated or truncated + # Record the frame if replay saving as GIF is enabled. + if self._save_replay_gif: + self._frames.append(self._env.render()) + + # Update rewards and check if the episode is done. self._eval_episode_return += rew if done: info['eval_episode_return'] = self._eval_episode_return + if self._save_replay_gif: + self.save_gif_replay() action_mask = np.ones(self.action_space.n, 'int8') obs = {'observation': obs, 'action_mask': action_mask, 'to_play': -1} return BaseEnvTimestep(obs, rew, done, info) + def save_gif_replay(self) -> None: + """ + Save the recorded frames as a GIF replay. + """ + if not os.path.exists(self._replay_path_gif): + os.makedirs(self._replay_path_gif) + + timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + gif_filename = f'{self._cfg["env_id"]}_episode_{self._save_replay_count}_{timestamp}.gif' + gif_path = os.path.join(self._replay_path_gif, gif_filename) + + # Create the GIF using the recorded frames. + self.display_frames_as_gif(self._frames, gif_path) + print(f"Replay saved as {gif_path}") + self._save_replay_count += 1 + + @staticmethod + def display_frames_as_gif(frames: list, path: str) -> None: + """ + Convert a list of frames into a GIF and save it. + + Args: + frames (list): List of frames to be saved as a GIF. + path (str): Path where the GIF will be saved. + """ + patch = plt.imshow(frames[0]) + plt.axis('off') + + def animate(i): + patch.set_data(frames[i]) + + anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=50) + anim.save(path, writer='imagemagick', fps=20) + def close(self) -> None: """ - Close the environment, and set the initialization flag to False. + Close the environment and reset the initialization flag. """ if self._init_flag: self._env.close() @@ -134,24 +172,23 @@ def close(self) -> None: def seed(self, seed: int, dynamic_seed: bool = True) -> None: """ - Set the seed for the environment's random number generator. Can handle both static and dynamic seeding. + Set the random seed for the environment. + + Args: + seed (int): The seed value. + dynamic_seed (bool): Whether to use dynamic seed generation. """ self._seed = seed self._dynamic_seed = dynamic_seed np.random.seed(self._seed) - def enable_save_replay(self, replay_path: Optional[str] = None) -> None: - """ - Enable the saving of replay videos. If no replay path is given, a default is used. + def random_action(self) -> np.ndarray: """ - if replay_path is None: - replay_path = './video' - self._replay_path = replay_path + Generate a random action from the action space. - def random_action(self) -> np.ndarray: + Returns: + np.ndarray: A random action. """ - Generate a random action using the action space's sample method. Returns a numpy array containing the action. - """ random_action = self.action_space.sample() random_action = to_ndarray([random_action], dtype=np.int64) return random_action @@ -159,21 +196,21 @@ def random_action(self) -> np.ndarray: @property def observation_space(self) -> gym.spaces.Space: """ - Property to access the observation space of the environment. + Returns the observation space of the environment. """ return self._observation_space @property def action_space(self) -> gym.spaces.Space: """ - Property to access the action space of the environment. + Returns the action space of the environment. """ return self._action_space @property def reward_space(self) -> gym.spaces.Space: """ - Property to access the reward space of the environment. + Returns the reward space of the environment. """ return self._reward_space @@ -181,4 +218,4 @@ def __repr__(self) -> str: """ String representation of the environment. """ - return "LightZero CartPole Env" + return f"LightZero CartPole Env({self._cfg['env_id']})" \ No newline at end of file