Skip to content

Commit

Permalink
polish(pu): polish comments and render in tictactoe, gomoku, connect4…
Browse files Browse the repository at this point in the history
…, 2048
  • Loading branch information
puyuan1996 committed Nov 29, 2023
1 parent 9e97184 commit d3aaccd
Show file tree
Hide file tree
Showing 19 changed files with 532 additions and 200 deletions.
14 changes: 14 additions & 0 deletions lzero/policy/alphazero.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ class AlphaZeroPolicy(Policy):
# (int) The number of channels of hidden states in AlphaZero model.
num_channels=32,
),
# (bool) Whether to enable the sampled-based algorithm (e.g. Sampled EfficientZero)
# this variable is used in ``collector``.
sampled_algo=False,
# (bool) Whether to enable the gumbel-based algorithm (e.g. Gumbel Muzero)
gumbel_algo=False,
# (bool) Whether to use multi-gpu training.
multi_gpu=False,
# (bool) Whether to use cuda for network.
Expand Down Expand Up @@ -350,6 +355,15 @@ def _get_simulation_env(self):
else:
raise NotImplementedError
self.simulate_env = GomokuEnv(gomoku_alphazero_config.env)
elif self._cfg.simulation_env_name == 'connect4':
from zoo.board_games.connect4.envs.connect4_env import Connect4Env
if self._cfg.simulation_env_config_type == 'play_with_bot':
from zoo.board_games.connect4.config.connect4_alphazero_bot_mode_config import connect4_alphazero_config
elif self._cfg.simulation_env_config_type == 'self_play':
from zoo.board_games.connect4.config.connect4_alphazero_sp_mode_config import connect4_alphazero_config
else:
raise NotImplementedError
self.simulate_env = Connect4Env(connect4_alphazero_config.env)
else:
raise NotImplementedError

Expand Down
36 changes: 26 additions & 10 deletions zoo/atari/entry/atari_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,41 @@
- returns_mean_seeds (:obj:`np.array`): Array of mean return values for each seed.
- returns_seeds (:obj:`np.array`): Array of all return values for each seed.
"""
# Take the config of MuZero as an example
# Importing the necessary configuration files from the atari muzero configuration in the zoo directory.
from zoo.atari.config.atari_muzero_config import main_config, create_config

# model_path = "/path/ckpt/ckpt_best.pth.tar"
# model_path is the path to the trained MuZero model checkpoint.
# If no path is provided, the script will use the default model.
model_path = None

# seeds is a list of seed values for the random number generator, used to initialize the environment.
seeds = [0]
# num_episodes_each_seed is the number of episodes to run for each seed.
num_episodes_each_seed = 1
# total_test_episodes is the total number of test episodes, calculated as the product of the number of seeds and the number of episodes per seed
total_test_episodes = num_episodes_each_seed * len(seeds)
create_config.env_manager.type = 'base' # Visualization requires the 'type' to be set as base
main_config.env.evaluator_env_num = 1 # Visualization requires the 'env_num' to be set as 1

# Setting the type of the environment manager to 'base' for the visualization purposes.
create_config.env_manager.type = 'base'
# The number of environments to evaluate concurrently. Set to 1 for visualization purposes.
main_config.env.evaluator_env_num = 1
# The total number of evaluation episodes that should be run.
main_config.env.n_evaluator_episode = total_test_episodes
main_config.env.render_mode_human = False # Whether to enable real-time rendering
# A boolean flag indicating whether to render the environments in real-time.
main_config.env.render_mode_human = False

main_config.env.save_replay = True # Whether to save the video
# A boolean flag indicating whether to save the video of the environment.
main_config.env.save_replay = True
# The path where the recorded video will be saved.
main_config.env.save_path = './video'
main_config.env.eval_max_episode_steps = int(20) # Adjust according to different environments
# The maximum number of steps for each episode during evaluation. This may need to be adjusted based on the specific characteristics of the environment.
main_config.env.eval_max_episode_steps = int(20)

# These lists will store the mean and total rewards for each seed.
returns_mean_seeds = []
returns_seeds = []

# The main evaluation loop. For each seed, the MuZero model is evaluated and the mean and total rewards are recorded.
for seed in seeds:
returns_mean, returns = eval_muzero(
[main_config, create_config],
Expand All @@ -49,11 +63,13 @@
returns_mean_seeds.append(returns_mean)
returns_seeds.append(returns)

# Convert the list of mean and total rewards into numpy arrays for easier statistical analysis.
returns_mean_seeds = np.array(returns_mean_seeds)
returns_seeds = np.array(returns_seeds)

# Printing the evaluation results. The average reward and the total reward for each seed are displayed, followed by the mean reward across all seeds.
print("=" * 20)
print(f'We eval total {len(seeds)} seeds. In each seed, we eval {num_episodes_each_seed} episodes.')
print(f'In seeds {seeds}, returns_mean_seeds is {returns_mean_seeds}, returns is {returns_seeds}')
print('In all seeds, reward_mean:', returns_mean_seeds.mean())
print(f"We evaluated a total of {len(seeds)} seeds. For each seed, we evaluated {num_episodes_each_seed} episode(s).")
print(f"For seeds {seeds}, the mean returns are {returns_mean_seeds}, and the returns are {returns_seeds}.")
print("Across all seeds, the mean reward is:", returns_mean_seeds.mean())
print("=" * 20)
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
batch_size = 256
max_env_step = int(1e6)
model_path = None
mcts_ctree = False

# ==============================================================
# end of the most frequently changed config specified by the user
# ==============================================================
Expand All @@ -19,12 +21,29 @@
env=dict(
battle_mode='play_with_bot_mode',
bot_action_type='rule',
channel_last=False,
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
n_evaluator_episode=evaluator_env_num,
manager=dict(shared_memory=False, ),
# ==============================================================
# for the creation of simulation env
agent_vs_human=False,
prob_random_agent=0,
prob_expert_agent=0,
scale=True,
screen_scaling=9,
render_mode=None,
alphazero_mcts_ctree=mcts_ctree,
# ==============================================================
),
policy=dict(
mcts_ctree=mcts_ctree,
# ==============================================================
# for the creation of simulation env
simulation_env_name='connect4',
simulation_env_config_type='play_with_bot',
# ==============================================================
model=dict(
observation_shape=(3, 6, 7),
action_space_size=7,
Expand Down
10 changes: 7 additions & 3 deletions zoo/board_games/connect4/envs/connect4_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ class Connect4Env(BaseEnv):
# (str) The render mode. Options are 'None', 'state_realtime_mode', 'image_realtime_mode' or 'image_savefile_mode'.
# If None, then the game will not be rendered.
render_mode=None,
# (str) The suffix of the replay name.
replay_path='./video',
# (str) The type of the bot of the environment.
bot_action_type='rule',
# (bool) Whether to let human to play with the agent when evaluating. If False, then use the bot to evaluate the agent.
Expand Down Expand Up @@ -101,7 +103,7 @@ def __init__(self, cfg: dict = None) -> None:
# options = {None, 'state_realtime_mode', 'image_realtime_mode', 'image_savefile_mode'}
self.render_mode = cfg.render_mode
self.replay_name_suffix = "test"
self.replay_path = None
self.replay_path = cfg.replay_path
self.replay_format = 'gif'
self.screen = None
self.frames = []
Expand Down Expand Up @@ -473,9 +475,11 @@ def save_render_output(self, replay_name_suffix: str = '', replay_path: str = No
"""
# At the end of the episode, save the frames.
if replay_path is None:
filename = f'game_connect4_{replay_name_suffix}.{format}'
filename = f'connect4_{replay_name_suffix}.{format}'
else:
filename = f'{replay_path}.{format}'
if not os.path.exists(replay_path):
os.makedirs(replay_path)
filename = replay_path+f'/connect4_{replay_name_suffix}.{format}'

if format == 'gif':
# Save frames as a GIF with a duration of 0.1 seconds per frame.
Expand Down
31 changes: 22 additions & 9 deletions zoo/board_games/connect4/eval/connect4_alphazero_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,31 @@
from zoo.board_games.connect4.config.connect4_alphazero_bot_mode_config import main_config, create_config

if __name__ == '__main__':
"""
model_path (:obj:`Optional[str]`): The pretrained model path, which should
point to the ckpt file of the pretrained model, and an absolute path is recommended.
In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``.
"""
Entry point for the evaluation of the AlphaZero model on the Connect4 environment.
Variables:
- model_path (:obj:`Optional[str]`): The pretrained model path, which should point to the ckpt file of the
pretrained model. An absolute path is recommended. In LightZero, the path is usually something like
``exp_name/ckpt/ckpt_best.pth.tar``.
- returns_mean_seeds (:obj:`List[float]`): List to store the mean returns for each seed.
- returns_seeds (:obj:`List[float]`): List to store the returns for each seed.
- seeds (:obj:`List[int]`): List of seeds for the environment.
- num_episodes_each_seed (:obj:`int`): Number of episodes to run for each seed.
- total_test_episodes (:obj:`int`): Total number of test episodes, computed as the product of the number of
seeds and the number of episodes per seed.
"""
# model_path = './ckpt/ckpt_best.pth.tar'
model_path = None
seeds = [0]
num_episodes_each_seed = 1
# If True, you can play with the agent.
# main_config.env.agent_vs_human = True
main_config.env.agent_vs_human = False
# main_config.env.render_mode = 'image_savefile_mode'
main_config.env.render_mode = 'state_realtime_mode'
# main_config.env.render_mode = 'image_realtime_mode'
main_config.env.render_mode = 'image_savefile_mode'
main_config.env.replay_path = './video'

main_config.policy.mcts.num_simulations = 10
main_config.env.prob_random_action_in_bot = 0.
main_config.env.bot_action_type = 'rule'
Expand All @@ -40,9 +53,9 @@
returns_seeds = np.array(returns_seeds)

print("=" * 20)
print(f'We eval total {len(seeds)} seeds. In each seed, we eval {num_episodes_each_seed} episodes.')
print(f'In seeds {seeds}, returns_mean_seeds is {returns_mean_seeds}, returns is {returns_seeds}')
print('In all seeds, reward_mean:', returns_mean_seeds.mean(), end='. ')
print(f"We evaluated a total of {len(seeds)} seeds. For each seed, we evaluated {num_episodes_each_seed} episode(s).")
print(f"For seeds {seeds}, the mean returns are {returns_mean_seeds}, and the returns are {returns_seeds}.")
print("Across all seeds, the mean reward is:", returns_mean_seeds.mean())
print(
f'win rate: {len(np.where(returns_seeds == 1.)[0]) / total_test_episodes}, draw rate: {len(np.where(returns_seeds == 0.)[0]) / total_test_episodes}, lose rate: {len(np.where(returns_seeds == -1.)[0]) / total_test_episodes}'
)
Expand Down
60 changes: 60 additions & 0 deletions zoo/board_games/connect4/eval/connect4_muzero_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from zoo.board_games.connect4.config.connect4_muzero_bot_mode_config import main_config, create_config
from lzero.entry import eval_muzero
import numpy as np

if __name__ == '__main__':
"""
Entry point for the evaluation of the MuZero model on the Connect4 environment.
Variables:
- model_path (:obj:`Optional[str]`): The pretrained model path, which should point to the ckpt file of the
pretrained model. An absolute path is recommended. In LightZero, the path is usually something like
``exp_name/ckpt/ckpt_best.pth.tar``.
- returns_mean_seeds (:obj:`List[float]`): List to store the mean returns for each seed.
- returns_seeds (:obj:`List[float]`): List to store the returns for each seed.
- seeds (:obj:`List[int]`): List of seeds for the environment.
- num_episodes_each_seed (:obj:`int`): Number of episodes to run for each seed.
- total_test_episodes (:obj:`int`): Total number of test episodes, computed as the product of the number of
seeds and the number of episodes per seed.
"""
# model_path = './ckpt/ckpt_best.pth.tar'
model_path = None
seeds = [0]
num_episodes_each_seed = 1
# If True, you can play with the agent.
# main_config.env.agent_vs_human = True
main_config.env.agent_vs_human = False
# main_config.env.render_mode = 'image_realtime_mode'
main_config.env.render_mode = 'image_savefile_mode'
main_config.env.replay_path = './video'

main_config.env.prob_random_action_in_bot = 0.
main_config.env.bot_action_type = 'rule'
create_config.env_manager.type = 'base'
main_config.env.evaluator_env_num = 1
main_config.env.n_evaluator_episode = 1
total_test_episodes = num_episodes_each_seed * len(seeds)
returns_mean_seeds = []
returns_seeds = []
for seed in seeds:
returns_mean, returns = eval_muzero(
[main_config, create_config],
seed=seed,
num_episodes_each_seed=num_episodes_each_seed,
print_seed_details=True,
model_path=model_path
)
returns_mean_seeds.append(returns_mean)
returns_seeds.append(returns)

returns_mean_seeds = np.array(returns_mean_seeds)
returns_seeds = np.array(returns_seeds)

print("=" * 20)
print(f"We evaluated a total of {len(seeds)} seeds. For each seed, we evaluated {num_episodes_each_seed} episode(s).")
print(f"For seeds {seeds}, the mean returns are {returns_mean_seeds}, and the returns are {returns_seeds}.")
print("Across all seeds, the mean reward is:", returns_mean_seeds.mean())
print(
f'win rate: {len(np.where(returns_seeds == 1.)[0]) / total_test_episodes}, draw rate: {len(np.where(returns_seeds == 0.)[0]) / total_test_episodes}, lose rate: {len(np.where(returns_seeds == -1.)[0]) / total_test_episodes}'
)
print("=" * 20)
39 changes: 28 additions & 11 deletions zoo/board_games/gomoku/entry/gomoku_alphazero_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,35 @@
import numpy as np

if __name__ == '__main__':
"""
model_path (:obj:`Optional[str]`): The pretrained model path, which should
point to the ckpt file of the pretrained model, and an absolute path is recommended.
In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``.
"""
model_path = './ckpt/ckpt_best.pth.tar'
Entry point for the evaluation of the AlphaZero model on the Gomoku environment.
Variables:
- model_path (:obj:`Optional[str]`): The pretrained model path, which should point to the ckpt file of the
pretrained model. An absolute path is recommended. In LightZero, the path is usually something like
``exp_name/ckpt/ckpt_best.pth.tar``.
- returns_mean_seeds (:obj:`List[float]`): List to store the mean returns for each seed.
- returns_seeds (:obj:`List[float]`): List to store the returns for each seed.
- seeds (:obj:`List[int]`): List of seeds for the environment.
- num_episodes_each_seed (:obj:`int`): Number of episodes to run for each seed.
- total_test_episodes (:obj:`int`): Total number of test episodes, computed as the product of the number of
seeds and the number of episodes per seed.
"""
# model_path = './ckpt/ckpt_best.pth.tar'
model_path = None

seeds = [0]
num_episodes_each_seed = 5
num_episodes_each_seed = 1
# If True, you can play with the agent.
main_config.env.agent_vs_human = True
main_config.env.render_mode = 'image_realtime_mode'
# main_config.env.agent_vs_human = True
main_config.env.agent_vs_human = False
# main_config.env.render_mode = 'image_realtime_mode'
main_config.env.render_mode = 'image_savefile_mode'
main_config.env.replay_path = './video'

create_config.env_manager.type = 'base'
main_config.env.alphazero_mcts_ctree = False
main_config.policy.mcts_ctree = False
main_config.env.evaluator_env_num = 1
main_config.env.n_evaluator_episode = 1
total_test_episodes = num_episodes_each_seed * len(seeds)
Expand All @@ -35,9 +52,9 @@
returns_seeds = np.array(returns_seeds)

print("=" * 20)
print(f'We eval total {len(seeds)} seeds. In each seed, we eval {num_episodes_each_seed} episodes.')
print(f'In seeds {seeds}, returns_mean_seeds is {returns_mean_seeds}, returns is {returns_seeds}')
print('In all seeds, reward_mean:', returns_mean_seeds.mean(), end='. ')
print(f"We evaluated a total of {len(seeds)} seeds. For each seed, we evaluated {num_episodes_each_seed} episode(s).")
print(f"For seeds {seeds}, the mean returns are {returns_mean_seeds}, and the returns are {returns_seeds}.")
print("Across all seeds, the mean reward is:", returns_mean_seeds.mean())
print(
f'win rate: {len(np.where(returns_seeds == 1.)[0]) / total_test_episodes}, draw rate: {len(np.where(returns_seeds == 0.)[0]) / total_test_episodes}, lose rate: {len(np.where(returns_seeds == -1.)[0]) / total_test_episodes}'
)
Expand Down
Loading

0 comments on commit d3aaccd

Please sign in to comment.