Skip to content

Commit

Permalink
fix(pu): fix tictactoe and gomoku env to be compatibility with alphaz…
Browse files Browse the repository at this point in the history
…ero ctree (#148)
  • Loading branch information
puyuan1996 authored Nov 26, 2023
1 parent cfe3c3c commit e2531ce
Show file tree
Hide file tree
Showing 8 changed files with 104 additions and 46 deletions.
12 changes: 6 additions & 6 deletions lzero/mcts/ctree/ctree_sampled_efficientzero/lib/cnode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,12 +380,12 @@ namespace tree
// After sorting, the first vector is the index, and the second vector is the probability value after perturbation sorted from large to small.
for (size_t iter = 0; iter < disturbed_probs.size(); iter++)
{
#ifdef __APPLE__
disc_action_with_probs.__emplace_back(std::make_pair(iter, disturbed_probs[iter]));
#else
disc_action_with_probs.emplace_back(std::make_pair(iter, disturbed_probs[iter]));
#endif
// disc_action_with_probs.emplace_back(std::make_pair(iter, disturbed_probs[iter]));
// #ifdef __APPLE__
// disc_action_with_probs.__emplace_back(std::make_pair(iter, disturbed_probs[iter]));
// #else
// disc_action_with_probs.emplace_back(std::make_pair(iter, disturbed_probs[iter]));
// #endif
disc_action_with_probs.emplace_back(std::make_pair(iter, disturbed_probs[iter]));
}

std::sort(disc_action_with_probs.begin(), disc_action_with_probs.end(), cmp);
Expand Down
27 changes: 17 additions & 10 deletions lzero/policy/alphazero.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,18 +246,20 @@ def _forward_collect(self, obs: Dict, temperature: float = 1) -> Dict[str, torch
self._collect_mcts_temperature = temperature
ready_env_id = list(obs.keys())
init_state = {env_id: obs[env_id]['board'] for env_id in ready_env_id}
# If 'katago_game_state' is in the observation of the given environment ID, it's value is used.
# If it's not present (which will raise a KeyError), None is used instead.
# This approach is taken to maintain compatibility with the handling of 'katago' related parts of 'alphazero_mcts_ctree' in Go.
katago_game_state = {env_id: obs[env_id].get('katago_game_state', None) for env_id in ready_env_id}
start_player_index = {env_id: obs[env_id]['current_player_index'] for env_id in ready_env_id}
output = {}
self._policy_model = self._collect_model
for env_id in ready_env_id:
state_config_for_simulation_env_reset = EasyDict(dict(start_player_index=start_player_index[env_id],
init_state=init_state[env_id], ))
action, mcts_probs = self._collect_mcts.get_next_action(
state_config_for_simulation_env_reset,
policy_forward_fn=self._policy_value_fn,
temperature=self._collect_mcts_temperature,
sample=True
)
init_state=init_state[env_id],
katago_policy_init=False,
katago_game_state=katago_game_state[env_id]))
action, mcts_probs = self._collect_mcts.get_next_action(state_config_for_simulation_env_reset, self._policy_value_fn, self._collect_mcts_temperature, True)

output[env_id] = {
'action': action,
'probs': mcts_probs,
Expand Down Expand Up @@ -305,15 +307,20 @@ def _forward_eval(self, obs: Dict) -> Dict[str, torch.Tensor]:
"""
ready_env_id = list(obs.keys())
init_state = {env_id: obs[env_id]['board'] for env_id in ready_env_id}
# If 'katago_game_state' is in the observation of the given environment ID, it's value is used.
# If it's not present (which will raise a KeyError), None is used instead.
# This approach is taken to maintain compatibility with the handling of 'katago' related parts of 'alphazero_mcts_ctree' in Go.
katago_game_state = {env_id: obs[env_id].get('katago_game_state', None) for env_id in ready_env_id}
start_player_index = {env_id: obs[env_id]['current_player_index'] for env_id in ready_env_id}
output = {}
self._policy_model = self._eval_model
for env_id in ready_env_id:
state_config_for_simulation_env_reset = EasyDict(dict(start_player_index=start_player_index[env_id],
init_state=init_state[env_id], ))
init_state=init_state[env_id],
katago_policy_init=False,
katago_game_state=katago_game_state[env_id]))
action, mcts_probs = self._eval_mcts.get_next_action(
state_config_for_simulation_env_reset, policy_forward_fn=self._policy_value_fn, temperature=1.0,
sample=False
state_config_for_simulation_env_reset, self._policy_value_fn, 1.0, False
)
output[env_id] = {
'action': action,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# ==============================================================
gomoku_alphazero_config = dict(
exp_name=
f'data_az_ptree/gomoku_alphazero_bot-mode_rand{prob_random_action_in_bot}_ns{num_simulations}_upc{update_per_collect}_seed0',
f'data_az_ctree/gomoku_alphazero_bot-mode_rand{prob_random_action_in_bot}_ns{num_simulations}_upc{update_per_collect}_seed0',
env=dict(
board_size=board_size,
battle_mode='play_with_bot_mode',
Expand All @@ -35,10 +35,9 @@
prob_random_agent=0,
prob_expert_agent=0,
scale=True,
check_action_to_connect4_in_bot_v0=False,
mcts_ctree=mcts_ctree,
screen_scaling=9,
render_mode=None,
alphazero_mcts_ctree=mcts_ctree,
# ==============================================================
),
policy=dict(
Expand Down
12 changes: 8 additions & 4 deletions zoo/board_games/gomoku/config/gomoku_alphazero_sp_mode_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,19 @@
batch_size = 256
max_env_step = int(5e5)
prob_random_action_in_bot = 0.5
mcts_ctree = True
# ==============================================================
# end of the most frequently changed config specified by the user
# ==============================================================
gomoku_alphazero_config = dict(
exp_name=
f'data_az_ptree/gomoku_alphazero_sp-mode_rand{prob_random_action_in_bot}_ns{num_simulations}_upc{update_per_collect}_seed0',
f'data_az_ctree/gomoku_alphazero_sp-mode_rand{prob_random_action_in_bot}_ns{num_simulations}_upc{update_per_collect}_seed0',
env=dict(
board_size=board_size,
battle_mode='self_play_mode',
bot_action_type='v0',
bot_action_type='v1',
prob_random_action_in_bot=prob_random_action_in_bot,
channel_last=False, # NOTE
channel_last=False,
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
n_evaluator_episode=evaluator_env_num,
Expand All @@ -34,10 +35,13 @@
prob_random_agent=0,
prob_expert_agent=0,
scale=True,
check_action_to_connect4_in_bot_v0=False,
screen_scaling=9,
render_mode=None,
alphazero_mcts_ctree=mcts_ctree,
# ==============================================================
),
policy=dict(
mcts_ctree=mcts_ctree,
# ==============================================================
# for the creation of simulation env
simulation_env_name='gomoku',
Expand Down
32 changes: 26 additions & 6 deletions zoo/board_games/gomoku/envs/gomoku_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,10 @@ class GomokuEnv(BaseEnv):
prob_random_agent=0,
# (float) The probability that a random action will be taken when calling the bot.
prob_random_action_in_bot=0.,
# (bool) Whether to check the action to connect 4 in the bot v0.
check_action_to_connect4_in_bot_v0=False,
# (float) The stop value when training the agent. If the evalue return reach the stop value, then the training will stop.
stop_value=2,
# (bool) Whether to use the MCTS ctree in AlphaZero. If True, then the AlphaZero MCTS ctree will be used.
alphazero_mcts_ctree=False,
)

@classmethod
Expand Down Expand Up @@ -117,7 +117,6 @@ def __init__(self, cfg: dict = None):
self.prob_random_action_in_bot = cfg.prob_random_action_in_bot
self.channel_last = cfg.channel_last
self.scale = cfg.scale
self.check_action_to_connect4_in_bot_v0 = cfg.check_action_to_connect4_in_bot_v0
self.agent_vs_human = cfg.agent_vs_human
self.bot_action_type = cfg.bot_action_type

Expand All @@ -142,11 +141,30 @@ def __init__(self, cfg: dict = None):
self.alpha_beta_pruning_player = AlphaBetaPruningBot(self, cfg, 'alpha_beta_pruning_player')
elif self.bot_action_type == 'v0':
self.rule_bot = GomokuRuleBotV0(self, self._current_player)
self.alphazero_mcts_ctree = cfg.alphazero_mcts_ctree
if not self.alphazero_mcts_ctree:
# plt is not work in mcts_ctree mode
self.fig, self.ax = plt.subplots(figsize=(self.board_size, self.board_size))
plt.ion()

self.fig, self.ax = plt.subplots(figsize=(self.board_size, self.board_size))
plt.ion()
def reset(self, start_player_index=0, init_state=None, katago_policy_init=False, katago_game_state=None):
"""
Overview:
This method resets the environment and optionally starts with a custom state specified by 'init_state'.
Arguments:
- start_player_index (:obj:`int`, optional): Specifies the starting player. The players are [1,2] and
their corresponding indices are [0,1]. Defaults to 0.
- init_state (:obj:`Any`, optional): The custom starting state. If provided, the game starts from this state.
Defaults to None.
- katago_policy_init (:obj:`bool`, optional): This parameter is used to maintain compatibility with the
handling of 'katago' related parts in 'alphazero_mcts_ctree' in Go. Defaults to False.
- katago_game_state (:obj:`Any`, optional): This parameter is similar to 'katago_policy_init' and is used to
maintain compatibility with 'katago' in 'alphazero_mcts_ctree'. Defaults to None.
"""
if self.alphazero_mcts_ctree and init_state is not None:
# Convert byte string to np.ndarray
init_state = np.frombuffer(init_state, dtype=np.int32)

def reset(self, start_player_index=0, init_state=None):
self._observation_space = gym.spaces.Box(
low=0, high=2, shape=(self.board_size, self.board_size, 3), dtype=np.int32
)
Expand All @@ -156,6 +174,8 @@ def reset(self, start_player_index=0, init_state=None):
self._current_player = self.players[self.start_player_index]
if init_state is not None:
self.board = np.array(copy.deepcopy(init_state), dtype="int32")
if self.alphazero_mcts_ctree:
self.board = self.board.reshape((self.board_size, self.board_size))
else:
self.board = np.zeros((self.board_size, self.board_size), dtype="int32")
action_mask = np.zeros(self.total_num_actions, 'int8')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
# ==============================================================

tictactoe_alphazero_config = dict(
exp_name=f'data_az_ptree/tictactoe_alphazero_bot-mode_ns{num_simulations}_upc{update_per_collect}_seed0',
exp_name=f'data_az_ctree/tictactoe_alphazero_bot-mode_ns{num_simulations}_upc{update_per_collect}_seed0',
env=dict(
board_size=3,
battle_mode='play_with_bot_mode',
bot_action_type='v0', # {'v0', 'alpha_beta_pruning'}
channel_last=False, # NOTE
channel_last=False,
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
n_evaluator_episode=evaluator_env_num,
Expand All @@ -32,16 +32,16 @@
prob_random_agent=0,
prob_expert_agent=0,
scale=True,
mcts_ctree=mcts_ctree,
alphazero_mcts_ctree=mcts_ctree,
# ==============================================================
),
policy=dict(
mcts_ctree=mcts_ctree,
# ==============================================================
# for the creation of simulation env
simulation_env_name='tictactoe',
simulation_env_config_type='play_with_bot',
# ==============================================================
mcts_ctree=mcts_ctree,
model=dict(
observation_shape=(3, 3, 3),
action_space_size=int(1 * 3 * 3),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,17 @@
update_per_collect = 50
batch_size = 256
max_env_step = int(2e5)
mcts_ctree = True
# ==============================================================
# end of the most frequently changed config specified by the user
# ==============================================================
tictactoe_alphazero_config = dict(
exp_name='data_az_ptree/tictactoe_sp-mode_alphazero_seed0',
exp_name='data_az_ctree/tictactoe_sp-mode_alphazero_seed0',
env=dict(
board_size=3,
battle_mode='self_play_mode',
bot_action_type='v0', # {'v0', 'alpha_beta_pruning'}
channel_last=False, # NOTE
channel_last=False,
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
n_evaluator_episode=evaluator_env_num,
Expand All @@ -30,9 +31,11 @@
prob_random_agent=0,
prob_expert_agent=0,
scale=True,
alphazero_mcts_ctree=mcts_ctree,
# ==============================================================
),
policy=dict(
mcts_ctree=mcts_ctree,
# ==============================================================
# for the creation of simulation env
simulation_env_name='tictactoe',
Expand Down
47 changes: 36 additions & 11 deletions zoo/board_games/tictactoe/envs/tictactoe_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class TicTacToeEnv(BaseEnv):
channel_last=True,
scale=True,
stop_value=1,
alphazero_mcts_ctree=False,
)

@classmethod
Expand Down Expand Up @@ -76,6 +77,7 @@ def __init__(self, cfg=None):
self.bot_action_type = cfg.bot_action_type
if 'alpha_beta_pruning' in self.bot_action_type:
self.alpha_beta_pruning_player = AlphaBetaPruningBot(self, cfg, 'alpha_beta_pruning_player')
self.alphazero_mcts_ctree = cfg.alphazero_mcts_ctree

@property
def legal_actions(self):
Expand All @@ -94,17 +96,37 @@ def legal_actions_cython_lru(self):
return _legal_actions_func_lru(tuple(map(tuple, self.board)))

def get_done_winner(self):
"""
Overview:
Check if the game is over and who the winner is. Return 'done' and 'winner'.
Returns:
- outputs (:obj:`Tuple`): Tuple containing 'done' and 'winner',
- if player 1 win, 'done' = True, 'winner' = 1
- if player 2 win, 'done' = True, 'winner' = 2
- if draw, 'done' = True, 'winner' = -1
- if game is not over, 'done' = False, 'winner' = -1
"""
# Convert NumPy arrays to nested tuples to make them hashable.
return _get_done_winner_func_lru(tuple(map(tuple, self.board)))

def reset(self, start_player_index=0, init_state=None):
def reset(self, start_player_index=0, init_state=None, katago_policy_init=False, katago_game_state=None):
"""
Overview:
Env reset and custom state start by init_state
This method resets the environment and optionally starts with a custom state specified by 'init_state'.
Arguments:
start_player_index: players = [1,2], player_index = [0,1]
init_state: custom start state.
- start_player_index (:obj:`int`, optional): Specifies the starting player. The players are [1,2] and
their corresponding indices are [0,1]. Defaults to 0.
- init_state (:obj:`Any`, optional): The custom starting state. If provided, the game starts from this state.
Defaults to None.
- katago_policy_init (:obj:`bool`, optional): This parameter is used to maintain compatibility with the
handling of 'katago' related parts in 'alphazero_mcts_ctree' in Go. Defaults to False.
- katago_game_state (:obj:`Any`, optional): This parameter is similar to 'katago_policy_init' and is used to
maintain compatibility with 'katago' in 'alphazero_mcts_ctree'. Defaults to None.
"""
if self.alphazero_mcts_ctree and init_state is not None:
# Convert byte string to np.ndarray
init_state = np.frombuffer(init_state, dtype=np.int32)

if self.scale:
self._observation_space = gym.spaces.Box(
low=0, high=1, shape=(self.board_size, self.board_size, 3), dtype=np.float32
Expand All @@ -119,6 +141,8 @@ def reset(self, start_player_index=0, init_state=None):
self._current_player = self.players[self.start_player_index]
if init_state is not None:
self.board = np.array(copy.deepcopy(init_state), dtype="int32")
if self.alphazero_mcts_ctree:
self.board = self.board.reshape((self.board_size, self.board_size))
else:
self.board = np.zeros((self.board_size, self.board_size), dtype="int32")

Expand Down Expand Up @@ -247,11 +271,11 @@ def _player_step(self, action):
done, winner = self.get_done_winner()

reward = np.array(float(winner == self.current_player)).astype(np.float32)
info = {'next player to play': self.to_play}
info = {'next player to play': self.next_player}
"""
NOTE: here exchange the player
"""
self.current_player = self.to_play
self.current_player = self.next_player

if done:
info['eval_episode_return'] = reward
Expand All @@ -277,11 +301,11 @@ def current_state(self):
Returns:
- current_state (:obj:`array`):
the 0 dim means which positions is occupied by self.current_player,
the 1 dim indicates which positions are occupied by self.to_play,
the 1 dim indicates which positions are occupied by self.next_player,
the 2 dim indicates which player is the to_play player, 1 means player 1, 2 means player 2
"""
board_curr_player = np.where(self.board == self.current_player, 1, 0)
board_opponent_player = np.where(self.board == self.to_play, 1, 0)
board_opponent_player = np.where(self.board == self.next_player, 1, 0)
board_to_play = np.full((self.board_size, self.board_size), self.current_player)
raw_obs = np.array([board_curr_player, board_opponent_player, board_to_play], dtype=np.float32)
if self.scale:
Expand Down Expand Up @@ -413,13 +437,14 @@ def current_player(self):
@property
def current_player_index(self):
"""
current_player_index = 0, current_player = 1
current_player_index = 1, current_player = 2
Overview:
current_player_index = 0, current_player = 1
current_player_index = 1, current_player = 2
"""
return 0 if self._current_player == 1 else 1

@property
def to_play(self):
def next_player(self):
return self.players[0] if self.current_player == self.players[1] else self.players[1]

@property
Expand Down

0 comments on commit e2531ce

Please sign in to comment.