Skip to content

Commit

Permalink
fix(pu): fix mcts and alphabeta bot unittest, polish simulation metho…
Browse files Browse the repository at this point in the history
…d of ptree_az (#120)

* fix(pu): fix mcts and alphabeta bot deepcopy bug, polish alphazero simulation method utilizing simulation env reset rather than deepcopy

* fix(pu): fix bot_action_type in test_speed_win-rate_between_bots.py and polish some minor code
  • Loading branch information
puyuan1996 authored Oct 26, 2023
1 parent fa1b727 commit e9af1cd
Show file tree
Hide file tree
Showing 17 changed files with 179 additions and 81 deletions.
9 changes: 6 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,14 @@ Updated on 2023.09.21 LightZero-v0.0.2

> LightZero is a lightweight, efficient, and easy-to-understand open-source algorithm toolkit that combines Monte Carlo Tree Search (MCTS) and Deep Reinforcement Learning (RL).
English | [简体中文](https://github.com/opendilab/LightZero/blob/main/README.zh.md) | [Paper](https://arxiv.org/pdf/2310.08348.pdf)
English | [简体中文(Simplified Chinese)](https://github.com/opendilab/LightZero/blob/main/README.zh.md) | [Paper](https://arxiv.org/pdf/2310.08348.pdf)

## Background

The method of combining Monte Carlo Tree Search and Deep Reinforcement Learning represented by AlphaZero and MuZero has achieved superhuman level in various games such as Go and Atari,and has also made gratifying progress in scientific fields such as protein structure prediction, matrix multiplication algorithm search, etc.
The integration of Monte Carlo Tree Search and Deep Reinforcement Learning,
exemplified by AlphaZero and MuZero,
has achieved unprecedented performance levels in various games, including Go and Atari.
This advanced methodology has also made significant strides in scientific domains like protein structure prediction and the search for matrix multiplication algorithms.
The following is an overview of the historical evolution of the Monte Carlo Tree Search algorithm series:
![pipeline](assets/mcts_rl_evolution_overview.png)

Expand Down Expand Up @@ -484,4 +487,4 @@ Special thanks to [@PaParaZz1](https://github.com/PaParaZz1), [@karroyan](https:
## License
All code within this repository is under [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0).
<p align="right">(<a href="#top">back to top</a>)</p>
<p align="right">(<a href="#top">Back to top</a>)</p>
4 changes: 2 additions & 2 deletions README.zh.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# LightZero

<div id="top"></div>

# LightZero

<div align="center">
<img width="1000px" height="auto" src="https://github.com/opendilab/LightZero/blob/main/LightZero.png"></a>
</div>
Expand Down
6 changes: 1 addition & 5 deletions lzero/mcts/ctree/ctree_sampled_efficientzero/lib/cnode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,11 +380,7 @@ namespace tree
// After sorting, the first vector is the index, and the second vector is the probability value after perturbation sorted from large to small.
for (size_t iter = 0; iter < disturbed_probs.size(); iter++)
{
#ifdef __APPLE__
disc_action_with_probs.__emplace_back(std::make_pair(iter, disturbed_probs[iter]));
#else
disc_action_with_probs.emplace_back(std::make_pair(iter, disturbed_probs[iter]));
#endif
disc_action_with_probs.emplace_back(std::make_pair(iter, disturbed_probs[iter]));
}

std::sort(disc_action_with_probs.begin(), disc_action_with_probs.end(), cmp);
Expand Down
48 changes: 21 additions & 27 deletions lzero/mcts/ptree/ptree_az.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import copy
import math
from typing import List, Tuple, Union, Callable, Type
from typing import List, Tuple, Union, Callable, Type, Dict, Any

import numpy as np
import torch
Expand Down Expand Up @@ -169,7 +169,7 @@ class MCTS(object):
Finally, by repeatedly calling ``_simulate`` through ``get_next_action``, the optimal action is obtained.
"""

def __init__(self, cfg: EasyDict) -> None:
def __init__(self, cfg: EasyDict, simulate_env: Type[BaseEnv]) -> None:
"""
Overview:
Initializes the MCTS process.
Expand All @@ -193,11 +193,12 @@ def __init__(self, cfg: EasyDict) -> None:
'root_dirichlet_alpha', 0.3
) # 0.3 # for chess, 0.03 for Go and 0.15 for shogi.
self._root_noise_weight = self._cfg.get('root_noise_weight', 0.25)
self.simulate_cnt = 0

self.simulate_env = simulate_env

def get_next_action(
self,
simulate_env: Type[BaseEnv],
state_config_for_simulate_env_reset: Dict[str, Any],
policy_forward_fn: Callable,
temperature: int = 1.0,
sample: bool = True
Expand All @@ -206,7 +207,7 @@ def get_next_action(
Overview:
Get the next action to take based on the current state of the game.
Arguments:
- simulate_env (:obj:`Class BaseGameEnv`): The class of simulate env.
- state_config_for_simulate_env_reset (:obj:`Dict`): The config of state when reset the env.
- policy_forward_fn (:obj:`Function`): The Callable to compute the action probs and state value.
- temperature (:obj:`Float`): The exploration temperature.
- sample (:obj:`Bool`): Whether to sample an action from the probabilities or choose the most probable action.
Expand All @@ -217,41 +218,38 @@ def get_next_action(

# Create a new root node for the MCTS search.
root = Node()

self.simulate_env.reset(
start_player_index=state_config_for_simulate_env_reset.start_player_index,
init_state=state_config_for_simulate_env_reset.init_state,
)
# Expand the root node by adding children to it.
self._expand_leaf_node(root, simulate_env, policy_forward_fn)
self._expand_leaf_node(root, self.simulate_env, policy_forward_fn)

# Add Dirichlet noise to the root node's prior probabilities to encourage exploration.
if sample:
self._add_exploration_noise(root)

# for debugging
# print(simulate_env.board)
# print('value= {}'.format([(k, v.value) for k,v in root.children.items()]))
# print('visit_count= {}'.format([(k, v.visit_count) for k,v in root.children.items()]))
# print('legal_action= {}',format(simulate_env.legal_actions))

# Perform MCTS search for a fixed number of iterations.
for n in range(self._num_simulations):
# Initialize the simulated environment and reset it to the root node.
simulate_env_copy = copy.deepcopy(simulate_env)
self.simulate_env.reset(
start_player_index=state_config_for_simulate_env_reset.start_player_index,
init_state=state_config_for_simulate_env_reset.init_state,
)
# Set the battle mode adopted by the environment during the MCTS process.
# In ``self_play_mode``, when the environment calls the step function once, it will play one move based on the incoming action.
# In ``play_with_bot_mode``, when the step function is called, it will play one move based on the incoming action,
# and then it will play another move based on the action generated by the built-in bot in the environment, which means two moves in total.
# Therefore, in the MCTS process, except for the terminal nodes, the player corresponding to each node is the same player as the root node.
simulate_env_copy.battle_mode = simulate_env_copy.mcts_mode
simulate_env_copy.render_mode = None
self.simulate_env.battle_mode = self.simulate_env.mcts_mode
self.simulate_env.render_mode = None
# Run the simulation from the root to a leaf node and update the node values along the way.
self._simulate(root, simulate_env_copy, policy_forward_fn)

# for debugging
# print('after simulation')
# print('value= {}'.format([(k, v.value) for k,v in root.children.items()]))
# print('visit_count= {}'.format([(k, v.visit_count) for k,v in root.children.items()]))
self._simulate(root, self.simulate_env, policy_forward_fn)

# Get the visit count for each possible action at the root node.
action_visits = []
for action in range(simulate_env.action_space.n):
for action in range(self.simulate_env.action_space.n):
if action in root.children:
action_visits.append((action, root.children[action].visit_count))
else:
Expand All @@ -273,6 +271,7 @@ def get_next_action(
action = np.random.choice(actions, p=action_probs)
else:
action = actions[np.argmax(action_probs)]

# Return the selected action and the output probability of each action.
return action, action_probs

Expand All @@ -288,11 +287,6 @@ def _simulate(self, node: Node, simulate_env: Type[BaseEnv], policy_forward_fn:
"""
while not node.is_leaf():
# Traverse the tree until the leaf node.

# only for debug
# self.simulate_cnt += 1
# print('simulate_cnt: {}'.format(self.simulate_cnt))
# print(f'node:{node}, list(node.children.keys()) is: {list(node.children.keys())}. simulate_env.legal_actions is: {simulate_env.legal_actions}')
action, node = self._select_child(node, simulate_env)
# When there are no common elements in ``node.children`` and ``simulate_env.legal_actions``, action would be None, and we set the node to be a leaf node.
if action is None:
Expand Down
63 changes: 42 additions & 21 deletions lzero/policy/alphazero.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ding.torch_utils import to_device
from ding.utils import POLICY_REGISTRY
from ding.utils.data import default_collate
from easydict import EasyDict

from lzero.mcts.ptree.ptree_az import MCTS
from lzero.policy import configure_optimizers
Expand Down Expand Up @@ -210,17 +211,17 @@ def _init_collect(self) -> None:
Overview:
Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils.
"""
self._collect_mcts = MCTS(self._cfg.mcts)
self._get_simulation_env()
self._collect_model = self._model
self._collect_mcts_temperature = 1
self._collect_mcts = MCTS(self._cfg.mcts, self.simulate_env)

@torch.no_grad()
def _forward_collect(self, envs: Dict, obs: Dict, temperature: float = 1) -> Dict[str, torch.Tensor]:
def _forward_collect(self, obs: Dict, temperature: float = 1) -> Dict[str, torch.Tensor]:
"""
Overview:
The forward function for collecting data in collect mode. Use real env to execute MCTS search.
Arguments:
- envs (:obj:`Dict`): The dict of colletor envs, the key is env_id and the value is the env instance.
- obs (:obj:`Dict`): The dict of obs, the key is env_id and the value is the \
corresponding obs in this timestep.
- temperature (:obj:`float`): The temperature for MCTS search.
Expand All @@ -229,20 +230,16 @@ def _forward_collect(self, envs: Dict, obs: Dict, temperature: float = 1) -> Dic
the corresponding policy output in this timestep, including action, probs and so on.
"""
self._collect_mcts_temperature = temperature
ready_env_id = list(envs.keys())
ready_env_id = list(obs.keys())
init_state = {env_id: obs[env_id]['board'] for env_id in ready_env_id}
start_player_index = {env_id: obs[env_id]['current_player_index'] for env_id in ready_env_id}
output = {}
self._policy_model = self._collect_model
for env_id in ready_env_id:
# print('[collect] start_player_index={}'.format(start_player_index[env_id]))
# print('[collect] init_state=\n{}'.format(init_state[env_id]))
envs[env_id].reset(
start_player_index=start_player_index[env_id],
init_state=init_state[env_id],
)
state_config_for_simulation_env_reset = EasyDict(dict(start_player_index=start_player_index[env_id],
init_state=init_state[env_id], ))
action, mcts_probs = self._collect_mcts.get_next_action(
envs[env_id],
state_config_for_simulation_env_reset,
policy_forward_fn=self._policy_value_fn,
temperature=self._collect_mcts_temperature,
sample=True
Expand All @@ -258,15 +255,18 @@ def _init_eval(self) -> None:
Overview:
Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils.
"""
self._eval_mcts = MCTS(self._cfg.mcts)
self._get_simulation_env()
import copy
mcts_eval_config = copy.deepcopy(self._cfg.mcts)
mcts_eval_config.num_simulations = mcts_eval_config.num_simulations * 2
self._eval_mcts = MCTS(mcts_eval_config, self.simulate_env)
self._eval_model = self._model

def _forward_eval(self, envs: Dict, obs: Dict) -> Dict[str, torch.Tensor]:
def _forward_eval(self, obs: Dict) -> Dict[str, torch.Tensor]:
"""
Overview:
The forward function for evaluating the current policy in eval mode, similar to ``self._forward_collect``.
Arguments:
- envs (:obj:`Dict`): The dict of colletor envs, the key is env_id and the value is the env instance.
- obs (:obj:`Dict`): The dict of obs, the key is env_id and the value is the \
corresponding obs in this timestep.
Returns:
Expand All @@ -279,21 +279,42 @@ def _forward_eval(self, envs: Dict, obs: Dict) -> Dict[str, torch.Tensor]:
output = {}
self._policy_model = self._eval_model
for env_id in ready_env_id:
# print('[eval] start_player_index={}'.format(start_player_index[env_id]))
# print('[eval] init_state=\n {}'.format(init_state[env_id]))
envs[env_id].reset(
start_player_index=start_player_index[env_id],
init_state=init_state[env_id],
)
state_config_for_simulation_env_reset = EasyDict(dict(start_player_index=start_player_index[env_id],
init_state=init_state[env_id],))
action, mcts_probs = self._eval_mcts.get_next_action(
envs[env_id], policy_forward_fn=self._policy_value_fn, temperature=1.0, sample=False
state_config_for_simulation_env_reset, policy_forward_fn=self._policy_value_fn, temperature=1.0, sample=False
)
output[env_id] = {
'action': action,
'probs': mcts_probs,
}
return output

def _get_simulation_env(self):
if self._cfg.simulation_env_name == 'tictactoe':
from zoo.board_games.tictactoe.envs.tictactoe_env import TicTacToeEnv
if self._cfg.simulation_env_config_type == 'play_with_bot':
from zoo.board_games.tictactoe.config.tictactoe_alphazero_bot_mode_config import \
tictactoe_alphazero_config
elif self._cfg.simulation_env_config_type == 'self_play':
from zoo.board_games.tictactoe.config.tictactoe_alphazero_sp_mode_config import \
tictactoe_alphazero_config
else:
raise NotImplementedError
self.simulate_env = TicTacToeEnv(tictactoe_alphazero_config.env)

elif self._cfg.simulation_env_name == 'gomoku':
from zoo.board_games.gomoku.envs.gomoku_env import GomokuEnv
if self._cfg.simulation_env_config_type == 'play_with_bot':
from zoo.board_games.gomoku.config.gomoku_alphazero_bot_mode_config import gomoku_alphazero_config
elif self._cfg.simulation_env_config_type == 'self_play':
from zoo.board_games.gomoku.config.gomoku_alphazero_sp_mode_config import gomoku_alphazero_config
else:
raise NotImplementedError
self.simulate_env = GomokuEnv(gomoku_alphazero_config.env)
else:
raise NotImplementedError

@torch.no_grad()
def _policy_value_fn(self, env: 'Env') -> Tuple[Dict[int, np.ndarray], float]: # noqa
legal_actions = env.legal_actions
Expand Down
7 changes: 2 additions & 5 deletions lzero/worker/alphazero_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,14 +215,11 @@ def collect(self,
obs_ = {env_id: obs[env_id] for env_id in ready_env_id}
# Policy forward.
self._obs_pool.update(obs_)
simulation_envs = {}
for env_id in ready_env_id:
# create the new simulation env instances from the current collect env using the same env_config.
simulation_envs[env_id] = self._env._env_fn[env_id]()

# ==============================================================
# policy forward
# ==============================================================
policy_output = self._policy.forward(simulation_envs, obs_, temperature)
policy_output = self._policy.forward(obs_, temperature)
self._policy_output_pool.update(policy_output)
# Interact with env.
actions = {env_id: output['action'] for env_id, output in policy_output.items()}
Expand Down
6 changes: 1 addition & 5 deletions lzero/worker/alphazero_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,15 +202,11 @@ def eval(
with self._timer:
while not eval_monitor.is_finished():
obs = self._env.ready_obs
simulation_envs = {}
for env_id in list(obs.keys()):
# create the new simulation env instances from the current evaluate env using the same env_config.
simulation_envs[env_id] = self._env._env_fn[env_id]()

# ==============================================================
# policy forward
# ==============================================================
policy_output = self._policy.forward(simulation_envs, obs)
policy_output = self._policy.forward(obs)
actions = {env_id: output['action'] for env_id, output in policy_output.items()}
# ==============================================================
# Interact with env.
Expand Down
2 changes: 1 addition & 1 deletion zoo/board_games/alphabeta_pruning_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(self, board, legal_actions, start_player_index=0, parent=None, prev
super().__init__()
self.env = env
self.board = board
self.legal_actions = legal_actions
self.legal_actions = copy.deepcopy(legal_actions)
self.children = []
self.parent = parent
self.prev_action = prev_action
Expand Down
13 changes: 13 additions & 0 deletions zoo/board_games/gomoku/config/gomoku_alphazero_bot_mode_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,21 @@
evaluator_env_num=evaluator_env_num,
n_evaluator_episode=evaluator_env_num,
manager=dict(shared_memory=False, ),
# ==============================================================
# for the creation of simulation env
agent_vs_human=False,
prob_random_agent=0,
prob_expert_agent=0,
scale=True,
check_action_to_connect4_in_bot_v0=False,
# ==============================================================
),
policy=dict(
# ==============================================================
# for the creation of simulation env
simulation_env_name='gomoku',
simulation_env_config_type='play_with_bot',
# ==============================================================
torch_compile=False,
tensor_float_32=False,
model=dict(
Expand Down
13 changes: 13 additions & 0 deletions zoo/board_games/gomoku/config/gomoku_alphazero_sp_mode_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,21 @@
evaluator_env_num=evaluator_env_num,
n_evaluator_episode=evaluator_env_num,
manager=dict(shared_memory=False, ),
# ==============================================================
# for the creation of simulation env
agent_vs_human=False,
prob_random_agent=0,
prob_expert_agent=0,
scale=True,
check_action_to_connect4_in_bot_v0=False,
# ==============================================================
),
policy=dict(
# ==============================================================
# for the creation of simulation env
simulation_env_name='gomoku',
simulation_env_config_type='self_play',
# ==============================================================
torch_compile=False,
tensor_float_32=False,
model=dict(
Expand Down
Loading

0 comments on commit e9af1cd

Please sign in to comment.