diff --git a/lzero/mcts/ctree/ctree_sampled_efficientzero/lib/cnode.cpp b/lzero/mcts/ctree/ctree_sampled_efficientzero/lib/cnode.cpp index e9a6cc628..29bb30a21 100644 --- a/lzero/mcts/ctree/ctree_sampled_efficientzero/lib/cnode.cpp +++ b/lzero/mcts/ctree/ctree_sampled_efficientzero/lib/cnode.cpp @@ -380,12 +380,12 @@ namespace tree // After sorting, the first vector is the index, and the second vector is the probability value after perturbation sorted from large to small. for (size_t iter = 0; iter < disturbed_probs.size(); iter++) { - #ifdef __APPLE__ - disc_action_with_probs.__emplace_back(std::make_pair(iter, disturbed_probs[iter])); - #else - disc_action_with_probs.emplace_back(std::make_pair(iter, disturbed_probs[iter])); - #endif - // disc_action_with_probs.emplace_back(std::make_pair(iter, disturbed_probs[iter])); +// #ifdef __APPLE__ +// disc_action_with_probs.__emplace_back(std::make_pair(iter, disturbed_probs[iter])); +// #else +// disc_action_with_probs.emplace_back(std::make_pair(iter, disturbed_probs[iter])); +// #endif + disc_action_with_probs.emplace_back(std::make_pair(iter, disturbed_probs[iter])); } std::sort(disc_action_with_probs.begin(), disc_action_with_probs.end(), cmp); diff --git a/lzero/policy/alphazero.py b/lzero/policy/alphazero.py index a9753b91f..3967eb92f 100644 --- a/lzero/policy/alphazero.py +++ b/lzero/policy/alphazero.py @@ -246,18 +246,20 @@ def _forward_collect(self, obs: Dict, temperature: float = 1) -> Dict[str, torch self._collect_mcts_temperature = temperature ready_env_id = list(obs.keys()) init_state = {env_id: obs[env_id]['board'] for env_id in ready_env_id} + # If 'katago_game_state' is in the observation of the given environment ID, it's value is used. + # If it's not present (which will raise a KeyError), None is used instead. + # This approach is taken to maintain compatibility with the handling of 'katago' related parts of 'alphazero_mcts_ctree' in Go. + katago_game_state = {env_id: obs[env_id].get('katago_game_state', None) for env_id in ready_env_id} start_player_index = {env_id: obs[env_id]['current_player_index'] for env_id in ready_env_id} output = {} self._policy_model = self._collect_model for env_id in ready_env_id: state_config_for_simulation_env_reset = EasyDict(dict(start_player_index=start_player_index[env_id], - init_state=init_state[env_id], )) - action, mcts_probs = self._collect_mcts.get_next_action( - state_config_for_simulation_env_reset, - policy_forward_fn=self._policy_value_fn, - temperature=self._collect_mcts_temperature, - sample=True - ) + init_state=init_state[env_id], + katago_policy_init=False, + katago_game_state=katago_game_state[env_id])) + action, mcts_probs = self._collect_mcts.get_next_action(state_config_for_simulation_env_reset, self._policy_value_fn, self._collect_mcts_temperature, True) + output[env_id] = { 'action': action, 'probs': mcts_probs, @@ -305,15 +307,20 @@ def _forward_eval(self, obs: Dict) -> Dict[str, torch.Tensor]: """ ready_env_id = list(obs.keys()) init_state = {env_id: obs[env_id]['board'] for env_id in ready_env_id} + # If 'katago_game_state' is in the observation of the given environment ID, it's value is used. + # If it's not present (which will raise a KeyError), None is used instead. + # This approach is taken to maintain compatibility with the handling of 'katago' related parts of 'alphazero_mcts_ctree' in Go. + katago_game_state = {env_id: obs[env_id].get('katago_game_state', None) for env_id in ready_env_id} start_player_index = {env_id: obs[env_id]['current_player_index'] for env_id in ready_env_id} output = {} self._policy_model = self._eval_model for env_id in ready_env_id: state_config_for_simulation_env_reset = EasyDict(dict(start_player_index=start_player_index[env_id], - init_state=init_state[env_id], )) + init_state=init_state[env_id], + katago_policy_init=False, + katago_game_state=katago_game_state[env_id])) action, mcts_probs = self._eval_mcts.get_next_action( - state_config_for_simulation_env_reset, policy_forward_fn=self._policy_value_fn, temperature=1.0, - sample=False + state_config_for_simulation_env_reset, self._policy_value_fn, 1.0, False ) output[env_id] = { 'action': action, diff --git a/zoo/board_games/gomoku/config/gomoku_alphazero_bot_mode_config.py b/zoo/board_games/gomoku/config/gomoku_alphazero_bot_mode_config.py index 6d65124b1..2cd6d710f 100644 --- a/zoo/board_games/gomoku/config/gomoku_alphazero_bot_mode_config.py +++ b/zoo/board_games/gomoku/config/gomoku_alphazero_bot_mode_config.py @@ -18,7 +18,7 @@ # ============================================================== gomoku_alphazero_config = dict( exp_name= - f'data_az_ptree/gomoku_alphazero_bot-mode_rand{prob_random_action_in_bot}_ns{num_simulations}_upc{update_per_collect}_seed0', + f'data_az_ctree/gomoku_alphazero_bot-mode_rand{prob_random_action_in_bot}_ns{num_simulations}_upc{update_per_collect}_seed0', env=dict( board_size=board_size, battle_mode='play_with_bot_mode', @@ -35,10 +35,9 @@ prob_random_agent=0, prob_expert_agent=0, scale=True, - check_action_to_connect4_in_bot_v0=False, - mcts_ctree=mcts_ctree, screen_scaling=9, render_mode=None, + alphazero_mcts_ctree=mcts_ctree, # ============================================================== ), policy=dict( diff --git a/zoo/board_games/gomoku/config/gomoku_alphazero_sp_mode_config.py b/zoo/board_games/gomoku/config/gomoku_alphazero_sp_mode_config.py index 22eb87564..43cc766e0 100644 --- a/zoo/board_games/gomoku/config/gomoku_alphazero_sp_mode_config.py +++ b/zoo/board_games/gomoku/config/gomoku_alphazero_sp_mode_config.py @@ -12,18 +12,19 @@ batch_size = 256 max_env_step = int(5e5) prob_random_action_in_bot = 0.5 +mcts_ctree = True # ============================================================== # end of the most frequently changed config specified by the user # ============================================================== gomoku_alphazero_config = dict( exp_name= - f'data_az_ptree/gomoku_alphazero_sp-mode_rand{prob_random_action_in_bot}_ns{num_simulations}_upc{update_per_collect}_seed0', + f'data_az_ctree/gomoku_alphazero_sp-mode_rand{prob_random_action_in_bot}_ns{num_simulations}_upc{update_per_collect}_seed0', env=dict( board_size=board_size, battle_mode='self_play_mode', - bot_action_type='v0', + bot_action_type='v1', prob_random_action_in_bot=prob_random_action_in_bot, - channel_last=False, # NOTE + channel_last=False, collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, n_evaluator_episode=evaluator_env_num, @@ -34,10 +35,13 @@ prob_random_agent=0, prob_expert_agent=0, scale=True, - check_action_to_connect4_in_bot_v0=False, + screen_scaling=9, + render_mode=None, + alphazero_mcts_ctree=mcts_ctree, # ============================================================== ), policy=dict( + mcts_ctree=mcts_ctree, # ============================================================== # for the creation of simulation env simulation_env_name='gomoku', diff --git a/zoo/board_games/gomoku/envs/gomoku_env.py b/zoo/board_games/gomoku/envs/gomoku_env.py index aa4d4bd29..a20f16654 100644 --- a/zoo/board_games/gomoku/envs/gomoku_env.py +++ b/zoo/board_games/gomoku/envs/gomoku_env.py @@ -68,10 +68,10 @@ class GomokuEnv(BaseEnv): prob_random_agent=0, # (float) The probability that a random action will be taken when calling the bot. prob_random_action_in_bot=0., - # (bool) Whether to check the action to connect 4 in the bot v0. - check_action_to_connect4_in_bot_v0=False, # (float) The stop value when training the agent. If the evalue return reach the stop value, then the training will stop. stop_value=2, + # (bool) Whether to use the MCTS ctree in AlphaZero. If True, then the AlphaZero MCTS ctree will be used. + alphazero_mcts_ctree=False, ) @classmethod @@ -117,7 +117,6 @@ def __init__(self, cfg: dict = None): self.prob_random_action_in_bot = cfg.prob_random_action_in_bot self.channel_last = cfg.channel_last self.scale = cfg.scale - self.check_action_to_connect4_in_bot_v0 = cfg.check_action_to_connect4_in_bot_v0 self.agent_vs_human = cfg.agent_vs_human self.bot_action_type = cfg.bot_action_type @@ -142,11 +141,30 @@ def __init__(self, cfg: dict = None): self.alpha_beta_pruning_player = AlphaBetaPruningBot(self, cfg, 'alpha_beta_pruning_player') elif self.bot_action_type == 'v0': self.rule_bot = GomokuRuleBotV0(self, self._current_player) + self.alphazero_mcts_ctree = cfg.alphazero_mcts_ctree + if not self.alphazero_mcts_ctree: + # plt is not work in mcts_ctree mode + self.fig, self.ax = plt.subplots(figsize=(self.board_size, self.board_size)) + plt.ion() - self.fig, self.ax = plt.subplots(figsize=(self.board_size, self.board_size)) - plt.ion() + def reset(self, start_player_index=0, init_state=None, katago_policy_init=False, katago_game_state=None): + """ + Overview: + This method resets the environment and optionally starts with a custom state specified by 'init_state'. + Arguments: + - start_player_index (:obj:`int`, optional): Specifies the starting player. The players are [1,2] and + their corresponding indices are [0,1]. Defaults to 0. + - init_state (:obj:`Any`, optional): The custom starting state. If provided, the game starts from this state. + Defaults to None. + - katago_policy_init (:obj:`bool`, optional): This parameter is used to maintain compatibility with the + handling of 'katago' related parts in 'alphazero_mcts_ctree' in Go. Defaults to False. + - katago_game_state (:obj:`Any`, optional): This parameter is similar to 'katago_policy_init' and is used to + maintain compatibility with 'katago' in 'alphazero_mcts_ctree'. Defaults to None. + """ + if self.alphazero_mcts_ctree and init_state is not None: + # Convert byte string to np.ndarray + init_state = np.frombuffer(init_state, dtype=np.int32) - def reset(self, start_player_index=0, init_state=None): self._observation_space = gym.spaces.Box( low=0, high=2, shape=(self.board_size, self.board_size, 3), dtype=np.int32 ) @@ -156,6 +174,8 @@ def reset(self, start_player_index=0, init_state=None): self._current_player = self.players[self.start_player_index] if init_state is not None: self.board = np.array(copy.deepcopy(init_state), dtype="int32") + if self.alphazero_mcts_ctree: + self.board = self.board.reshape((self.board_size, self.board_size)) else: self.board = np.zeros((self.board_size, self.board_size), dtype="int32") action_mask = np.zeros(self.total_num_actions, 'int8') diff --git a/zoo/board_games/tictactoe/config/tictactoe_alphazero_bot_mode_config.py b/zoo/board_games/tictactoe/config/tictactoe_alphazero_bot_mode_config.py index 3fd9c2bae..42c9cb026 100644 --- a/zoo/board_games/tictactoe/config/tictactoe_alphazero_bot_mode_config.py +++ b/zoo/board_games/tictactoe/config/tictactoe_alphazero_bot_mode_config.py @@ -16,12 +16,12 @@ # ============================================================== tictactoe_alphazero_config = dict( - exp_name=f'data_az_ptree/tictactoe_alphazero_bot-mode_ns{num_simulations}_upc{update_per_collect}_seed0', + exp_name=f'data_az_ctree/tictactoe_alphazero_bot-mode_ns{num_simulations}_upc{update_per_collect}_seed0', env=dict( board_size=3, battle_mode='play_with_bot_mode', bot_action_type='v0', # {'v0', 'alpha_beta_pruning'} - channel_last=False, # NOTE + channel_last=False, collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, n_evaluator_episode=evaluator_env_num, @@ -32,16 +32,16 @@ prob_random_agent=0, prob_expert_agent=0, scale=True, - mcts_ctree=mcts_ctree, + alphazero_mcts_ctree=mcts_ctree, # ============================================================== ), policy=dict( + mcts_ctree=mcts_ctree, # ============================================================== # for the creation of simulation env simulation_env_name='tictactoe', simulation_env_config_type='play_with_bot', # ============================================================== - mcts_ctree=mcts_ctree, model=dict( observation_shape=(3, 3, 3), action_space_size=int(1 * 3 * 3), diff --git a/zoo/board_games/tictactoe/config/tictactoe_alphazero_sp_mode_config.py b/zoo/board_games/tictactoe/config/tictactoe_alphazero_sp_mode_config.py index 07a04943d..5f9ed1345 100644 --- a/zoo/board_games/tictactoe/config/tictactoe_alphazero_sp_mode_config.py +++ b/zoo/board_games/tictactoe/config/tictactoe_alphazero_sp_mode_config.py @@ -10,16 +10,17 @@ update_per_collect = 50 batch_size = 256 max_env_step = int(2e5) +mcts_ctree = True # ============================================================== # end of the most frequently changed config specified by the user # ============================================================== tictactoe_alphazero_config = dict( - exp_name='data_az_ptree/tictactoe_sp-mode_alphazero_seed0', + exp_name='data_az_ctree/tictactoe_sp-mode_alphazero_seed0', env=dict( board_size=3, battle_mode='self_play_mode', bot_action_type='v0', # {'v0', 'alpha_beta_pruning'} - channel_last=False, # NOTE + channel_last=False, collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, n_evaluator_episode=evaluator_env_num, @@ -30,9 +31,11 @@ prob_random_agent=0, prob_expert_agent=0, scale=True, + alphazero_mcts_ctree=mcts_ctree, # ============================================================== ), policy=dict( + mcts_ctree=mcts_ctree, # ============================================================== # for the creation of simulation env simulation_env_name='tictactoe', diff --git a/zoo/board_games/tictactoe/envs/tictactoe_env.py b/zoo/board_games/tictactoe/envs/tictactoe_env.py index 9c0ec3633..df30e92dd 100644 --- a/zoo/board_games/tictactoe/envs/tictactoe_env.py +++ b/zoo/board_games/tictactoe/envs/tictactoe_env.py @@ -46,6 +46,7 @@ class TicTacToeEnv(BaseEnv): channel_last=True, scale=True, stop_value=1, + alphazero_mcts_ctree=False, ) @classmethod @@ -76,6 +77,7 @@ def __init__(self, cfg=None): self.bot_action_type = cfg.bot_action_type if 'alpha_beta_pruning' in self.bot_action_type: self.alpha_beta_pruning_player = AlphaBetaPruningBot(self, cfg, 'alpha_beta_pruning_player') + self.alphazero_mcts_ctree = cfg.alphazero_mcts_ctree @property def legal_actions(self): @@ -94,17 +96,37 @@ def legal_actions_cython_lru(self): return _legal_actions_func_lru(tuple(map(tuple, self.board))) def get_done_winner(self): + """ + Overview: + Check if the game is over and who the winner is. Return 'done' and 'winner'. + Returns: + - outputs (:obj:`Tuple`): Tuple containing 'done' and 'winner', + - if player 1 win, 'done' = True, 'winner' = 1 + - if player 2 win, 'done' = True, 'winner' = 2 + - if draw, 'done' = True, 'winner' = -1 + - if game is not over, 'done' = False, 'winner' = -1 + """ # Convert NumPy arrays to nested tuples to make them hashable. return _get_done_winner_func_lru(tuple(map(tuple, self.board))) - def reset(self, start_player_index=0, init_state=None): + def reset(self, start_player_index=0, init_state=None, katago_policy_init=False, katago_game_state=None): """ Overview: - Env reset and custom state start by init_state + This method resets the environment and optionally starts with a custom state specified by 'init_state'. Arguments: - start_player_index: players = [1,2], player_index = [0,1] - init_state: custom start state. + - start_player_index (:obj:`int`, optional): Specifies the starting player. The players are [1,2] and + their corresponding indices are [0,1]. Defaults to 0. + - init_state (:obj:`Any`, optional): The custom starting state. If provided, the game starts from this state. + Defaults to None. + - katago_policy_init (:obj:`bool`, optional): This parameter is used to maintain compatibility with the + handling of 'katago' related parts in 'alphazero_mcts_ctree' in Go. Defaults to False. + - katago_game_state (:obj:`Any`, optional): This parameter is similar to 'katago_policy_init' and is used to + maintain compatibility with 'katago' in 'alphazero_mcts_ctree'. Defaults to None. """ + if self.alphazero_mcts_ctree and init_state is not None: + # Convert byte string to np.ndarray + init_state = np.frombuffer(init_state, dtype=np.int32) + if self.scale: self._observation_space = gym.spaces.Box( low=0, high=1, shape=(self.board_size, self.board_size, 3), dtype=np.float32 @@ -119,6 +141,8 @@ def reset(self, start_player_index=0, init_state=None): self._current_player = self.players[self.start_player_index] if init_state is not None: self.board = np.array(copy.deepcopy(init_state), dtype="int32") + if self.alphazero_mcts_ctree: + self.board = self.board.reshape((self.board_size, self.board_size)) else: self.board = np.zeros((self.board_size, self.board_size), dtype="int32") @@ -247,11 +271,11 @@ def _player_step(self, action): done, winner = self.get_done_winner() reward = np.array(float(winner == self.current_player)).astype(np.float32) - info = {'next player to play': self.to_play} + info = {'next player to play': self.next_player} """ NOTE: here exchange the player """ - self.current_player = self.to_play + self.current_player = self.next_player if done: info['eval_episode_return'] = reward @@ -277,11 +301,11 @@ def current_state(self): Returns: - current_state (:obj:`array`): the 0 dim means which positions is occupied by self.current_player, - the 1 dim indicates which positions are occupied by self.to_play, + the 1 dim indicates which positions are occupied by self.next_player, the 2 dim indicates which player is the to_play player, 1 means player 1, 2 means player 2 """ board_curr_player = np.where(self.board == self.current_player, 1, 0) - board_opponent_player = np.where(self.board == self.to_play, 1, 0) + board_opponent_player = np.where(self.board == self.next_player, 1, 0) board_to_play = np.full((self.board_size, self.board_size), self.current_player) raw_obs = np.array([board_curr_player, board_opponent_player, board_to_play], dtype=np.float32) if self.scale: @@ -413,13 +437,14 @@ def current_player(self): @property def current_player_index(self): """ - current_player_index = 0, current_player = 1 - current_player_index = 1, current_player = 2 + Overview: + current_player_index = 0, current_player = 1 + current_player_index = 1, current_player = 2 """ return 0 if self._current_player == 1 else 1 @property - def to_play(self): + def next_player(self): return self.players[0] if self.current_player == self.players[1] else self.players[1] @property