From 545fb10d3327746a9caa08374254b9d3a3e3d232 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <48008469+puyuan1996@users.noreply.github.com> Date: Wed, 29 Nov 2023 22:19:34 +0800 Subject: [PATCH] polish(pu): refine comments and render_eval configs for various common environments (#154) * polish(pu): polish comments and render in cartpole, pendulum, atari, lunarlander * polish(pu): polish comments and render in tictactoe, gomoku, connect4, 2048 --- lzero/policy/alphazero.py | 14 ++ zoo/atari/entry/atari_eval.py | 67 +++++--- zoo/atari/envs/atari_lightzero_env.py | 152 +++++++++++------ zoo/atari/envs/atari_wrappers.py | 74 ++++++--- .../connect4_alphazero_bot_mode_config.py | 19 +++ zoo/board_games/connect4/envs/connect4_env.py | 10 +- .../connect4/eval/connect4_alphazero_eval.py | 31 +++- .../connect4/eval/connect4_muzero_eval.py | 60 +++++++ .../gomoku/entry/gomoku_alphazero_eval.py | 39 +++-- zoo/board_games/gomoku/envs/gomoku_env.py | 30 +++- .../entry/tictactoe_alphazero_eval.py | 36 +++- .../tictactoe/entry/tictactoe_muzero_eval.py | 35 +++- .../tictactoe/envs/tictactoe_env.py | 139 +++++++++++++++- .../lunarlander/entry/lunarlander_eval.py | 72 ++++++-- .../envs/lunarlander_cont_disc_env.py | 154 +++++++----------- zoo/box2d/lunarlander/envs/lunarlander_env.py | 124 ++++++++------ .../lunarlander/envs/test_lunarlander_env.py | 10 +- .../cartpole/entry/cartpole_eval.py | 48 ++++-- .../cartpole/envs/cartpole_lightzero_env.py | 95 ++++++++--- .../pendulum/entry/pendulum_eval.py | 9 +- .../pendulum/envs/pendulum_lightzero_env.py | 103 +++++++----- zoo/game_2048/entry/2048_eval.py | 35 ++-- zoo/game_2048/envs/game_2048_env.py | 7 +- 23 files changed, 963 insertions(+), 400 deletions(-) create mode 100644 zoo/board_games/connect4/eval/connect4_muzero_eval.py diff --git a/lzero/policy/alphazero.py b/lzero/policy/alphazero.py index 3967eb92f..38fd9d9c3 100644 --- a/lzero/policy/alphazero.py +++ b/lzero/policy/alphazero.py @@ -36,6 +36,11 @@ class AlphaZeroPolicy(Policy): # (int) The number of channels of hidden states in AlphaZero model. num_channels=32, ), + # (bool) Whether to enable the sampled-based algorithm (e.g. Sampled EfficientZero) + # this variable is used in ``collector``. + sampled_algo=False, + # (bool) Whether to enable the gumbel-based algorithm (e.g. Gumbel Muzero) + gumbel_algo=False, # (bool) Whether to use multi-gpu training. multi_gpu=False, # (bool) Whether to use cuda for network. @@ -350,6 +355,15 @@ def _get_simulation_env(self): else: raise NotImplementedError self.simulate_env = GomokuEnv(gomoku_alphazero_config.env) + elif self._cfg.simulation_env_name == 'connect4': + from zoo.board_games.connect4.envs.connect4_env import Connect4Env + if self._cfg.simulation_env_config_type == 'play_with_bot': + from zoo.board_games.connect4.config.connect4_alphazero_bot_mode_config import connect4_alphazero_config + elif self._cfg.simulation_env_config_type == 'self_play': + from zoo.board_games.connect4.config.connect4_alphazero_sp_mode_config import connect4_alphazero_config + else: + raise NotImplementedError + self.simulate_env = Connect4Env(connect4_alphazero_config.env) else: raise NotImplementedError diff --git a/zoo/atari/entry/atari_eval.py b/zoo/atari/entry/atari_eval.py index 824ffd1ab..eda80cb23 100644 --- a/zoo/atari/entry/atari_eval.py +++ b/zoo/atari/entry/atari_eval.py @@ -1,31 +1,56 @@ -# According to the model you want to evaluate, import the corresponding config. from lzero.entry import eval_muzero import numpy as np if __name__ == "__main__": - """ - model_path (:obj:`Optional[str]`): The pretrained model path, which should - point to the ckpt file of the pretrained model, and an absolute path is recommended. - In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``. """ - # Take the config of sampled efficientzero as an example - from zoo.atari.config.atari_sampled_efficientzero_config import main_config, create_config + Overview: + Main script to evaluate the MuZero model on Atari games. The script will loop over multiple seeds, + evaluating a certain number of episodes per seed. Results are aggregated and printed. - model_path = "/path/ckpt/ckpt_best.pth.tar" + Variables: + - model_path (:obj:`Optional[str]`): The pretrained model path, pointing to the ckpt file of the pretrained model. + The path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``. + - seeds (:obj:`List[int]`): List of seeds to use for the evaluations. + - num_episodes_each_seed (:obj:`int`): Number of episodes to evaluate for each seed. + - total_test_episodes (:obj:`int`): Total number of test episodes, calculated as num_episodes_each_seed * len(seeds). + - returns_mean_seeds (:obj:`np.array`): Array of mean return values for each seed. + - returns_seeds (:obj:`np.array`): Array of all return values for each seed. + """ + # Importing the necessary configuration files from the atari muzero configuration in the zoo directory. + from zoo.atari.config.atari_muzero_config import main_config, create_config - returns_mean_seeds = [] - returns_seeds = [] + # model_path is the path to the trained MuZero model checkpoint. + # If no path is provided, the script will use the default model. + model_path = None + + # seeds is a list of seed values for the random number generator, used to initialize the environment. seeds = [0] + # num_episodes_each_seed is the number of episodes to run for each seed. num_episodes_each_seed = 1 + # total_test_episodes is the total number of test episodes, calculated as the product of the number of seeds and the number of episodes per seed total_test_episodes = num_episodes_each_seed * len(seeds) - create_config.env_manager.type = 'base' # Visualization requires the 'type' to be set as base - main_config.env.evaluator_env_num = 1 # Visualization requires the 'env_num' to be set as 1 + + # Setting the type of the environment manager to 'base' for the visualization purposes. + create_config.env_manager.type = 'base' + # The number of environments to evaluate concurrently. Set to 1 for visualization purposes. + main_config.env.evaluator_env_num = 1 + # The total number of evaluation episodes that should be run. main_config.env.n_evaluator_episode = total_test_episodes - main_config.env.render_mode_human = True # Whether to enable real-time rendering - main_config.env.save_video = True # Whether to save the video, if save the video render_mode_human must to be True - main_config.env.save_path = '../config/' - main_config.env.eval_max_episode_steps = int(1e3) # Adjust according to different environments + # A boolean flag indicating whether to render the environments in real-time. + main_config.env.render_mode_human = False + # A boolean flag indicating whether to save the video of the environment. + main_config.env.save_replay = True + # The path where the recorded video will be saved. + main_config.env.save_path = './video' + # The maximum number of steps for each episode during evaluation. This may need to be adjusted based on the specific characteristics of the environment. + main_config.env.eval_max_episode_steps = int(20) + + # These lists will store the mean and total rewards for each seed. + returns_mean_seeds = [] + returns_seeds = [] + + # The main evaluation loop. For each seed, the MuZero model is evaluated and the mean and total rewards are recorded. for seed in seeds: returns_mean, returns = eval_muzero( [main_config, create_config], @@ -38,11 +63,13 @@ returns_mean_seeds.append(returns_mean) returns_seeds.append(returns) + # Convert the list of mean and total rewards into numpy arrays for easier statistical analysis. returns_mean_seeds = np.array(returns_mean_seeds) returns_seeds = np.array(returns_seeds) + # Printing the evaluation results. The average reward and the total reward for each seed are displayed, followed by the mean reward across all seeds. print("=" * 20) - print(f'We eval total {len(seeds)} seeds. In each seed, we eval {num_episodes_each_seed} episodes.') - print(f'In seeds {seeds}, returns_mean_seeds is {returns_mean_seeds}, returns is {returns_seeds}') - print('In all seeds, reward_mean:', returns_mean_seeds.mean()) - print("=" * 20) + print(f"We evaluated a total of {len(seeds)} seeds. For each seed, we evaluated {num_episodes_each_seed} episode(s).") + print(f"For seeds {seeds}, the mean returns are {returns_mean_seeds}, and the returns are {returns_seeds}.") + print("Across all seeds, the mean reward is:", returns_mean_seeds.mean()) + print("=" * 20) \ No newline at end of file diff --git a/zoo/atari/envs/atari_lightzero_env.py b/zoo/atari/envs/atari_lightzero_env.py index 893c37e71..177c07625 100644 --- a/zoo/atari/envs/atari_lightzero_env.py +++ b/zoo/atari/envs/atari_lightzero_env.py @@ -1,6 +1,6 @@ import copy import sys -from typing import List +from typing import List, Any import gym import numpy as np @@ -14,56 +14,108 @@ @ENV_REGISTRY.register('atari_lightzero') class AtariLightZeroEnv(BaseEnv): + """ + Overview: + AtariLightZeroEnv is a derived class from BaseEnv and represents the environment for the Atari LightZero game. + This class provides the necessary interfaces to interact with the environment, including reset, step, seed, + close, etc. and manages the environment's properties such as observation_space, action_space, and reward_space. + Properties: + cfg, _init_flag, channel_last, clip_rewards, episode_life, _env, _observation_space, _action_space, + _reward_space, obs, _eval_episode_return, has_reset, _seed, _dynamic_seed + """ config = dict( + # (int) The number of environment instances used for data collection. collector_env_num=8, + # (int) The number of environment instances used for evaluator. evaluator_env_num=3, + # (int) The number of episodes to evaluate during each evaluation period. n_evaluator_episode=3, + # (str) The name of the Atari game environment. env_name='PongNoFrameskip-v4', + # (str) The type of the environment, here it's Atari. env_type='Atari', + # (tuple) The shape of the observation space, which is a stacked frame of 4 images each of 96x96 pixels. obs_shape=(4, 96, 96), + # (int) The maximum number of steps in each episode during data collection. collect_max_episode_steps=int(1.08e5), + # (int) The maximum number of steps in each episode during evaluation. eval_max_episode_steps=int(1.08e5), + # (bool) If True, the game is rendered in real-time. + render_mode_human=False, + # (bool) If True, a video of the game play is saved. + save_replay=False, + # (str) The path to save the video. + replay_path='./video', + # (bool) If set to True, the game screen is converted to grayscale, reducing the complexity of the observation space. gray_scale=True, + # (int) The number of frames to skip between each action. Higher values result in faster simulation. frame_skip=4, + # (bool) If True, the game ends when the agent loses a life, otherwise, the game only ends when all lives are lost. episode_life=True, + # (bool) If True, the rewards are clipped to a certain range, usually between -1 and 1, to reduce variance. clip_rewards=True, + # (bool) If True, the channels of the observation images are placed last (e.g., height, width, channels). channel_last=True, - render_mode_human=False, + # (bool) If True, the pixel values of the game frames are scaled down to the range [0, 1]. scale=True, + # (bool) If True, the game frames are preprocessed by cropping irrelevant parts and resizing to a smaller resolution. warp_frame=True, - save_video=False, + # (bool) If True, the game state is transformed into a string before being returned by the environment. transform2string=False, + # (bool) If True, additional wrappers for the game environment are used. game_wrapper=True, + # (dict) The configuration for the environment manager. If shared_memory is set to False, each environment instance + # runs in the same process as the trainer, otherwise, they run in separate processes. manager=dict(shared_memory=False, ), + # (int) The value of the cumulative reward at which the training stops. stop_value=int(1e6), ) @classmethod def default_config(cls: type) -> EasyDict: + """ + Overview: + Return the default configuration for the Atari LightZero environment. + Arguments: + - cls (:obj:`type`): The class AtariLightZeroEnv. + Returns: + - cfg (:obj:`EasyDict`): The default configuration dictionary. + """ cfg = EasyDict(copy.deepcopy(cls.config)) cfg.cfg_type = cls.__name__ + 'Dict' return cfg - def __init__(self, cfg=None): + def __init__(self, cfg: EasyDict) -> None: + """ + Overview: + Initialize the Atari LightZero environment with the given configuration. + Arguments: + - cfg (:obj:`EasyDict`): The configuration dictionary. + """ self.cfg = cfg self._init_flag = False self.channel_last = cfg.channel_last self.clip_rewards = cfg.clip_rewards self.episode_life = cfg.episode_life - def _make_env(self): - return wrap_lightzero(self.cfg, episode_life=self.cfg.episode_life, clip_rewards=self.cfg.clip_rewards) - - def reset(self): + def reset(self) -> dict: + """ + Overview: + Reset the environment and return the initial observation. + Returns: + - obs (:obj:`dict`): The initial observation after reset. + """ if not self._init_flag: - self._env = self._make_env() + # Create and return the wrapped environment for Atari LightZero. + self._env = wrap_lightzero(self.cfg, episode_life=self.cfg.episode_life, clip_rewards=self.cfg.clip_rewards) self._observation_space = self._env.env.observation_space self._action_space = self._env.env.action_space self._reward_space = gym.spaces.Box( - low=self._env.env.reward_range[0], high=self._env.env.reward_range[1], shape=(1, ), dtype=np.float32 + low=self._env.env.reward_range[0], high=self._env.env.reward_range[1], shape=(1,), dtype=np.float32 ) self._init_flag = True + if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed: np_seed = 100 * np.random.randint(1, 1000) self._env.env.seed(self._seed + np_seed) @@ -73,29 +125,19 @@ def reset(self): obs = self._env.reset() self.obs = to_ndarray(obs) self._eval_episode_return = 0. - self.has_reset = True obs = self.observe() - # obs.shape: 96,96,1 return obs - def observe(self): + def step(self, action: int) -> BaseEnvTimestep: """ Overview: - add action_mask to obs to adapt with MCTS alg.. + Execute the given action and return the resulting environment timestep. + Arguments: + - action (:obj:`int`): The action to be executed. + Returns: + - timestep (:obj:`BaseEnvTimestep`): The environment timestep after executing the action. """ - observation = self.obs - - if not self.channel_last: - # move the channel dim to the fist axis - # (96, 96, 3) -> (3, 96, 96) - observation = np.transpose(observation, (2, 0, 1)) - - action_mask = np.ones(self._action_space.n, 'int8') - return {'observation': observation, 'action_mask': action_mask, 'to_play': -1} - - def step(self, action): obs, reward, done, info = self._env.step(action) - # self._env.render() self.obs = to_ndarray(obs) self.reward = np.array(reward).astype(np.float32) self._eval_episode_return += self.reward @@ -105,6 +147,23 @@ def step(self, action): return BaseEnvTimestep(observation, self.reward, done, info) + def observe(self) -> dict: + """ + Overview: + Return the current observation along with the action mask and to_play flag. + Returns: + - observation (:obj:`dict`): The dictionary containing current observation, action mask, and to_play flag. + """ + observation = self.obs + + if not self.channel_last: + # move the channel dim to the fist axis + # (96, 96, 3) -> (3, 96, 96) + observation = np.transpose(observation, (2, 0, 1)) + + action_mask = np.ones(self._action_space.n, 'int8') + return {'observation': observation, 'action_mask': action_mask, 'to_play': -1} + @property def legal_actions(self): return np.arange(self._action_space.n) @@ -113,52 +172,41 @@ def random_action(self): action_list = self.legal_actions return np.random.choice(action_list) - def render(self, mode='human'): - self._env.render() - - def human_to_action(self): - """ - Overview: - For multiplayer games, ask the user for a legal action - and return the corresponding action number. - Returns: - An integer from the action space. - """ - while True: - try: - print(f"Current available actions for the player are:{self.legal_actions}") - choice = int(input(f"Enter the index of next action: ")) - if choice in self.legal_actions: - break - else: - print("Wrong input, try again") - except KeyboardInterrupt: - print("exit") - sys.exit(0) - except Exception as e: - print("Wrong input, try again") - return choice - def close(self) -> None: + """ + Close the environment, and set the initialization flag to False. + """ if self._init_flag: self._env.close() self._init_flag = False def seed(self, seed: int, dynamic_seed: bool = True) -> None: + """ + Set the seed for the environment's random number generator. Can handle both static and dynamic seeding. + """ self._seed = seed self._dynamic_seed = dynamic_seed np.random.seed(self._seed) @property def observation_space(self) -> gym.spaces.Space: + """ + Property to access the observation space of the environment. + """ return self._observation_space @property def action_space(self) -> gym.spaces.Space: + """ + Property to access the action space of the environment. + """ return self._action_space @property def reward_space(self) -> gym.spaces.Space: + """ + Property to access the reward space of the environment. + """ return self._reward_space def __repr__(self) -> str: diff --git a/zoo/atari/envs/atari_wrappers.py b/zoo/atari/envs/atari_wrappers.py index d16ff28cb..4254f16b3 100644 --- a/zoo/atari/envs/atari_wrappers.py +++ b/zoo/atari/envs/atari_wrappers.py @@ -1,5 +1,6 @@ -# Borrow a lot from openai baselines: -# https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py +# Adapted from openai baselines: https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py +from datetime import datetime +from typing import Optional import cv2 import gym @@ -8,9 +9,11 @@ ScaledFloatFrameWrapper, \ ClipRewardWrapper, FrameStackWrapper from ding.utils.compression_helper import jpeg_data_compressor +from easydict import EasyDict from gym.wrappers import RecordVideo +# only for reference now def wrap_deepmind(env_id, episode_life=True, clip_rewards=True, frame_stack=4, scale=True, warp_frame=True): """Configure environment for DeepMind-style Atari. The observation is channel-first: (c, h, w) instead of (h, w, c). @@ -42,6 +45,7 @@ def wrap_deepmind(env_id, episode_life=True, clip_rewards=True, frame_stack=4, s return env +# only for reference now def wrap_deepmind_mr(env_id, episode_life=True, clip_rewards=True, frame_stack=4, scale=True, warp_frame=True): """Configure environment for DeepMind-style Atari. The observation is channel-first: (c, h, w) instead of (h, w, c). @@ -73,18 +77,17 @@ def wrap_deepmind_mr(env_id, episode_life=True, clip_rewards=True, frame_stack=4 return env -def wrap_lightzero(config, episode_life, clip_rewards): +def wrap_lightzero(config: EasyDict, episode_life: bool, clip_rewards: bool) -> gym.Env: """ Overview: Configure environment for MuZero-style Atari. The observation is channel-first: (c, h, w) instead of (h, w, c). Arguments: - - config (:obj:`Dict`): Dict containing configuration. - - wrap_frame (:obj:`bool`): - - save_video (:obj:`bool`): - - save_path (:obj:`bool`): + - config (:obj:`Dict`): Dict containing configuration parameters for the environment. + - episode_life (:obj:`bool`): If True, the agent starts with a set number of lives and loses them during the game. + - clip_rewards (:obj:`bool`): If True, the rewards are clipped to a certain range. Return: - - the wrapped atari environment. + - env (:obj:`gym.Env`): The wrapped Atari environment with the given configurations. """ if config.render_mode_human: env = gym.make(config.env_name, render_mode='human') @@ -103,13 +106,14 @@ def wrap_lightzero(config, episode_life, clip_rewards): env = ScaledFloatFrameWrapper(env) if clip_rewards: env = ClipRewardWrapper(env) - if config.save_video: - import random, string + if config.save_replay: + timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + video_name = f'{env.spec.id}-video-{timestamp}' env = RecordVideo( env, - video_folder=config.save_path, + video_folder=config.replay_path, episode_trigger=lambda episode_id: True, - name_prefix='rl-video-{}'.format(''.join(random.choice(string.ascii_lowercase) for i in range(5))), + name_prefix=video_name ) env = JpegWrapper(env, transform2string=config.transform2string) @@ -120,8 +124,17 @@ def wrap_lightzero(config, episode_life, clip_rewards): class TimeLimit(gym.Wrapper): + """ + Overview: + A wrapper that limits the maximum number of steps in an episode. + """ - def __init__(self, env, max_episode_steps=None): + def __init__(self, env: gym.Env, max_episode_steps: Optional[int] = None): + """ + Arguments: + - env (:obj:`gym.Env`): The environment to wrap. + - max_episode_steps (:obj:`Optional[int]`): Maximum number of steps per episode. If None, no limit is applied. + """ super(TimeLimit, self).__init__(env) self._max_episode_steps = max_episode_steps self._elapsed_steps = 0 @@ -140,12 +153,20 @@ def reset(self, **kwargs): class WarpFrame(gym.ObservationWrapper): + """ + Overview: + A wrapper that warps frames to 84x84 as done in the Nature paper and later work. + """ - def __init__(self, env, width=84, height=84, grayscale=True, dict_space_key=None): + def __init__(self, env: gym.Env, width: int = 84, height: int = 84, grayscale: bool = True, + dict_space_key: Optional[str] = None): """ - Warp frames to 84x84 as done in the Nature paper and later work. - If the environment uses dictionary observations, `dict_space_key` can be specified which indicates which - observation should be warped. + Arguments: + - env (:obj:`gym.Env`): The environment to wrap. + - width (:obj:`int`): The width to which the frames are resized. + - height (:obj:`int`): The height to which the frames are resized. + - grayscale (:obj:`bool`): If True, convert frames to grayscale. + - dict_space_key (:obj:`Optional[str]`): If specified, indicates which observation should be warped. """ super().__init__(env) self._width = width @@ -192,10 +213,16 @@ def observation(self, obs): class JpegWrapper(gym.Wrapper): + """ + Overview: + A wrapper that converts the observation into a string to save memory. + """ - def __init__(self, env, transform2string=True): + def __init__(self, env: gym.Env, transform2string: bool = True): """ - Overview: convert the observation into string to save memory + Arguments: + - env (:obj:`gym.Env`): The environment to wrap. + - transform2string (:obj:`bool`): If True, transform the observations to string. """ super().__init__(env) self.transform2string = transform2string @@ -218,10 +245,15 @@ def reset(self, **kwargs): class GameWrapper(gym.Wrapper): + """ + Overview: + A wrapper to adapt the environment to the game interface. + """ - def __init__(self, env): + def __init__(self, env: gym.Env): """ - Overview: warp env to adapt the game interface + Arguments: + - env (:obj:`gym.Env`): The environment to wrap. """ super().__init__(env) diff --git a/zoo/board_games/connect4/config/connect4_alphazero_bot_mode_config.py b/zoo/board_games/connect4/config/connect4_alphazero_bot_mode_config.py index 60a8e299c..54c817768 100644 --- a/zoo/board_games/connect4/config/connect4_alphazero_bot_mode_config.py +++ b/zoo/board_games/connect4/config/connect4_alphazero_bot_mode_config.py @@ -11,6 +11,8 @@ batch_size = 256 max_env_step = int(1e6) model_path = None +mcts_ctree = False + # ============================================================== # end of the most frequently changed config specified by the user # ============================================================== @@ -19,12 +21,29 @@ env=dict( battle_mode='play_with_bot_mode', bot_action_type='rule', + channel_last=False, collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, n_evaluator_episode=evaluator_env_num, manager=dict(shared_memory=False, ), + # ============================================================== + # for the creation of simulation env + agent_vs_human=False, + prob_random_agent=0, + prob_expert_agent=0, + scale=True, + screen_scaling=9, + render_mode=None, + alphazero_mcts_ctree=mcts_ctree, + # ============================================================== ), policy=dict( + mcts_ctree=mcts_ctree, + # ============================================================== + # for the creation of simulation env + simulation_env_name='connect4', + simulation_env_config_type='play_with_bot', + # ============================================================== model=dict( observation_shape=(3, 6, 7), action_space_size=7, diff --git a/zoo/board_games/connect4/envs/connect4_env.py b/zoo/board_games/connect4/envs/connect4_env.py index 9b73f9291..88e463d79 100644 --- a/zoo/board_games/connect4/envs/connect4_env.py +++ b/zoo/board_games/connect4/envs/connect4_env.py @@ -62,6 +62,8 @@ class Connect4Env(BaseEnv): # (str) The render mode. Options are 'None', 'state_realtime_mode', 'image_realtime_mode' or 'image_savefile_mode'. # If None, then the game will not be rendered. render_mode=None, + # (str) The suffix of the replay name. + replay_path='./video', # (str) The type of the bot of the environment. bot_action_type='rule', # (bool) Whether to let human to play with the agent when evaluating. If False, then use the bot to evaluate the agent. @@ -101,7 +103,7 @@ def __init__(self, cfg: dict = None) -> None: # options = {None, 'state_realtime_mode', 'image_realtime_mode', 'image_savefile_mode'} self.render_mode = cfg.render_mode self.replay_name_suffix = "test" - self.replay_path = None + self.replay_path = cfg.replay_path self.replay_format = 'gif' self.screen = None self.frames = [] @@ -473,9 +475,11 @@ def save_render_output(self, replay_name_suffix: str = '', replay_path: str = No """ # At the end of the episode, save the frames. if replay_path is None: - filename = f'game_connect4_{replay_name_suffix}.{format}' + filename = f'connect4_{replay_name_suffix}.{format}' else: - filename = f'{replay_path}.{format}' + if not os.path.exists(replay_path): + os.makedirs(replay_path) + filename = replay_path+f'/connect4_{replay_name_suffix}.{format}' if format == 'gif': # Save frames as a GIF with a duration of 0.1 seconds per frame. diff --git a/zoo/board_games/connect4/eval/connect4_alphazero_eval.py b/zoo/board_games/connect4/eval/connect4_alphazero_eval.py index d963bc386..720b00c83 100644 --- a/zoo/board_games/connect4/eval/connect4_alphazero_eval.py +++ b/zoo/board_games/connect4/eval/connect4_alphazero_eval.py @@ -4,18 +4,31 @@ from zoo.board_games.connect4.config.connect4_alphazero_bot_mode_config import main_config, create_config if __name__ == '__main__': - """ - model_path (:obj:`Optional[str]`): The pretrained model path, which should - point to the ckpt file of the pretrained model, and an absolute path is recommended. - In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``. """ + Entry point for the evaluation of the AlphaZero model on the Connect4 environment. + + Variables: + - model_path (:obj:`Optional[str]`): The pretrained model path, which should point to the ckpt file of the + pretrained model. An absolute path is recommended. In LightZero, the path is usually something like + ``exp_name/ckpt/ckpt_best.pth.tar``. + - returns_mean_seeds (:obj:`List[float]`): List to store the mean returns for each seed. + - returns_seeds (:obj:`List[float]`): List to store the returns for each seed. + - seeds (:obj:`List[int]`): List of seeds for the environment. + - num_episodes_each_seed (:obj:`int`): Number of episodes to run for each seed. + - total_test_episodes (:obj:`int`): Total number of test episodes, computed as the product of the number of + seeds and the number of episodes per seed. + """ + # model_path = './ckpt/ckpt_best.pth.tar' model_path = None seeds = [0] num_episodes_each_seed = 1 # If True, you can play with the agent. + # main_config.env.agent_vs_human = True main_config.env.agent_vs_human = False - # main_config.env.render_mode = 'image_savefile_mode' - main_config.env.render_mode = 'state_realtime_mode' + # main_config.env.render_mode = 'image_realtime_mode' + main_config.env.render_mode = 'image_savefile_mode' + main_config.env.replay_path = './video' + main_config.policy.mcts.num_simulations = 10 main_config.env.prob_random_action_in_bot = 0. main_config.env.bot_action_type = 'rule' @@ -40,9 +53,9 @@ returns_seeds = np.array(returns_seeds) print("=" * 20) - print(f'We eval total {len(seeds)} seeds. In each seed, we eval {num_episodes_each_seed} episodes.') - print(f'In seeds {seeds}, returns_mean_seeds is {returns_mean_seeds}, returns is {returns_seeds}') - print('In all seeds, reward_mean:', returns_mean_seeds.mean(), end='. ') + print(f"We evaluated a total of {len(seeds)} seeds. For each seed, we evaluated {num_episodes_each_seed} episode(s).") + print(f"For seeds {seeds}, the mean returns are {returns_mean_seeds}, and the returns are {returns_seeds}.") + print("Across all seeds, the mean reward is:", returns_mean_seeds.mean()) print( f'win rate: {len(np.where(returns_seeds == 1.)[0]) / total_test_episodes}, draw rate: {len(np.where(returns_seeds == 0.)[0]) / total_test_episodes}, lose rate: {len(np.where(returns_seeds == -1.)[0]) / total_test_episodes}' ) diff --git a/zoo/board_games/connect4/eval/connect4_muzero_eval.py b/zoo/board_games/connect4/eval/connect4_muzero_eval.py new file mode 100644 index 000000000..084ca3ba5 --- /dev/null +++ b/zoo/board_games/connect4/eval/connect4_muzero_eval.py @@ -0,0 +1,60 @@ +from zoo.board_games.connect4.config.connect4_muzero_bot_mode_config import main_config, create_config +from lzero.entry import eval_muzero +import numpy as np + +if __name__ == '__main__': + """ + Entry point for the evaluation of the MuZero model on the Connect4 environment. + + Variables: + - model_path (:obj:`Optional[str]`): The pretrained model path, which should point to the ckpt file of the + pretrained model. An absolute path is recommended. In LightZero, the path is usually something like + ``exp_name/ckpt/ckpt_best.pth.tar``. + - returns_mean_seeds (:obj:`List[float]`): List to store the mean returns for each seed. + - returns_seeds (:obj:`List[float]`): List to store the returns for each seed. + - seeds (:obj:`List[int]`): List of seeds for the environment. + - num_episodes_each_seed (:obj:`int`): Number of episodes to run for each seed. + - total_test_episodes (:obj:`int`): Total number of test episodes, computed as the product of the number of + seeds and the number of episodes per seed. + """ + # model_path = './ckpt/ckpt_best.pth.tar' + model_path = None + seeds = [0] + num_episodes_each_seed = 1 + # If True, you can play with the agent. + # main_config.env.agent_vs_human = True + main_config.env.agent_vs_human = False + # main_config.env.render_mode = 'image_realtime_mode' + main_config.env.render_mode = 'image_savefile_mode' + main_config.env.replay_path = './video' + + main_config.env.prob_random_action_in_bot = 0. + main_config.env.bot_action_type = 'rule' + create_config.env_manager.type = 'base' + main_config.env.evaluator_env_num = 1 + main_config.env.n_evaluator_episode = 1 + total_test_episodes = num_episodes_each_seed * len(seeds) + returns_mean_seeds = [] + returns_seeds = [] + for seed in seeds: + returns_mean, returns = eval_muzero( + [main_config, create_config], + seed=seed, + num_episodes_each_seed=num_episodes_each_seed, + print_seed_details=True, + model_path=model_path + ) + returns_mean_seeds.append(returns_mean) + returns_seeds.append(returns) + + returns_mean_seeds = np.array(returns_mean_seeds) + returns_seeds = np.array(returns_seeds) + + print("=" * 20) + print(f"We evaluated a total of {len(seeds)} seeds. For each seed, we evaluated {num_episodes_each_seed} episode(s).") + print(f"For seeds {seeds}, the mean returns are {returns_mean_seeds}, and the returns are {returns_seeds}.") + print("Across all seeds, the mean reward is:", returns_mean_seeds.mean()) + print( + f'win rate: {len(np.where(returns_seeds == 1.)[0]) / total_test_episodes}, draw rate: {len(np.where(returns_seeds == 0.)[0]) / total_test_episodes}, lose rate: {len(np.where(returns_seeds == -1.)[0]) / total_test_episodes}' + ) + print("=" * 20) diff --git a/zoo/board_games/gomoku/entry/gomoku_alphazero_eval.py b/zoo/board_games/gomoku/entry/gomoku_alphazero_eval.py index e5ae336b0..1946b9694 100644 --- a/zoo/board_games/gomoku/entry/gomoku_alphazero_eval.py +++ b/zoo/board_games/gomoku/entry/gomoku_alphazero_eval.py @@ -3,18 +3,35 @@ import numpy as np if __name__ == '__main__': - """ - model_path (:obj:`Optional[str]`): The pretrained model path, which should - point to the ckpt file of the pretrained model, and an absolute path is recommended. - In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``. """ - model_path = './ckpt/ckpt_best.pth.tar' + Entry point for the evaluation of the AlphaZero model on the Gomoku environment. + + Variables: + - model_path (:obj:`Optional[str]`): The pretrained model path, which should point to the ckpt file of the + pretrained model. An absolute path is recommended. In LightZero, the path is usually something like + ``exp_name/ckpt/ckpt_best.pth.tar``. + - returns_mean_seeds (:obj:`List[float]`): List to store the mean returns for each seed. + - returns_seeds (:obj:`List[float]`): List to store the returns for each seed. + - seeds (:obj:`List[int]`): List of seeds for the environment. + - num_episodes_each_seed (:obj:`int`): Number of episodes to run for each seed. + - total_test_episodes (:obj:`int`): Total number of test episodes, computed as the product of the number of + seeds and the number of episodes per seed. + """ + # model_path = './ckpt/ckpt_best.pth.tar' + model_path = None + seeds = [0] - num_episodes_each_seed = 5 + num_episodes_each_seed = 1 # If True, you can play with the agent. - main_config.env.agent_vs_human = True - main_config.env.render_mode = 'image_realtime_mode' + # main_config.env.agent_vs_human = True + main_config.env.agent_vs_human = False + # main_config.env.render_mode = 'image_realtime_mode' + main_config.env.render_mode = 'image_savefile_mode' + main_config.env.replay_path = './video' + create_config.env_manager.type = 'base' + main_config.env.alphazero_mcts_ctree = False + main_config.policy.mcts_ctree = False main_config.env.evaluator_env_num = 1 main_config.env.n_evaluator_episode = 1 total_test_episodes = num_episodes_each_seed * len(seeds) @@ -35,9 +52,9 @@ returns_seeds = np.array(returns_seeds) print("=" * 20) - print(f'We eval total {len(seeds)} seeds. In each seed, we eval {num_episodes_each_seed} episodes.') - print(f'In seeds {seeds}, returns_mean_seeds is {returns_mean_seeds}, returns is {returns_seeds}') - print('In all seeds, reward_mean:', returns_mean_seeds.mean(), end='. ') + print(f"We evaluated a total of {len(seeds)} seeds. For each seed, we evaluated {num_episodes_each_seed} episode(s).") + print(f"For seeds {seeds}, the mean returns are {returns_mean_seeds}, and the returns are {returns_seeds}.") + print("Across all seeds, the mean reward is:", returns_mean_seeds.mean()) print( f'win rate: {len(np.where(returns_seeds == 1.)[0]) / total_test_episodes}, draw rate: {len(np.where(returns_seeds == 0.)[0]) / total_test_episodes}, lose rate: {len(np.where(returns_seeds == -1.)[0]) / total_test_episodes}' ) diff --git a/zoo/board_games/gomoku/envs/gomoku_env.py b/zoo/board_games/gomoku/envs/gomoku_env.py index a20f16654..44f24f3a8 100644 --- a/zoo/board_games/gomoku/envs/gomoku_env.py +++ b/zoo/board_games/gomoku/envs/gomoku_env.py @@ -54,6 +54,8 @@ class GomokuEnv(BaseEnv): # (str) The render mode. Options are 'None', 'state_realtime_mode', 'image_realtime_mode' or 'image_savefile_mode'. # If None, then the game will not be rendered. render_mode=None, + # (str) The suffix of the replay name. + replay_path='./video', # (float) The scale of the render screen. screen_scaling=9, # (bool) Whether to use the 'channel last' format for the observation space. If False, 'channel first' format is used. @@ -125,7 +127,7 @@ def __init__(self, cfg: dict = None): # options = {None, 'state_realtime_mode', 'image_realtime_mode', 'image_savefile_mode'} self.render_mode = cfg.render_mode self.replay_name_suffix = "test" - self.replay_path = None + self.replay_path = cfg.replay_path self.replay_format = 'gif' # 'mp4' # self.screen = None self.frames = [] @@ -259,8 +261,10 @@ def step(self, action): elif self.battle_mode == 'eval_mode': # player 1 battle with expert player 2 + self._env.render(self.render_mode) # player 1's turn timestep_player1 = self._player_step(action) + self._env.render(self.render_mode) if self.agent_vs_human: print('player 1 (agent): ' + self.action_to_string(action)) # Note: visualize self.render(mode="image_realtime_mode") @@ -278,6 +282,7 @@ def step(self, action): # bot_action = self.random_action() timestep_player2 = self._player_step(bot_action) + self._env.render(self.render_mode) if self.agent_vs_human: print('player 2 (human): ' + self.action_to_string(bot_action)) # Note: visualize self.render(mode="image_realtime_mode") @@ -315,14 +320,14 @@ def _player_step(self, action): """ self.current_player = self.to_play - # Render the new step. - # The following code is used to save the rendered images in both - # collect/eval step and the simulated mcts step. + # The following code will save the rendered images in both env step in collect/eval phase and the env step in + # simulated mcts. # if self.render_mode is not None: # self.render(self.render_mode) if done: info['eval_episode_return'] = reward + self._env.render(self.render_mode) if self.render_mode == 'image_savefile_mode': self.save_render_output(replay_name_suffix=self.replay_name_suffix, replay_path=self.replay_path, format=self.replay_format) @@ -578,6 +583,8 @@ def render(self, mode="state_realtime_mode"): - mode (str): Rendering mode, options are "state_realtime_mode", "image_realtime_mode", and "image_savefile_mode". """ + if mode is None: + return # Print the state of the board directly if mode == "state_realtime_mode": print(np.array(self.board).reshape(self.board_size, self.board_size)) @@ -607,7 +614,14 @@ def render(self, mode="state_realtime_mode"): # Save the current frame to the frames list. self.fig.canvas.draw() image = np.frombuffer(self.fig.canvas.tostring_rgb(), dtype='uint8') - image = image.reshape(self.fig.canvas.get_width_height()[::-1] + (3,)) + + # Get the width and height of the figure + width, height = self.fig.get_size_inches() * self.fig.get_dpi() + width = int(width) + height = int(height) + image = image.reshape(height, width, 3) + + # image = image.reshape(self.fig.canvas.get_width_height()[::-1] + (3,)) self.frames.append(image) def close(self): @@ -703,9 +717,11 @@ def save_render_output(self, replay_name_suffix: str = '', replay_path: str = No """ # At the end of the episode, save the frames. if replay_path is None: - filename = f'game_gomoku_{self.board_size}_{replay_name_suffix}.{format}' + filename = f'gomoku_{self.board_size}_{replay_name_suffix}.{format}' else: - filename = f'{replay_path}.{format}' + if not os.path.exists(replay_path): + os.makedirs(replay_path) + filename = replay_path+f'/gomoku_{self.board_size}_{replay_name_suffix}.{format}' if format == 'gif': # Save frames as a GIF with a duration of 0.1 seconds per frame. diff --git a/zoo/board_games/tictactoe/entry/tictactoe_alphazero_eval.py b/zoo/board_games/tictactoe/entry/tictactoe_alphazero_eval.py index 780cbe66f..1c6de7850 100644 --- a/zoo/board_games/tictactoe/entry/tictactoe_alphazero_eval.py +++ b/zoo/board_games/tictactoe/entry/tictactoe_alphazero_eval.py @@ -3,20 +3,37 @@ import numpy as np if __name__ == '__main__': - """ - model_path (:obj:`Optional[str]`): The pretrained model path, which should - point to the ckpt file of the pretrained model, and an absolute path is recommended. - In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``. """ - model_path = './ckpt/ckpt_best.pth.tar' + Entry point for the evaluation of the AlphaZero model on the TicTacToe environment. + + Variables: + - model_path (:obj:`Optional[str]`): The pretrained model path, which should point to the ckpt file of the + pretrained model. An absolute path is recommended. In LightZero, the path is usually something like + ``exp_name/ckpt/ckpt_best.pth.tar``. + - returns_mean_seeds (:obj:`List[float]`): List to store the mean returns for each seed. + - returns_seeds (:obj:`List[float]`): List to store the returns for each seed. + - seeds (:obj:`List[int]`): List of seeds for the environment. + - num_episodes_each_seed (:obj:`int`): Number of episodes to run for each seed. + - total_test_episodes (:obj:`int`): Total number of test episodes, computed as the product of the number of + seeds and the number of episodes per seed. + """ + # model_path = './ckpt/ckpt_best.pth.tar' + model_path = None seeds = [0] - num_episodes_each_seed = 5 + num_episodes_each_seed = 1 + + # Enable saving of replay as a gif, specify the path to save the replay gif + main_config.env.save_replay_gif = True + main_config.env.replay_path_gif = './video' + + main_config.policy.mcts_ctree = False # If True, you can play with the agent. main_config.env.agent_vs_human = False create_config.env_manager.type = 'base' main_config.env.evaluator_env_num = 1 main_config.env.n_evaluator_episode = 1 total_test_episodes = num_episodes_each_seed * len(seeds) + returns_mean_seeds = [] returns_seeds = [] for seed in seeds: @@ -33,10 +50,11 @@ returns_mean_seeds = np.array(returns_mean_seeds) returns_seeds = np.array(returns_seeds) + # Print evaluation results print("=" * 20) - print(f'We eval total {len(seeds)} seeds. In each seed, we eval {num_episodes_each_seed} episodes.') - print(f'In seeds {seeds}, returns_mean_seeds is {returns_mean_seeds}, returns is {returns_seeds}') - print('In all seeds, reward_mean:', returns_mean_seeds.mean(), end='. ') + print(f"We evaluated a total of {len(seeds)} seeds. For each seed, we evaluated {num_episodes_each_seed} episode(s).") + print(f"For seeds {seeds}, the mean returns are {returns_mean_seeds}, and the returns are {returns_seeds}.") + print("Across all seeds, the mean reward is:", returns_mean_seeds.mean()) print( f'win rate: {len(np.where(returns_seeds == 1.)[0]) / total_test_episodes}, draw rate: {len(np.where(returns_seeds == 0.)[0]) / total_test_episodes}, lose rate: {len(np.where(returns_seeds == -1.)[0]) / total_test_episodes}' ) diff --git a/zoo/board_games/tictactoe/entry/tictactoe_muzero_eval.py b/zoo/board_games/tictactoe/entry/tictactoe_muzero_eval.py index 35e6a1d38..f2bebb0bd 100644 --- a/zoo/board_games/tictactoe/entry/tictactoe_muzero_eval.py +++ b/zoo/board_games/tictactoe/entry/tictactoe_muzero_eval.py @@ -3,21 +3,37 @@ import numpy as np if __name__ == "__main__": - """ - model_path (:obj:`Optional[str]`): The pretrained model path, which should - point to the ckpt file of the pretrained model, and an absolute path is recommended. - In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``. """ + Entry point for the evaluation of the MuZero model on the TicTacToe environment. + + Variables: + - model_path (:obj:`Optional[str]`): The pretrained model path, which should point to the ckpt file of the + pretrained model. An absolute path is recommended. In LightZero, the path is usually something like + ``exp_name/ckpt/ckpt_best.pth.tar``. + - returns_mean_seeds (:obj:`List[float]`): List to store the mean returns for each seed. + - returns_seeds (:obj:`List[float]`): List to store the returns for each seed. + - seeds (:obj:`List[int]`): List of seeds for the environment. + - num_episodes_each_seed (:obj:`int`): Number of episodes to run for each seed. + - total_test_episodes (:obj:`int`): Total number of test episodes, computed as the product of the number of + seeds and the number of episodes per seed. + """ + + # model_path = "./ckpt/ckpt_best.pth.tar" + model_path = None - model_path = "./ckpt/ckpt_best.pth.tar" seeds = [0] - num_episodes_each_seed = 5 + num_episodes_each_seed = 1 # If True, you can play with the agent. main_config.env.agent_vs_human = False create_config.env_manager.type = 'base' main_config.env.evaluator_env_num = 1 main_config.env.n_evaluator_episode = 1 total_test_episodes = num_episodes_each_seed * len(seeds) + + # Enable saving of replay as a gif, specify the path to save the replay gif + main_config.env.save_replay_gif = True + main_config.env.replay_path_gif = './video' + returns_mean_seeds = [] returns_seeds = [] for seed in seeds: @@ -34,10 +50,11 @@ returns_mean_seeds = np.array(returns_mean_seeds) returns_seeds = np.array(returns_seeds) + # Print evaluation results print("=" * 20) - print(f'We eval total {len(seeds)} seeds. In each seed, we eval {num_episodes_each_seed} episodes.') - print(f'In seeds {seeds}, returns_mean_seeds is {returns_mean_seeds}, returns is {returns_seeds}') - print('In all seeds, reward_mean:', returns_mean_seeds.mean(), end='. ') + print(f"We evaluated a total of {len(seeds)} seeds. For each seed, we evaluated {num_episodes_each_seed} episode(s).") + print(f"For seeds {seeds}, the mean returns are {returns_mean_seeds}, and the returns are {returns_seeds}.") + print("Across all seeds, the mean reward is:", returns_mean_seeds.mean()) print( f'win rate: {len(np.where(returns_seeds == 1.)[0]) / total_test_episodes}, draw rate: {len(np.where(returns_seeds == 0.)[0]) / total_test_episodes}, lose rate: {len(np.where(returns_seeds == -1.)[0]) / total_test_episodes}' ) diff --git a/zoo/board_games/tictactoe/envs/tictactoe_env.py b/zoo/board_games/tictactoe/envs/tictactoe_env.py index df30e92dd..ffcd38b1c 100644 --- a/zoo/board_games/tictactoe/envs/tictactoe_env.py +++ b/zoo/board_games/tictactoe/envs/tictactoe_env.py @@ -1,18 +1,21 @@ import copy +import os import sys +from datetime import datetime from functools import lru_cache from typing import List import gym +import matplotlib.pyplot as plt import numpy as np from ding.envs.env.base_env import BaseEnv, BaseEnvTimestep from ding.utils.registry_factory import ENV_REGISTRY from ditk import logging from easydict import EasyDict +from zoo.board_games.tictactoe.envs.get_done_winner_cython import get_done_winner_cython +from zoo.board_games.tictactoe.envs.legal_actions_cython import legal_actions_cython from zoo.board_games.alphabeta_pruning_bot import AlphaBetaPruningBot -from zoo.board_games.tictactoe.envs.legal_actions_cython import legal_actions_cython -from zoo.board_games.tictactoe.envs.get_done_winner_cython import get_done_winner_cython @lru_cache(maxsize=512) @@ -78,6 +81,9 @@ def __init__(self, cfg=None): if 'alpha_beta_pruning' in self.bot_action_type: self.alpha_beta_pruning_player = AlphaBetaPruningBot(self, cfg, 'alpha_beta_pruning_player') self.alphazero_mcts_ctree = cfg.alphazero_mcts_ctree + self._replay_path_gif = cfg.replay_path_gif + self._save_replay_gif = cfg.save_replay_gif + self._save_replay_count = 0 @property def legal_actions(self): @@ -170,6 +176,9 @@ def reset(self, start_player_index=0, init_state=None, katago_policy_init=False, 'current_player_index': self.start_player_index, 'to_play': self.current_player } + if self._save_replay_gif: + self._frames = [] + return obs def reset_v2(self, start_player_index=0, init_state=None): @@ -196,7 +205,8 @@ def step(self, action): timestep = self._player_step(action) if timestep.done: # The eval_episode_return is calculated from Player 1's perspective。 - timestep.info['eval_episode_return'] = -timestep.reward if timestep.obs['to_play'] == 1 else timestep.reward + timestep.info['eval_episode_return'] = -timestep.reward if timestep.obs[ + 'to_play'] == 1 else timestep.reward return timestep elif self.battle_mode == 'play_with_bot_mode': # player 1 battle with expert player 2 @@ -228,12 +238,27 @@ def step(self, action): # player 1 battle with expert player 2 # player 1's turn + if self._save_replay_gif: + self._frames.append(self._env.render(mode='rgb_array')) timestep_player1 = self._player_step(action) # self.env.render() if timestep_player1.done: # NOTE: in eval_mode, we must set to_play as -1, because we don't consider the alternation between players. # And the to_play is used in MCTS. timestep_player1.obs['to_play'] = -1 + + if self._save_replay_gif: + if not os.path.exists(self._replay_path_gif): + os.makedirs(self._replay_path_gif) + timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + path = os.path.join( + self._replay_path_gif, + 'tictactoe_episode_{}_{}.gif'.format(self._save_replay_count, timestamp) + ) + self.display_frames_as_gif(self._frames, path) + print(f'save episode {self._save_replay_count} in {self._replay_path_gif}!') + self._save_replay_count += 1 + return timestep_player1 # player 2's turn @@ -242,7 +267,11 @@ def step(self, action): else: bot_action = self.bot_action() # print('player 2 (computer player): ' + self.action_to_string(bot_action)) + if self._save_replay_gif: + self._frames.append(self._env.render(mode='rgb_array')) timestep_player2 = self._player_step(bot_action) + if self._save_replay_gif: + self._frames.append(self._env.render(mode='rgb_array')) # the eval_episode_return is calculated from Player 1's perspective timestep_player2.info['eval_episode_return'] = -timestep_player2.reward timestep_player2 = timestep_player2._replace(reward=-timestep_player2.reward) @@ -252,9 +281,23 @@ def step(self, action): # And the to_play is used in MCTS. timestep.obs['to_play'] = -1 + if timestep_player2.done: + if self._save_replay_gif: + if not os.path.exists(self._replay_path_gif): + os.makedirs(self._replay_path_gif) + timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + path = os.path.join( + self._replay_path_gif, + 'tictactoe_episode_{}_{}.gif'.format(self._save_replay_count, timestamp) + ) + self.display_frames_as_gif(self._frames, path) + print(f'save episode {self._save_replay_count} in {self._replay_path_gif}!') + self._save_replay_count += 1 + return timestep def _player_step(self, action): + if action in self.legal_actions: row, col = self.action_to_coord(action) self.board[row, col] = self.current_player @@ -280,7 +323,6 @@ def _player_step(self, action): if done: info['eval_episode_return'] = reward # print('tictactoe one episode done: ', info) - action_mask = np.zeros(self.total_num_actions, 'int8') action_mask[self.legal_actions] = 1 obs = { @@ -559,6 +601,92 @@ def simulate_action_v2(self, board, start_player_index, action): return new_board, new_legal_actions + def render(self, mode="human"): + """ + Render the game state, either as a string (mode='human') or as an RGB image (mode='rgb_array'). + + Arguments: + - mode (:obj:`str`): The mode to render with. Valid modes are: + - 'human': render to the current display or terminal and + - 'rgb_array': Return an numpy.ndarray with shape (x, y, 3), + representing RGB values for an image of the board + Returns: + if mode is: + - 'human': returns None + - 'rgb_array': return a numpy array representing the rendered image. + Raises: + ValueError: If the provided mode is unknown. + """ + if mode == 'human': + print(self.board) + elif mode == 'rgb_array': + dpi = 80 + fig, ax = plt.subplots(figsize=(6, 6), dpi=dpi) + + # """Piece is in the cross point of row and col""" + # # Draw a black background, white grid + # ax.imshow(np.zeros((self.board_size, self.board_size, 3)), origin='lower') + # ax.grid(color='white', linewidth=2) + # + # # Draw the 'X' and 'O' symbols for each player + # for i in range(self.board_size): + # for j in range(self.board_size): + # if self.board[i, j] == 1: # Player 1 + # ax.text(j, i, 'X', ha='center', va='center', color='white', fontsize=24) + # elif self.board[i, j] == 2: # Player 2 + # ax.text(j, i, 'O', ha='center', va='center', color='white', fontsize=24) + + # # Setup the axes + # ax.set_xticks(np.arange(self.board_size)) + # ax.set_yticks(np.arange(self.board_size)) + + """Piece is in the center point of grid""" + # Draw a peachpuff background, black grid + ax.imshow(np.ones((self.board_size, self.board_size, 3)) * np.array([255, 218, 185]) / 255, origin='lower') + ax.grid(color='black', linewidth=2) + + # Draw the 'X' and 'O' symbols for each player + for i in range(self.board_size): + for j in range(self.board_size): + if self.board[i, j] == 1: # Player 1 + ax.text(j, i, 'X', ha='center', va='center', color='black', fontsize=24) + elif self.board[i, j] == 2: # Player 2 + ax.text(j, i, 'O', ha='center', va='center', color='white', fontsize=24) + + # Setup the axes + ax.set_xticks(np.arange(0.5, self.board_size, 1)) + ax.set_yticks(np.arange(0.5, self.board_size, 1)) + + ax.set_xticklabels([]) + ax.set_yticklabels([]) + ax.xaxis.set_ticks_position('none') + ax.yaxis.set_ticks_position('none') + + # Set the title of the game + plt.title('TicTacToe: ' + ('Black Turn' if self.current_player == 1 else 'White Turn')) + + fig.canvas.draw() + + # Get the width and height of the figure + width, height = fig.get_size_inches() * fig.get_dpi() + width = int(width) + height = int(height) + + # Use the width and height values to reshape the numpy array + img = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8') + img = img.reshape(height, width, 3) + + plt.close(fig) + + return img + else: + raise ValueError(f"Unknown mode '{mode}', it should be either 'human' or 'rgb_array'.") + + @staticmethod + def display_frames_as_gif(frames: list, path: str) -> None: + import imageio + imageio.mimsave(path, frames, fps=20) + def clone(self): return copy.deepcopy(self) @@ -567,9 +695,6 @@ def seed(self, seed: int, dynamic_seed: bool = True) -> None: self._dynamic_seed = dynamic_seed np.random.seed(self._seed) - def render(self, mode="human"): - print(self.board) - @property def observation_space(self) -> gym.spaces.Space: return self._observation_space diff --git a/zoo/box2d/lunarlander/entry/lunarlander_eval.py b/zoo/box2d/lunarlander/entry/lunarlander_eval.py index b74516ed0..6a51ce5fa 100644 --- a/zoo/box2d/lunarlander/entry/lunarlander_eval.py +++ b/zoo/box2d/lunarlander/entry/lunarlander_eval.py @@ -1,38 +1,78 @@ -# According to the model you want to evaluate, import the corresponding config. +# Import the necessary libraries and configs based on the model you want to evaluate from zoo.box2d.lunarlander.config.lunarlander_disc_muzero_config import main_config, create_config from lzero.entry import eval_muzero import numpy as np if __name__ == "__main__": - """ - model_path (:obj:`Optional[str]`): The pretrained model path, which should - point to the ckpt file of the pretrained model, and an absolute path is recommended. - In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``. """ - model_path = './ckpt/ckpt_best.pth.tar' + Overview: + Evaluate the model performance by running multiple episodes with different seeds using the MuZero algorithm. + The evaluation results (returns and mean returns) are printed out for each seed and summarized for all seeds. + Variables: + - model_path (:obj:`str`): Path to the pretrained model's checkpoint file. Usually something like + "exp_name/ckpt/ckpt_best.pth.tar". Absolute path is recommended. + - seeds (:obj:`List[int]`): List of seeds to use for evaluation. Each seed will run for a specified number + of episodes. + - num_episodes_each_seed (:obj:`int`): Number of episodes to be run for each seed. + - main_config (:obj:`EasyDict`): Main configuration for the evaluation, imported from the model's config file. + - returns_mean_seeds (:obj:`List[float]`): List to store the mean returns for each seed. + - returns_seeds (:obj:`List[List[float]]`): List to store the returns for each episode from each seed. + Outputs: + Prints out the mean returns and returns for each seed, along with the overall mean return across all seeds. + + .. note:: + The eval_muzero function is used here for evaluation. For more details about this function and its parameters, + please refer to its own documentation. + """ + # model_path = './ckpt/ckpt_best.pth.tar' + model_path = None + + # Initialize a list with a single seed for the experiment seeds = [0] - num_episodes_each_seed = 5 + + # Set the number of episodes to run for each seed + num_episodes_each_seed = 1 + + # Specify the number of environments for the evaluator to use main_config.env.evaluator_env_num = 1 + + # Set the number of episodes for the evaluator to run main_config.env.n_evaluator_episode = 1 + + # The total number of test episodes is the product of the number of episodes per seed and the number of seeds total_test_episodes = num_episodes_each_seed * len(seeds) + + # Uncomment the following lines to save a replay of the episodes as an mp4 video + # main_config.env.replay_path = './video' + + # Enable saving of replay as a gif, specify the path to save the replay gif + main_config.env.save_replay_gif = True + main_config.env.replay_path_gif = './video' + + # Initialize lists to store the mean and total returns for each seed returns_mean_seeds = [] returns_seeds = [] + + # For each seed, run the evaluation function and store the resulting mean and total returns for seed in seeds: returns_mean, returns = eval_muzero( - [main_config, create_config], - seed=seed, - num_episodes_each_seed=num_episodes_each_seed, - print_seed_details=False, - model_path=model_path + [main_config, create_config], # Configuration parameters for the evaluation + seed=seed, # The seed for the random number generator + num_episodes_each_seed=num_episodes_each_seed, # The number of episodes to run for this seed + print_seed_details=False, # Whether to print detailed information for each seed + model_path=model_path # The path to the trained model to be evaluated ) + # Append the mean and total returns to their respective lists returns_mean_seeds.append(returns_mean) returns_seeds.append(returns) + # Convert the lists of returns to numpy arrays for easier statistical analysis returns_mean_seeds = np.array(returns_mean_seeds) returns_seeds = np.array(returns_seeds) + # Print evaluation results print("=" * 20) - print(f'We eval total {len(seeds)} seeds. In each seed, we eval {num_episodes_each_seed} episodes.') - print(f'In seeds {seeds}, returns_mean_seeds is {returns_mean_seeds}, returns is {returns_seeds}') - print('In all seeds, reward_mean:', returns_mean_seeds.mean()) - print("=" * 20) + print(f"We evaluated a total of {len(seeds)} seeds. For each seed, we evaluated {num_episodes_each_seed} episode(s).") + print(f"For seeds {seeds}, the mean returns are {returns_mean_seeds}, and the returns are {returns_seeds}.") + print("Across all seeds, the mean reward is:", returns_mean_seeds.mean()) + print("=" * 20) \ No newline at end of file diff --git a/zoo/box2d/lunarlander/envs/lunarlander_cont_disc_env.py b/zoo/box2d/lunarlander/envs/lunarlander_cont_disc_env.py index 4d03e78c7..bfb402f7f 100755 --- a/zoo/box2d/lunarlander/envs/lunarlander_cont_disc_env.py +++ b/zoo/box2d/lunarlander/envs/lunarlander_cont_disc_env.py @@ -1,45 +1,69 @@ -from typing import Any, List, Union, Optional +from datetime import datetime + import gym +import copy import os +from itertools import product + +import gym import numpy as np -from ding.envs import BaseEnv, BaseEnvTimestep -from ding.torch_utils import to_ndarray, to_list -from ding.utils import ENV_REGISTRY -from ding.envs.common import affine_transform +from ding.envs import BaseEnvTimestep from ding.envs import ObsPlusPrevActRewWrapper -from itertools import product +from ding.envs.common import affine_transform +from ding.torch_utils import to_ndarray +from ding.utils import ENV_REGISTRY from easydict import EasyDict -import copy + +from zoo.box2d.lunarlander.envs.lunarlander_env import LunarLanderEnv @ENV_REGISTRY.register('lunarlander_cont_disc') -class LunarLanderDiscEnv(BaseEnv): +class LunarLanderDiscEnv(LunarLanderEnv): """ - Overview: - The modified LunarLander environment with manually discretized action space. For each dimension, equally dividing the - original continuous action into ``each_dim_disc_size`` bins and using their Cartesian product to obtain - handcrafted discrete actions. + Overview: + The modified LunarLander environment with manually discretized action space. For each dimension, it equally divides the + original continuous action into ``each_dim_disc_size`` bins and uses their Cartesian product to obtain + handcrafted discrete actions. """ @classmethod def default_config(cls: type) -> EasyDict: + """ + Overview: + Get the default configuration of the LunarLander environment. + Returns: + - cfg (:obj:`EasyDict`): Default configuration dictionary. + """ cfg = EasyDict(copy.deepcopy(cls.config)) cfg.cfg_type = cls.__name__ + 'Dict' return cfg config = dict( + # (str) The gym environment name. + env_name="LunarLander-v2", + # (int) The number of bins for each dimension of the action space. + each_dim_disc_size=4, + # (bool) If True, save the replay as a gif file. save_replay_gif=False, + # (str or None) The path to save the replay gif. If None, the replay gif will not be saved. replay_path_gif=None, + # (str or None) The path to save the replay. If None, the replay will not be saved. replay_path=None, - use_act_scale=False, - delay_reward_step=0, - prob_random_agent=0., + # (bool) If True, the action will be scaled. + act_scale=True, + # (int) The maximum number of steps for each episode during collection. collect_max_episode_steps=int(1.08e5), + # (int) The maximum number of steps for each episode during evaluation. eval_max_episode_steps=int(1.08e5), - each_dim_disc_size=4, ) def __init__(self, cfg: dict) -> None: + """ + Overview: + Initialize the LunarLander environment with the given config dictionary. + Arguments: + - cfg (:obj:`dict`): Configuration dictionary. + """ self._cfg = cfg self._init_flag = False # env_name: LunarLander-v2, LunarLanderContinuous-v2 @@ -56,19 +80,21 @@ def __init__(self, cfg: dict) -> None: def reset(self) -> np.ndarray: """ Overview: - During the reset phase, the original environment will be created, - and at the same time, the action space will be discretized into "each_dim_disc_size" bins. + Reset the environment. During the reset phase, the original environment will be created, + and at the same time, the action space will be discretized into "each_dim_disc_size" bins. Returns: - info_dict (:obj:`Dict[str, Any]`): Including observation, action_mask, and to_play label. """ if not self._init_flag: self._env = gym.make(self._cfg.env_name) if self._replay_path is not None: + timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + video_name = f'{self._env.spec.id}-video-{timestamp}' self._env = gym.wrappers.RecordVideo( self._env, video_folder=self._replay_path, episode_trigger=lambda episode_id: True, - name_prefix='rl-video-{}'.format(id(self)) + name_prefix=video_name ) if hasattr(self._cfg, 'obs_plus_prev_action_reward') and self._cfg.obs_plus_prev_action_reward: self._env = ObsPlusPrevActRewWrapper(self._env) @@ -104,28 +130,15 @@ def reset(self) -> np.ndarray: obs = {'observation': obs, 'action_mask': action_mask, 'to_play': -1} return obs - def close(self) -> None: - if self._init_flag: - self._env.close() - self._init_flag = False - - def render(self) -> None: - self._env.render() - - def seed(self, seed: int, dynamic_seed: bool = True) -> None: - self._seed = seed - self._dynamic_seed = dynamic_seed - np.random.seed(self._seed) - def step(self, action: np.ndarray) -> BaseEnvTimestep: """ Overview: - During the step phase, the environment first converts the discrete action into a continuous action, - and then passes it into the original environment. + Take an action in the environment. During the step phase, the environment first converts the discrete action into a continuous action, + and then passes it into the original environment. Arguments: - - action (:obj:`np.ndarray`): Discrete action + - action (:obj:`np.ndarray`): Discrete action to be taken in the environment. Returns: - - BaseEnvTimestep (:obj:`tuple`): Including observation, reward, done, and info. + - BaseEnvTimestep (:obj:`BaseEnvTimestep`): A tuple containing observation, reward, done, and info. """ action = [-1 + 2 / self.n * k for k in self.disc_to_cont[int(action)]] action = to_ndarray(action) @@ -143,67 +156,26 @@ def step(self, action: np.ndarray) -> BaseEnvTimestep: if done: info['eval_episode_return'] = self._eval_episode_return if self._save_replay_gif: - print(self._replay_path) - if not os.path.exists(self._replay_path): - os.makedirs(self._replay_path) + if not os.path.exists(self._replay_path_gif): + os.makedirs(self._replay_path_gif) + timestamp = datetime.now().strftime("%Y%m%d%H%M%S") path = os.path.join( - self._replay_path, '{}_episode_{}.gif'.format(self._env_name, self._save_replay_count) + self._replay_path_gif, + '{}_episode_{}_seed{}_{}.gif'.format(self._env_name, self._save_replay_count, self._seed, timestamp) ) self.display_frames_as_gif(self._frames, path) + print(f'save episode {self._save_replay_count} in {self._replay_path_gif}!') self._save_replay_count += 1 obs = to_ndarray(obs) - rew = to_ndarray([rew]).astype(np.float32) # wrapped to be transferred to a array with shape (1,) + rew = to_ndarray([rew]).astype(np.float32) # wrapped to be transferred to an array with shape (1,) return BaseEnvTimestep(obs, rew, done, info) - @property - def legal_actions(self): - return np.arange(self._action_space.n) - - def enable_save_replay(self, replay_path: Optional[str] = None) -> None: - if replay_path is None: - replay_path = './video' - self._replay_path = replay_path - self._save_replay_gif = True - self._save_replay_count = 0 - - @staticmethod - def display_frames_as_gif(frames: list, path: str) -> None: - import imageio - imageio.mimsave(path, frames, fps=20) - - def random_action(self) -> np.ndarray: - random_action = self.action_space.sample() - if isinstance(random_action, np.ndarray): - pass - elif isinstance(random_action, int): - random_action = to_ndarray([random_action], dtype=np.int64) - return random_action - - @property - def observation_space(self) -> gym.spaces.Space: - return self._observation_space - - @property - def action_space(self) -> gym.spaces.Space: - return self._action_space - - @property - def reward_space(self) -> gym.spaces.Space: - return self._reward_space - def __repr__(self) -> str: - return "DI-engine LunarLander Env" - - @staticmethod - def create_collector_env_cfg(cfg: dict) -> List[dict]: - collector_env_num = cfg.pop('collector_env_num') - cfg = copy.deepcopy(cfg) - cfg.max_episode_steps = cfg.collect_max_episode_steps - return [cfg for _ in range(collector_env_num)] + """ + Overview: + Represent the environment instance as a string. + Returns: + - repr_str (:obj:`str`): Representation string of the environment instance. + """ + return "LightZero LunarLander Env (with manually discretized action space)" - @staticmethod - def create_evaluator_env_cfg(cfg: dict) -> List[dict]: - evaluator_env_num = cfg.pop('evaluator_env_num') - cfg = copy.deepcopy(cfg) - cfg.max_episode_steps = cfg.eval_max_episode_steps - return [cfg for _ in range(evaluator_env_num)] diff --git a/zoo/box2d/lunarlander/envs/lunarlander_env.py b/zoo/box2d/lunarlander/envs/lunarlander_env.py index a6827d2b7..8a7662613 100755 --- a/zoo/box2d/lunarlander/envs/lunarlander_env.py +++ b/zoo/box2d/lunarlander/envs/lunarlander_env.py @@ -1,68 +1,101 @@ import copy import os -from typing import List, Optional +from datetime import datetime +from typing import List, Optional, Dict import gym import numpy as np -from ding.envs import BaseEnv, BaseEnvTimestep +from ding.envs import BaseEnvTimestep from ding.envs import ObsPlusPrevActRewWrapper from ding.envs.common import affine_transform from ding.torch_utils import to_ndarray from ding.utils import ENV_REGISTRY from easydict import EasyDict +from zoo.classic_control.cartpole.envs.cartpole_lightzero_env import CartPoleEnv + @ENV_REGISTRY.register('lunarlander') -class LunarLanderEnv(BaseEnv): +class LunarLanderEnv(CartPoleEnv): + """ + Overview: + The LunarLander Environment class for LightZero algo.. This class is a wrapper of the gym LunarLander environment, with additional + functionalities like replay saving and seed setting. The class is registered in ENV_REGISTRY with the key 'lunarlander'. + """ config = dict( + # (str) The gym environment name. env_name="LunarLander-v2", + # (bool) If True, save the replay as a gif file. save_replay_gif=False, + # (str or None) The path to save the replay gif. If None, the replay gif will not be saved. replay_path_gif=None, + # (str or None) The path to save the replay. If None, the replay will not be saved. replay_path=None, + # (bool) If True, the action will be scaled. act_scale=True, - delay_reward_step=0, - prob_random_agent=0., + # (int) The maximum number of steps for each episode during collection. collect_max_episode_steps=int(1.08e5), + # (int) The maximum number of steps for each episode during evaluation. eval_max_episode_steps=int(1.08e5), ) @classmethod def default_config(cls: type) -> EasyDict: + """ + Overview: + Return the default configuration of the class. + Returns: + - cfg (:obj:`EasyDict`): Default configuration dict. + """ cfg = EasyDict(copy.deepcopy(cls.config)) cfg.cfg_type = cls.__name__ + 'Dict' return cfg def __init__(self, cfg: dict) -> None: + """ + Overview: + Initialize the LunarLander environment. + Arguments: + - cfg (:obj:`dict`): Configuration dict. The dict should include keys like 'env_name', 'replay_path', etc. + """ self._cfg = cfg self._init_flag = False # env_name options = {'LunarLander-v2', 'LunarLanderContinuous-v2'} self._env_name = cfg.env_name - self._replay_path = cfg.get('replay_path', None) - self._replay_path_gif = cfg.get('replay_path_gif', None) - self._save_replay_gif = cfg.get('save_replay_gif', False) + self._replay_path = cfg.replay_path + self._replay_path_gif = cfg.replay_path_gif + self._save_replay_gif = cfg.save_replay_gif self._save_replay_count = 0 if 'Continuous' in self._env_name: self._act_scale = cfg.act_scale # act_scale only works in continuous env else: self._act_scale = False - def reset(self) -> np.ndarray: + def reset(self) -> Dict[str, np.ndarray]: + """ + Overview: + Reset the environment and return the initial observation. + Returns: + - obs (:obj:`np.ndarray`): The initial observation after resetting. + """ if not self._init_flag: self._env = gym.make(self._cfg.env_name) if self._replay_path is not None: + timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + video_name = f'{self._env.spec.id}-video-{timestamp}' self._env = gym.wrappers.RecordVideo( self._env, video_folder=self._replay_path, episode_trigger=lambda episode_id: True, - name_prefix='rl-video-{}'.format(id(self)) + name_prefix=video_name ) if hasattr(self._cfg, 'obs_plus_prev_action_reward') and self._cfg.obs_plus_prev_action_reward: self._env = ObsPlusPrevActRewWrapper(self._env) self._observation_space = self._env.observation_space self._action_space = self._env.action_space self._reward_space = gym.spaces.Box( - low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32 + low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1,), dtype=np.float32 ) self._init_flag = True if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed: @@ -83,21 +116,16 @@ def reset(self) -> np.ndarray: obs = {'observation': obs, 'action_mask': action_mask, 'to_play': -1} return obs - def close(self) -> None: - if self._init_flag: - self._env.close() - self._init_flag = False - - def render(self) -> None: - self._env.render() - - def seed(self, seed: int, dynamic_seed: bool = True) -> None: - self._seed = seed - self._dynamic_seed = dynamic_seed - np.random.seed(self._seed) - def step(self, action: np.ndarray) -> BaseEnvTimestep: - if action.shape == (1, ): + """ + Overview: + Take a step in the environment with the given action. + Arguments: + - action (:obj:`np.ndarray`): The action to be taken. + Returns: + - timestep (:obj:`BaseEnvTimestep`): The timestep information including observation, reward, done flag, and info. + """ + if action.shape == (1,): action = action.item() # 0-dim array if self._act_scale: action = affine_transform(action, min_val=-1, max_val=1) @@ -117,9 +145,10 @@ def step(self, action: np.ndarray) -> BaseEnvTimestep: if self._save_replay_gif: if not os.path.exists(self._replay_path_gif): os.makedirs(self._replay_path_gif) + timestamp = datetime.now().strftime("%Y%m%d%H%M%S") path = os.path.join( self._replay_path_gif, - '{}_episode_{}_seed{}.gif'.format(self._env_name, self._save_replay_count, self._seed) + '{}_episode_{}_seed{}_{}.gif'.format(self._env_name, self._save_replay_count, self._seed, timestamp) ) self.display_frames_as_gif(self._frames, path) print(f'save episode {self._save_replay_count} in {self._replay_path_gif}!') @@ -129,16 +158,15 @@ def step(self, action: np.ndarray) -> BaseEnvTimestep: return BaseEnvTimestep(obs, rew, done, info) @property - def legal_actions(self): + def legal_actions(self) -> np.ndarray: + """ + Overview: + Get the legal actions in the environment. + Returns: + - legal_actions (:obj:`np.ndarray`): An array of legal actions. + """ return np.arange(self._action_space.n) - def enable_save_replay(self, replay_path: Optional[str] = None) -> None: - if replay_path is None: - replay_path = './video' - self._replay_path = replay_path - self._save_replay_gif = True - self._save_replay_count = 0 - @staticmethod def display_frames_as_gif(frames: list, path: str) -> None: import imageio @@ -152,23 +180,19 @@ def random_action(self) -> np.ndarray: random_action = to_ndarray([random_action], dtype=np.int64) return random_action - @property - def observation_space(self) -> gym.spaces.Space: - return self._observation_space - - @property - def action_space(self) -> gym.spaces.Space: - return self._action_space - - @property - def reward_space(self) -> gym.spaces.Space: - return self._reward_space - def __repr__(self) -> str: return "LightZero LunarLander Env." @staticmethod def create_collector_env_cfg(cfg: dict) -> List[dict]: + """ + Overview: + Create a list of environment configurations for the collector. + Arguments: + - cfg (:obj:`dict`): The base configuration dict. + Returns: + - cfgs (:obj:`List[dict]`): The list of environment configurations. + """ collector_env_num = cfg.pop('collector_env_num') cfg = copy.deepcopy(cfg) cfg.max_episode_steps = cfg.collect_max_episode_steps @@ -176,6 +200,14 @@ def create_collector_env_cfg(cfg: dict) -> List[dict]: @staticmethod def create_evaluator_env_cfg(cfg: dict) -> List[dict]: + """ + Overview: + Create a list of environment configurations for the evaluator. + Arguments: + - cfg (:obj:`dict`): The base configuration dict. + Returns: + - cfgs (:obj:`List[dict]`): The list of environment configurations. + """ evaluator_env_num = cfg.pop('evaluator_env_num') cfg = copy.deepcopy(cfg) cfg.max_episode_steps = cfg.eval_max_episode_steps diff --git a/zoo/box2d/lunarlander/envs/test_lunarlander_env.py b/zoo/box2d/lunarlander/envs/test_lunarlander_env.py index c9e76c846..f932f1de2 100755 --- a/zoo/box2d/lunarlander/envs/test_lunarlander_env.py +++ b/zoo/box2d/lunarlander/envs/test_lunarlander_env.py @@ -9,11 +9,17 @@ 'cfg', [ EasyDict({ 'env_name': 'LunarLander-v2', - 'act_scale': False + 'act_scale': False, + 'replay_path': None, + 'replay_path_gif': None, + 'save_replay_gif': False, }), EasyDict({ 'env_name': 'LunarLanderContinuous-v2', - 'act_scale': True + 'act_scale': True, + 'replay_path': None, + 'replay_path_gif': None, + 'save_replay_gif': False, }) ] ) diff --git a/zoo/classic_control/cartpole/entry/cartpole_eval.py b/zoo/classic_control/cartpole/entry/cartpole_eval.py index 32cb54907..b4e3c554e 100644 --- a/zoo/classic_control/cartpole/entry/cartpole_eval.py +++ b/zoo/classic_control/cartpole/entry/cartpole_eval.py @@ -1,22 +1,39 @@ -from cartpole_muzero_config import main_config, create_config +from zoo.classic_control.cartpole.config.cartpole_muzero_config import main_config, create_config from lzero.entry import eval_muzero import numpy as np if __name__ == "__main__": - """ - model_path (:obj:`Optional[str]`): The pretrained model path, which should - point to the ckpt file of the pretrained model, and an absolute path is recommended. - In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``. """ - model_path = "./ckpt/ckpt_best.pth.tar" - seeds = [0] - num_episodes_each_seed = 5 - main_config.env.evaluator_env_num = 1 - main_config.env.n_evaluator_episode = 1 - total_test_episodes = num_episodes_each_seed * len(seeds) + Entry point for the evaluation of the MuZero model on the CartPole environment. + + Variables: + - model_path (:obj:`Optional[str]`): The pretrained model path, which should point to the ckpt file of the + pretrained model. An absolute path is recommended. In LightZero, the path is usually something like + ``exp_name/ckpt/ckpt_best.pth.tar``. + - returns_mean_seeds (:obj:`List[float]`): List to store the mean returns for each seed. + - returns_seeds (:obj:`List[float]`): List to store the returns for each seed. + - seeds (:obj:`List[int]`): List of seeds for the environment. + - num_episodes_each_seed (:obj:`int`): Number of episodes to run for each seed. + - total_test_episodes (:obj:`int`): Total number of test episodes, computed as the product of the number of + seeds and the number of episodes per seed. + """ + # model_path = "./ckpt/ckpt_best.pth.tar" + model_path = None returns_mean_seeds = [] returns_seeds = [] + seeds = [0] + num_episodes_each_seed = 2 + total_test_episodes = num_episodes_each_seed * len(seeds) + create_config.env_manager.type = 'base' # Visualization requires the 'type' to be set as base + main_config.env.evaluator_env_num = 1 # Visualization requires the 'env_num' to be set as 1 + main_config.env.n_evaluator_episode = total_test_episodes + main_config.env.replay_path = './video' + for seed in seeds: + """ + - returns_mean (:obj:`float`): The mean return of the evaluation. + - returns (:obj:`List[float]`): The returns of the evaluation. + """ returns_mean, returns = eval_muzero( [main_config, create_config], seed=seed, @@ -30,8 +47,9 @@ returns_mean_seeds = np.array(returns_mean_seeds) returns_seeds = np.array(returns_seeds) + # Print evaluation results print("=" * 20) - print(f'We eval total {len(seeds)} seeds. In each seed, we eval {num_episodes_each_seed} episodes.') - print(f'In seeds {seeds}, returns_mean_seeds is {returns_mean_seeds}, returns is {returns_seeds}') - print('In all seeds, reward_mean:', returns_mean_seeds.mean()) - print("=" * 20) + print(f"We evaluated a total of {len(seeds)} seeds. For each seed, we evaluated {num_episodes_each_seed} episode(s).") + print(f"For seeds {seeds}, the mean returns are {returns_mean_seeds}, and the returns are {returns_seeds}.") + print("Across all seeds, the mean reward is:", returns_mean_seeds.mean()) + print("=" * 20) \ No newline at end of file diff --git a/zoo/classic_control/cartpole/envs/cartpole_lightzero_env.py b/zoo/classic_control/cartpole/envs/cartpole_lightzero_env.py index bf958ce1c..52d337189 100644 --- a/zoo/classic_control/cartpole/envs/cartpole_lightzero_env.py +++ b/zoo/classic_control/cartpole/envs/cartpole_lightzero_env.py @@ -1,8 +1,8 @@ -from typing import Union, Optional +from datetime import datetime +from typing import Union, Optional, Dict import gym import numpy as np - from ding.envs import BaseEnv, BaseEnvTimestep from ding.envs import ObsPlusPrevActRewWrapper from ding.torch_utils import to_ndarray @@ -11,31 +11,46 @@ @ENV_REGISTRY.register('cartpole_lightzero') class CartPoleEnv(BaseEnv): + """ + LightZero version of the classic CartPole environment. This class includes methods for resetting, closing, and + stepping through the environment, as well as seeding for reproducibility, saving replay videos, and generating random + actions. It also includes properties for accessing the observation space, action space, and reward space of the + environment. + """ def __init__(self, cfg: dict = {}) -> None: + """ + Initialize the environment with a configuration dictionary. Sets up spaces for observations, actions, and rewards. + """ self._cfg = cfg self._init_flag = False - self._replay_path = None + self._continuous = False + self._replay_path = cfg.replay_path self._observation_space = gym.spaces.Box( low=np.array([-4.8, float("-inf"), -0.42, float("-inf")]), high=np.array([4.8, float("inf"), 0.42, float("inf")]), - shape=(4, ), + shape=(4,), dtype=np.float32 ) self._action_space = gym.spaces.Discrete(2) self._action_space.seed(0) # default seed - self._reward_space = gym.spaces.Box(low=0.0, high=1.0, shape=(1, ), dtype=np.float32) - self._continuous = False + self._reward_space = gym.spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float32) - def reset(self) -> np.ndarray: + def reset(self) -> Dict[str, np.ndarray]: + """ + Reset the environment. If it hasn't been initialized yet, this method also handles that. It also handles seeding + if necessary. Returns the first observation. + """ if not self._init_flag: self._env = gym.make('CartPole-v0') if self._replay_path is not None: + timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + video_name = f'{self._env.spec.id}-video-{timestamp}' self._env = gym.wrappers.RecordVideo( self._env, video_folder=self._replay_path, episode_trigger=lambda episode_id: True, - name_prefix='rl-video-{}'.format(id(self)) + name_prefix=video_name ) if hasattr(self._cfg, 'obs_plus_prev_action_reward') and self._cfg.obs_plus_prev_action_reward: self._env = ObsPlusPrevActRewWrapper(self._env) @@ -57,18 +72,26 @@ def reset(self) -> np.ndarray: return obs - def close(self) -> None: - if self._init_flag: - self._env.close() - self._init_flag = False - - def seed(self, seed: int, dynamic_seed: bool = True) -> None: - self._seed = seed - self._dynamic_seed = dynamic_seed - np.random.seed(self._seed) - def step(self, action: Union[int, np.ndarray]) -> BaseEnvTimestep: - if isinstance(action, np.ndarray) and action.shape == (1, ): + """ + Overview: + Perform a step in the environment using the provided action, and return the next state of the environment. + The next state is encapsulated in a BaseEnvTimestep object, which includes the new observation, reward, + done flag, and info dictionary. + Arguments: + - action (:obj:`Union[int, np.ndarray]`): The action to be performed in the environment. If the action is + a 1-dimensional numpy array, it is squeezed to a 0-dimension array. + Returns: + - timestep (:obj:`BaseEnvTimestep`): An object containing the new observation, reward, done flag, + and info dictionary. + .. note:: + - The cumulative reward (`_eval_episode_return`) is updated with the reward obtained in this step. + - If the episode ends (done is True), the total reward for the episode is stored in the info dictionary + under the key 'eval_episode_return'. + - An action mask is created with ones, which represents the availability of each action in the action space. + - Observations are returned in a dictionary format containing 'observation', 'action_mask', and 'to_play'. + """ + if isinstance(action, np.ndarray) and action.shape == (1,): action = action.squeeze() # 0-dim array obs, rew, done, info = self._env.step(action) @@ -82,27 +105,61 @@ def step(self, action: Union[int, np.ndarray]) -> BaseEnvTimestep: return BaseEnvTimestep(obs, rew, done, info) + def close(self) -> None: + """ + Close the environment, and set the initialization flag to False. + """ + if self._init_flag: + self._env.close() + self._init_flag = False + + def seed(self, seed: int, dynamic_seed: bool = True) -> None: + """ + Set the seed for the environment's random number generator. Can handle both static and dynamic seeding. + """ + self._seed = seed + self._dynamic_seed = dynamic_seed + np.random.seed(self._seed) + def enable_save_replay(self, replay_path: Optional[str] = None) -> None: + """ + Enable the saving of replay videos. If no replay path is given, a default is used. + """ if replay_path is None: replay_path = './video' self._replay_path = replay_path def random_action(self) -> np.ndarray: + """ + Generate a random action using the action space's sample method. Returns a numpy array containing the action. + """ random_action = self.action_space.sample() random_action = to_ndarray([random_action], dtype=np.int64) return random_action @property def observation_space(self) -> gym.spaces.Space: + """ + Property to access the observation space of the environment. + """ return self._observation_space @property def action_space(self) -> gym.spaces.Space: + """ + Property to access the action space of the environment. + """ return self._action_space @property def reward_space(self) -> gym.spaces.Space: + """ + Property to access the reward space of the environment. + """ return self._reward_space def __repr__(self) -> str: + """ + String representation of the environment. + """ return "LightZero CartPole Env" diff --git a/zoo/classic_control/pendulum/entry/pendulum_eval.py b/zoo/classic_control/pendulum/entry/pendulum_eval.py index c8e80731f..94bf729a8 100644 --- a/zoo/classic_control/pendulum/entry/pendulum_eval.py +++ b/zoo/classic_control/pendulum/entry/pendulum_eval.py @@ -11,7 +11,7 @@ """ model_path = "./ckpt/ckpt_best.pth.tar" seeds = [0] - num_episodes_each_seed = 5 + num_episodes_each_seed = 1 main_config.env.evaluator_env_num = 1 main_config.env.n_evaluator_episode = 1 total_test_episodes = num_episodes_each_seed * len(seeds) @@ -31,8 +31,9 @@ returns_mean_seeds = np.array(returns_mean_seeds) returns_seeds = np.array(returns_seeds) + # Print evaluation results print("=" * 20) - print(f'We eval total {len(seeds)} seeds. In each seed, we eval {num_episodes_each_seed} episodes.') - print(f'In seeds {seeds}, returns_mean_seeds is {returns_mean_seeds}, returns is {returns_seeds}') - print('In all seeds, reward_mean:', returns_mean_seeds.mean()) + print(f"We evaluated a total of {len(seeds)} seeds. For each seed, we evaluated {num_episodes_each_seed} episode(s).") + print(f"For seeds {seeds}, the mean returns are {returns_mean_seeds}, and the returns are {returns_seeds}.") + print("Across all seeds, the mean reward is:", returns_mean_seeds.mean()) print("=" * 20) diff --git a/zoo/classic_control/pendulum/envs/pendulum_lightzero_env.py b/zoo/classic_control/pendulum/envs/pendulum_lightzero_env.py index 204ddbc3b..1ca23fb04 100644 --- a/zoo/classic_control/pendulum/envs/pendulum_lightzero_env.py +++ b/zoo/classic_control/pendulum/envs/pendulum_lightzero_env.py @@ -1,17 +1,26 @@ import copy -from typing import Optional +from datetime import datetime +from typing import Union, Dict import gym import numpy as np -from ding.envs import BaseEnv, BaseEnvTimestep +from ding.envs import BaseEnvTimestep from ding.envs.common.common_function import affine_transform from ding.torch_utils import to_ndarray from ding.utils import ENV_REGISTRY from easydict import EasyDict +from zoo.classic_control.cartpole.envs.cartpole_lightzero_env import CartPoleEnv + @ENV_REGISTRY.register('pendulum_lightzero') -class PendulumEnv(BaseEnv): +class PendulumEnv(CartPoleEnv): + """ + LightZero version of the classic Pendulum environment. This class includes methods for resetting, closing, and + stepping through the environment, as well as seeding for reproducibility, saving replay videos, and generating random + actions. It also includes properties for accessing the observation space, action space, and reward space of the + environment. + """ @classmethod def default_config(cls: type) -> EasyDict: @@ -20,18 +29,18 @@ def default_config(cls: type) -> EasyDict: return cfg config = dict( + # (bool) Whether to use continuous action space continuous=True, - save_replay_gif=False, - replay_path_gif=None, + # (str) The path to save replay videos replay_path=None, + # (bool) Whether to scale action into [-2, 2] act_scale=True, - delay_reward_step=0, - prob_random_agent=0., - collect_max_episode_steps=int(1.08e5), - eval_max_episode_steps=int(1.08e5), ) def __init__(self, cfg: dict) -> None: + """ + Initialize the environment with a configuration dictionary. Sets up spaces for observations, actions, and rewards. + """ self._cfg = cfg self._act_scale = cfg.act_scale try: @@ -39,33 +48,39 @@ def __init__(self, cfg: dict) -> None: except: self._env = gym.make('Pendulum-v0') self._init_flag = False - self._replay_path = None + self._replay_path = cfg.replay_path self._continuous = cfg.get("continuous", True) self._observation_space = gym.spaces.Box( - low=np.array([-1.0, -1.0, -8.0]), high=np.array([1.0, 1.0, 8.0]), shape=(3, ), dtype=np.float32 + low=np.array([-1.0, -1.0, -8.0]), high=np.array([1.0, 1.0, 8.0]), shape=(3,), dtype=np.float32 ) if self._continuous: - self._action_space = gym.spaces.Box(low=-2.0, high=2.0, shape=(1, ), dtype=np.float32) + self._action_space = gym.spaces.Box(low=-2.0, high=2.0, shape=(1,), dtype=np.float32) else: self.discrete_action_num = 11 self._action_space = gym.spaces.Discrete(self.discrete_action_num) self._action_space.seed(0) # default seed self._reward_space = gym.spaces.Box( - low=-1 * (3.14 * 3.14 + 0.1 * 8 * 8 + 0.001 * 2 * 2), high=0.0, shape=(1, ), dtype=np.float32 + low=-1 * (3.14 * 3.14 + 0.1 * 8 * 8 + 0.001 * 2 * 2), high=0.0, shape=(1,), dtype=np.float32 ) - def reset(self) -> np.ndarray: + def reset(self) -> Dict[str, np.ndarray]: + """ + Reset the environment. If it hasn't been initialized yet, this method also handles that. It also handles seeding + if necessary. Returns the first observation. + """ if not self._init_flag: try: self._env = gym.make('Pendulum-v1') except: self._env = gym.make('Pendulum-v0') if self._replay_path is not None: + timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + video_name = f'{self._env.spec.id}-video-{timestamp}' self._env = gym.wrappers.RecordVideo( self._env, video_folder=self._replay_path, episode_trigger=lambda episode_id: True, - name_prefix='rl-video-{}'.format(id(self)) + name_prefix=video_name ) self._init_flag = True if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed: @@ -87,29 +102,38 @@ def reset(self) -> np.ndarray: return obs - def close(self) -> None: - if self._init_flag: - self._env.close() - self._init_flag = False - - def seed(self, seed: int, dynamic_seed: bool = True) -> None: - self._seed = seed - self._dynamic_seed = dynamic_seed - np.random.seed(self._seed) - - def step(self, action: np.ndarray) -> BaseEnvTimestep: + def step(self, action: Union[int, np.ndarray]) -> BaseEnvTimestep: + """ + Overview: + Step the environment forward with the provided action. This method returns the next state of the environment + (observation, reward, done flag, and info dictionary) encapsulated in a BaseEnvTimestep object. + Arguments: + - action (:obj:`Union[int, np.ndarray]`): The action to be performed in the environment. + Returns: + - timestep (:obj:`BaseEnvTimestep`): An object containing the new observation, reward, done flag, + and info dictionary. + + .. note:: + - If the environment requires discrete actions, they are converted to float actions in the range [-1, 1]. + - If action scaling is enabled, continuous actions are scaled into the range [-2, 2]. + - For each step, the cumulative reward (`_eval_episode_return`) is updated. + - If the episode ends (done is True), the total reward for the episode is stored in the info dictionary + under the key 'eval_episode_return'. + - If the environment requires discrete actions, an action mask is created, otherwise, it's None. + - Observations are returned in a dictionary format containing 'observation', 'action_mask', and 'to_play'. + """ if isinstance(action, int): action = np.array(action) # if require discrete env, convert actions to [-1 ~ 1] float actions if not self._continuous: action = (action / (self.discrete_action_num - 1)) * 2 - 1 - # scale into [-2, 2] + # scale the continous action into [-2, 2] if self._act_scale: action = affine_transform(action, min_val=self._env.action_space.low, max_val=self._env.action_space.high) obs, rew, done, info = self._env.step(action) self._eval_episode_return += rew obs = to_ndarray(obs).astype(np.float32) - # wrapped to be transferred to a array with shape (1,) + # wrapped to be transferred to an array with shape (1,) rew = to_ndarray([rew]).astype(np.float32) if done: @@ -123,12 +147,10 @@ def step(self, action: np.ndarray) -> BaseEnvTimestep: return BaseEnvTimestep(obs, rew, done, info) - def enable_save_replay(self, replay_path: Optional[str] = None) -> None: - if replay_path is None: - replay_path = './video' - self._replay_path = replay_path - def random_action(self) -> np.ndarray: + """ + Generate a random action using the action space's sample method. Returns a numpy array containing the action. + """ if self._continuous: random_action = self.action_space.sample().astype(np.float32) else: @@ -136,17 +158,8 @@ def random_action(self) -> np.ndarray: random_action = to_ndarray([random_action], dtype=np.int64) return random_action - @property - def observation_space(self) -> gym.spaces.Space: - return self._observation_space - - @property - def action_space(self) -> gym.spaces.Space: - return self._action_space - - @property - def reward_space(self) -> gym.spaces.Space: - return self._reward_space - def __repr__(self) -> str: + """ + String representation of the environment. + """ return "LightZero Pendulum Env({})".format(self._cfg.env_id) diff --git a/zoo/game_2048/entry/2048_eval.py b/zoo/game_2048/entry/2048_eval.py index 2a9bb6f58..d77407806 100644 --- a/zoo/game_2048/entry/2048_eval.py +++ b/zoo/game_2048/entry/2048_eval.py @@ -6,30 +6,41 @@ from zoo.game_2048.config.stochastic_muzero_2048_config import main_config, create_config if __name__ == "__main__": - """ - model_path (:obj:`Optional[str]`): The pretrained model path, which should - point to the ckpt file of the pretrained model, and an absolute path is recommended. - In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``. + """ + Entry point for the evaluation of the muzero or stochastic_muzero model on the 2048 environment. + + Variables: + - model_path (:obj:`Optional[str]`): The pretrained model path, which should point to the ckpt file of the + pretrained model. An absolute path is recommended. In LightZero, the path is usually something like + ``exp_name/ckpt/ckpt_best.pth.tar``. + - returns_mean_seeds (:obj:`List[float]`): List to store the mean returns for each seed. + - returns_seeds (:obj:`List[float]`): List to store the returns for each seed. + - seeds (:obj:`List[int]`): List of seeds for the environment. + - num_episodes_each_seed (:obj:`int`): Number of episodes to run for each seed. + - total_test_episodes (:obj:`int`): Total number of test episodes, computed as the product of the number of + seeds and the number of episodes per seed. """ - model_path = "./ckpt/ckpt_best.pth.tar" + # model_path = './ckpt/ckpt_best.pth.tar' + model_path = None returns_mean_seeds = [] returns_seeds = [] seeds = [0] num_episodes_each_seed = 1 - total_test_episodes = num_episodes_each_seed * len(seeds) - create_config.env_manager.type = 'base' # Visualization requires the 'type' to be set as base - main_config.env.evaluator_env_num = 1 # Visualization requires the 'env_num' to be set as 1 - main_config.env.n_evaluator_episode = total_test_episodes - main_config.env.save_replay = True # Whether to save the replay, if save the video render_mode_human must to be True + + # main_config.env.render_mode = 'image_realtime_mode' + main_config.env.render_mode = 'image_savefile_mode' + main_config.env.replay_path = './video' main_config.env.replay_format = 'gif' main_config.env.replay_name_suffix = 'muzero_ns100_s0' # main_config.env.replay_name_suffix = 'stochastic_muzero_ns100_s0' - main_config.env.replay_path = None main_config.env.max_episode_steps = int(1e9) # Adjust according to different environments - + total_test_episodes = num_episodes_each_seed * len(seeds) + create_config.env_manager.type = 'base' # Visualization requires the 'type' to be set as base + main_config.env.evaluator_env_num = 1 # Visualization requires the 'env_num' to be set as 1 + main_config.env.n_evaluator_episode = total_test_episodes for seed in seeds: returns_mean, returns = eval_muzero( [main_config, create_config], diff --git a/zoo/game_2048/envs/game_2048_env.py b/zoo/game_2048/envs/game_2048_env.py index e62d3acb7..00a143ed7 100644 --- a/zoo/game_2048/envs/game_2048_env.py +++ b/zoo/game_2048/envs/game_2048_env.py @@ -1,5 +1,6 @@ import copy import logging +import os import sys from typing import List @@ -713,9 +714,11 @@ def draw_tile(self, draw, x, y, o, fnt): def save_render_output(self, replay_name_suffix: str = '', replay_path=None, format='gif'): # At the end of the episode, save the frames to a gif or mp4 file if replay_path is None: - filename = f'game_2048_{replay_name_suffix}.{format}' + filename = f'2048_{replay_name_suffix}.{format}' else: - filename = f'{replay_path}.{format}' + if not os.path.exists(replay_path): + os.makedirs(replay_path) + filename = replay_path+f'/2048_{replay_name_suffix}.{format}' if format == 'gif': imageio.mimsave(filename, self.frames, 'GIF')