Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(pu): polish all mlp model and related configs #26

Merged
merged 16 commits into from
May 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -1433,6 +1433,7 @@ events.*
/test_*
# LightZero special key
/lzero/mcts/**/*.cpp
/zoo/**/*.c
/lzero/mcts/**/*.so
/lzero/mcts/**/*.h
!/lzero/mcts/**/lib
Expand Down
4 changes: 3 additions & 1 deletion lzero/mcts/buffer/game_buffer_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,9 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
if self._cfg.use_root_value:
# use the root values from MCTS, as in EfficiientZero
# the root values have limited improvement but require much more GPU actors;
_, reward_pool, policy_logits_pool, latent_state_roots = concat_output(network_output, data_type='muzero')
_, reward_pool, policy_logits_pool, latent_state_roots = concat_output(
network_output, data_type='muzero'
)
reward_pool = reward_pool.squeeze().tolist()
policy_logits_pool = policy_logits_pool.tolist()
noises = [
Expand Down
11 changes: 6 additions & 5 deletions lzero/mcts/ptree/ptree_ez.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ def __init__(self, prior: float, legal_actions: List = None, action_space_size:
self.parent_value_prefix = 0 # only used in update_tree_q method

def expand(
self, to_play: int, simulation_index: int, batch_index: int, value_prefix: float,
policy_logits: List[float]
self, to_play: int, simulation_index: int, batch_index: int, value_prefix: float, policy_logits: List[float]
) -> None:
"""
Overview:
Expand Down Expand Up @@ -286,6 +285,7 @@ def __init__(self, num: int) -> None:
self.last_actions = []
self.search_lens = []


def select_child(
root: Node, min_max_stats: MinMaxStats, pb_c_base: float, pb_c_int: float, discount_factor: float,
mean_q: float, players: int
Expand Down Expand Up @@ -431,7 +431,6 @@ def batch_traverse(
is_root = 1
search_len = 0
results.search_paths[i].append(node)

"""
MCTS stage 1: Selection
Each simulation starts from the internal root state s0, and finishes when the simulation reaches a leaf node s_l.
Expand Down Expand Up @@ -515,7 +514,7 @@ def backpropagate(
path_len = len(search_path)
for i in range(path_len - 1, -1, -1):
node = search_path[i]

node.value_sum += bootstrap_value if node.to_play == to_play else -bootstrap_value

node.visit_count += 1
Expand All @@ -536,7 +535,9 @@ def backpropagate(
min_max_stats.update(true_reward + discount_factor * -node.value)

# true_reward is in the perspective of current player of node
bootstrap_value = (-true_reward if node.to_play == to_play else true_reward) + discount_factor * bootstrap_value
bootstrap_value = (
-true_reward if node.to_play == to_play else true_reward
) + discount_factor * bootstrap_value


def batch_backpropagate(
Expand Down
8 changes: 4 additions & 4 deletions lzero/mcts/ptree/ptree_sez.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ def __init__(
self.batch_index = 0

def expand(
self, to_play: int, simulation_index: int, batch_index: int, value_prefix: float,
policy_logits: List[float]
self, to_play: int, simulation_index: int, batch_index: int, value_prefix: float, policy_logits: List[float]
) -> None:
"""
Overview:
Expand Down Expand Up @@ -614,7 +613,6 @@ def batch_traverse(
is_root = 1
search_len = 0
results.search_paths[i].append(node)

"""
MCTS stage 1: Selection
Each simulation starts from the internal root state s0, and finishes when the simulation reaches a leaf node s_l.
Expand Down Expand Up @@ -726,7 +724,9 @@ def backpropagate(
min_max_stats.update(true_reward + discount_factor * -node.value)

# true_reward is in the perspective of current player of node
bootstrap_value = (-true_reward if node.to_play == to_play else true_reward) + discount_factor * bootstrap_value
bootstrap_value = (
-true_reward if node.to_play == to_play else true_reward
) + discount_factor * bootstrap_value


def batch_backpropagate(
Expand Down
23 changes: 12 additions & 11 deletions lzero/mcts/tree_search/mcts_ctree.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from lzero.mcts.ctree.ctree_efficientzero import ez_tree as ez_ctree
from lzero.mcts.ctree.ctree_muzero import mz_tree as mz_ctree


# ==============================================================
# EfficientZero
# ==============================================================
Expand Down Expand Up @@ -93,7 +92,7 @@ def search(
# preparation some constant
batch_size = roots.num
pb_c_base, pb_c_init, discount_factor = self._cfg.pb_c_base, self._cfg.pb_c_init, self._cfg.discount_factor

# the data storage of latent states: storing the latent state of all the nodes in one search.
latent_state_batch_in_search_path = [latent_state_roots]
# the data storage of value prefix hidden states in LSTM
Expand All @@ -118,7 +117,6 @@ def search(
# latent_state_index_in_batch: the second index of leaf node states in latent_state_batch_in_search_path, i.e. the index in the batch, whose maximum is ``batch_size``.
# e.g. the latent state of the leaf node in (x, y) is latent_state_batch_in_search_path[x, y], where x is current_latent_state_index, y is batch_index.
# The index of value prefix hidden state of the leaf node are in the same manner.

"""
MCTS stage 1: Selection
Each simulation starts from the internal root state s0, and finishes when the simulation reaches a leaf node s_l.
Expand All @@ -143,7 +141,6 @@ def search(
).unsqueeze(0)
# .long() is only for discrete action
last_actions = torch.from_numpy(np.asarray(last_actions)).to(self._cfg.device).long()

"""
MCTS stage 2: Expansion
At the final time-step l of the simulation, the next_latent_state and reward/value_prefix are computed by the dynamics function.
Expand All @@ -156,7 +153,10 @@ def search(
)
if not model.training:
# if not in training, obtain the scalars of the value/value_prefix
[network_output.latent_state, network_output.policy_logits, network_output.value, network_output.value_prefix] = to_detach_cpu_numpy(
[
network_output.latent_state, network_output.policy_logits, network_output.value,
network_output.value_prefix
] = to_detach_cpu_numpy(
[
network_output.latent_state,
network_output.policy_logits,
Expand Down Expand Up @@ -187,7 +187,7 @@ def search(
reward_hidden_state_c_batch.append(reward_latent_state_batch[0])
reward_hidden_state_h_batch.append(reward_latent_state_batch[1])

# In ``batch_backpropagate()``, we first expand the leaf node using ``the policy_logits`` and
# In ``batch_backpropagate()``, we first expand the leaf node using ``the policy_logits`` and
# ``reward`` predicted by the model, then perform backpropagation along the search path to update the
# statistics.

Expand Down Expand Up @@ -260,7 +260,7 @@ def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "m

def search(
self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], to_play_batch: Union[int,
List[Any]]
List[Any]]
) -> None:
"""
Overview:
Expand Down Expand Up @@ -296,7 +296,6 @@ def search(
# latent_state_index_in_batch: the second index of leaf node states in latent_state_batch_in_search_path, i.e. the index in the batch, whose maximum is ``batch_size``.
# e.g. the latent state of the leaf node in (x, y) is latent_state_batch_in_search_path[x, y], where x is current_latent_state_index, y is batch_index.
# The index of value prefix hidden state of the leaf node are in the same manner.

"""
MCTS stage 1: Selection
Each simulation starts from the internal root state s0, and finishes when the simulation reaches a leaf node s_l.
Expand Down Expand Up @@ -324,8 +323,10 @@ def search(

if not model.training:
# if not in training, obtain the scalars of the value/reward
[network_output.latent_state, network_output.policy_logits, network_output.value,
network_output.reward] = to_detach_cpu_numpy(
[
network_output.latent_state, network_output.policy_logits, network_output.value,
network_output.reward
] = to_detach_cpu_numpy(
[
network_output.latent_state,
network_output.policy_logits,
Expand All @@ -340,7 +341,7 @@ def search(
value_batch = network_output.value.reshape(-1).tolist()
policy_logits_batch = network_output.policy_logits.tolist()

# In ``batch_backpropagate()``, we first expand the leaf node using ``the policy_logits`` and
# In ``batch_backpropagate()``, we first expand the leaf node using ``the policy_logits`` and
# ``reward`` predicted by the model, then perform backpropagation along the search path to update the
# statistics.

Expand Down
9 changes: 5 additions & 4 deletions lzero/mcts/tree_search/mcts_ctree_sampled.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ def search(
# latent_state_index_in_batch: the second index of leaf node states in latent_state_batch_in_search_path, i.e. the index in the batch, whose maximum is ``batch_size``.
# e.g. the latent state of the leaf node in (x, y) is latent_state_batch_in_search_path[x, y], where x is current_latent_state_index, y is batch_index.
# The index of value prefix hidden state of the leaf node are in the same manner.

"""
MCTS stage 1: Selection
Each simulation starts from the internal root state s0, and finishes when the simulation reaches a leaf node s_l.
Expand Down Expand Up @@ -153,7 +152,6 @@ def search(
else:
# discrete action
last_actions = torch.from_numpy(np.asarray(last_actions)).to(device).long()

"""
MCTS stage 2: Expansion
At the final time-step l of the simulation, the next_latent_state and reward/value_prefix are computed by the dynamics function.
Expand All @@ -166,7 +164,10 @@ def search(
)
if not model.training:
# if not in training, obtain the scalars of the value/value_prefix
[network_output.latent_state, network_output.policy_logits, network_output.value, network_output.value_prefix] = to_detach_cpu_numpy(
[
network_output.latent_state, network_output.policy_logits, network_output.value,
network_output.value_prefix
] = to_detach_cpu_numpy(
[
network_output.latent_state,
network_output.policy_logits,
Expand Down Expand Up @@ -196,7 +197,7 @@ def search(
reward_hidden_state_c_pool.append(reward_latent_state_batch[0])
reward_hidden_state_h_pool.append(reward_latent_state_batch[1])

# In ``batch_backpropagate()``, we first expand the leaf node using ``the policy_logits`` and
# In ``batch_backpropagate()``, we first expand the leaf node using ``the policy_logits`` and
# ``reward`` predicted by the model, then perform backpropagation along the search path to update the
# statistics.

Expand Down
31 changes: 16 additions & 15 deletions lzero/mcts/tree_search/mcts_ptree.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ def search(
# latent_state_index_in_batch: the second index of leaf node states in latent_state_batch_in_search_path, i.e. the index in the batch, whose maximum is ``batch_size``.
# e.g. the latent state of the leaf node in (x, y) is latent_state_batch_in_search_path[x, y], where x is current_latent_state_index, y is batch_index.
# The index of value prefix hidden state of the leaf node are in the same manner.

"""
MCTS stage 1: Selection
Each simulation starts from the internal root state s0, and finishes when the simulation reaches a leaf node s_l.
Expand All @@ -140,10 +139,10 @@ def search(
hidden_states_h_reward.append(reward_hidden_state_h_batch[ix][0][iy])

latent_states = torch.from_numpy(np.asarray(latent_states)).to(self._cfg.device).float()
hidden_states_c_reward = torch.from_numpy(np.asarray(hidden_states_c_reward)
).to(self._cfg.device).unsqueeze(0)
hidden_states_h_reward = torch.from_numpy(np.asarray(hidden_states_h_reward)
).to(self._cfg.device).unsqueeze(0)
hidden_states_c_reward = torch.from_numpy(np.asarray(hidden_states_c_reward)).to(self._cfg.device
).unsqueeze(0)
hidden_states_h_reward = torch.from_numpy(np.asarray(hidden_states_h_reward)).to(self._cfg.device
).unsqueeze(0)
# .long() is only for discrete action
last_actions = torch.from_numpy(np.asarray(last_actions)).to(self._cfg.device).long()
"""
Expand All @@ -159,8 +158,10 @@ def search(

if not model.training:
# if not in training, obtain the scalars of the value/reward
[network_output.latent_state, network_output.policy_logits, network_output.value,
network_output.value_prefix] = to_detach_cpu_numpy(
[
network_output.latent_state, network_output.policy_logits, network_output.value,
network_output.value_prefix
] = to_detach_cpu_numpy(
[
network_output.latent_state,
network_output.policy_logits,
Expand Down Expand Up @@ -190,7 +191,7 @@ def search(
reward_hidden_state_c_batch.append(reward_latent_state_batch[0])
reward_hidden_state_h_batch.append(reward_latent_state_batch[1])

# In ``batch_backpropagate()``, we first expand the leaf node using ``the policy_logits`` and
# In ``batch_backpropagate()``, we first expand the leaf node using ``the policy_logits`` and
# ``reward`` predicted by the model, then perform backpropagation along the search path to update the
# statistics.

Expand Down Expand Up @@ -297,10 +298,9 @@ def search(
# prepare a result wrapper to transport results between python and c++ parts
results = tree_muzero.SearchResults(num=batch_size)

# latent_state_index_in_search_path: The first index of the latent state corresponding to the leaf node in latent_state_batch_in_search_path, that is, the search depth.
# latent_state_index_in_batch: The second index of the latent state corresponding to the leaf node in latent_state_batch_in_search_path, i.e. the index in the batch, whose maximum is ``batch_size``.
# latent_state_index_in_search_path: The first index of the latent state corresponding to the leaf node in latent_state_batch_in_search_path, that is, the search depth.
# latent_state_index_in_batch: The second index of the latent state corresponding to the leaf node in latent_state_batch_in_search_path, i.e. the index in the batch, whose maximum is ``batch_size``.
# e.g. the latent state of the leaf node in (x, y) is latent_state_batch_in_search_path[x, y], where x is current_latent_state_index, y is batch_index.

"""
MCTS stage 1: Selection
Each simulation starts from the internal root state s0, and finishes when the simulation reaches a leaf node s_l.
Expand All @@ -315,7 +315,6 @@ def search(
latent_states = torch.from_numpy(np.asarray(latent_states)).to(self._cfg.device).float()
# only for discrete action
last_actions = torch.from_numpy(np.asarray(last_actions)).to(self._cfg.device).long()

"""
MCTS stage 2: Expansion
At the final time-step l of the simulation, the next_latent_state and reward/value_prefix are computed by the dynamics function.
Expand All @@ -327,8 +326,10 @@ def search(

if not model.training:
# if not in training, obtain the scalars of the value/reward
[network_output.latent_state, network_output.policy_logits, network_output.value,
network_output.reward] = to_detach_cpu_numpy(
[
network_output.latent_state, network_output.policy_logits, network_output.value,
network_output.reward
] = to_detach_cpu_numpy(
[
network_output.latent_state,
network_output.policy_logits,
Expand All @@ -343,7 +344,7 @@ def search(
reward_batch = network_output.reward.reshape(-1).tolist()
policy_logits_batch = network_output.policy_logits.tolist()

# In ``batch_backpropagate()``, we first expand the leaf node using ``the policy_logits`` and
# In ``batch_backpropagate()``, we first expand the leaf node using ``the policy_logits`` and
# ``reward`` predicted by the model, then perform backpropagation along the search path to update the
# statistics.

Expand Down
Loading