Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(pu): add sampled alphazero and polish gomoku env #141

Merged
merged 9 commits into from
Nov 13, 2023
2 changes: 1 addition & 1 deletion lzero/entry/eval_alphazero.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def eval_alphazero(
if print_seed_details:
print("=" * 20)
print(f'In seed {seed}, returns: {returns}')
if cfg.policy.env_type == 'board_games':
if cfg.policy.simulation_env_name in ['tictactoe', 'connect4', 'gomoku', 'chess']:
print(
f'win rate: {len(np.where(returns == 1.)[0]) / num_episodes_each_seed}, draw rate: {len(np.where(returns == 0.)[0]) / num_episodes_each_seed}, lose rate: {len(np.where(returns == -1.)[0]) / num_episodes_each_seed}'
)
Expand Down
1 change: 1 addition & 0 deletions lzero/mcts/ctree/ctree_sampled_efficientzero/lib/cnode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,7 @@ namespace tree
#else
disc_action_with_probs.emplace_back(std::make_pair(iter, disturbed_probs[iter]));
#endif
// disc_action_with_probs.emplace_back(std::make_pair(iter, disturbed_probs[iter]));
}

std::sort(disc_action_with_probs.begin(), disc_action_with_probs.end(), cmp);
Expand Down
521 changes: 521 additions & 0 deletions lzero/mcts/ptree/ptree_az_sampled.py

Large diffs are not rendered by default.

73 changes: 66 additions & 7 deletions lzero/mcts/ptree/ptree_sez.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ def expand(

for action_index in range(self.num_of_sampled_actions):
self.children[Action(sampled_actions[action_index].detach().cpu().numpy())] = Node(
# prob[action_index], # NOTE: this is a bug
prob[sampled_actions[action_index]], #
action_space_size=self.action_space_size,
num_of_sampled_actions=self.num_of_sampled_actions,
Expand Down Expand Up @@ -403,6 +402,7 @@ def get_sampled_actions(self) -> List[List[Union[int, float]]]:
- python_sampled_actions: a vector of sampled_actions for each root, e.g. the size of original action space is 6, the K=3,
python_sampled_actions = [[1,3,0], [2,4,0], [5,4,1]].
"""
# TODO(pu): root_sampled_actions bug in discere action space?
sampled_actions = []
for i in range(self.root_num):
sampled_actions.append(self.roots[i].legal_actions)
Expand Down Expand Up @@ -774,20 +774,79 @@ def batch_backpropagate(
)


from typing import Union
import numpy as np

class Action:
"""Class that represent an action of a game."""
"""
Class that represents an action of a game.

Attributes:
value (Union[int, np.ndarray]): The value of the action. Can be either an integer or a numpy array.
"""

def __init__(self, value: Union[int, np.ndarray]) -> None:
"""
Initializes the Action with the given value.

def __init__(self, value: float) -> None:
Args:
value (Union[int, np.ndarray]): The value of the action.
"""
self.value = value

def __hash__(self) -> hash:
return hash(self.value.tostring())
def __hash__(self) -> int:
"""
Returns a hash of the Action's value.

If the value is a numpy array, it is flattened to a tuple and then hashed.
If the value is a single integer, it is hashed directly.

Returns:
int: The hash of the Action's value.
"""
if isinstance(self.value, np.ndarray):
if self.value.ndim == 0:
return hash(self.value.item())
else:
return hash(tuple(self.value.flatten()))
else:
return hash(self.value)

def __eq__(self, other: "Action") -> bool:
return (self.value == other.value).all()
"""
Determines if this Action is equal to another Action.

If both values are numpy arrays, they are compared element-wise.
Otherwise, they are compared directly.

Args:
other (Action): The Action to compare with.

Returns:
bool: True if the two Actions are equal, False otherwise.
"""
if isinstance(self.value, np.ndarray) and isinstance(other.value, np.ndarray):
return np.array_equal(self.value, other.value)
else:
return self.value == other.value

def __gt__(self, other: "Action") -> bool:
return self.value[0] > other.value[0]
"""
Determines if this Action's value is greater than another Action's value.

Args:
other (Action): The Action to compare with.

Returns:
bool: True if this Action's value is greater, False otherwise.
"""
return self.value > other.value

def __repr__(self) -> str:
"""
Returns a string representation of this Action.

Returns:
str: A string representation of the Action's value.
"""
return str(self.value)
107 changes: 88 additions & 19 deletions lzero/model/alphazero_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from ding.model import ReparameterizationHead
from ding.torch_utils import MLP, ResBlock
from ding.utils import MODEL_REGISTRY, SequenceType

Expand All @@ -34,6 +35,16 @@ def __init__(
fc_value_layers: SequenceType = [32],
fc_policy_layers: SequenceType = [32],
value_support_size: int = 601,
# ==============================================================
# specific sampled related config
# ==============================================================
continuous_action_space: bool = False,
num_of_sampled_actions: int = 6,
sigma_type='conditioned',
fixed_sigma_value: float = 0.3,
bound_type: str = None,
norm_type: str = 'BN',
discrete_action_encoding_type: str = 'one_hot',
):
"""
Overview:
Expand Down Expand Up @@ -70,7 +81,26 @@ def __init__(
self.last_linear_layer_init_zero = last_linear_layer_init_zero
self.representation_network = representation_network

self.continuous_action_space = continuous_action_space
self.action_space_size = action_space_size
# The dim of action space. For discrete action space, it's 1.
# For continuous action space, it is the dim of action.
self.action_space_dim = action_space_size if self.continuous_action_space else 1
assert discrete_action_encoding_type in ['one_hot', 'not_one_hot'], discrete_action_encoding_type
self.discrete_action_encoding_type = discrete_action_encoding_type
if self.continuous_action_space:
self.action_encoding_dim = action_space_size
else:
if self.discrete_action_encoding_type == 'one_hot':
self.action_encoding_dim = action_space_size
elif self.discrete_action_encoding_type == 'not_one_hot':
self.action_encoding_dim = 1
self.sigma_type = sigma_type
self.fixed_sigma_value = fixed_sigma_value
self.bound_type = bound_type
self.norm_type = norm_type
self.num_of_sampled_actions = num_of_sampled_actions

# TODO use more adaptive way to get the flatten output size
flatten_output_size_for_value_head = (
(
Expand All @@ -88,6 +118,7 @@ def __init__(

self.prediction_network = PredictionNetwork(
action_space_size,
self.continuous_action_space,
num_res_blocks,
num_channels,
value_head_channels,
Expand All @@ -99,6 +130,10 @@ def __init__(
flatten_output_size_for_policy_head,
last_linear_layer_init_zero=self.last_linear_layer_init_zero,
activation=activation,
sigma_type=self.sigma_type,
fixed_sigma_value=self.fixed_sigma_value,
bound_type=self.bound_type,
norm_type=self.norm_type,
)

if self.representation_network is None:
Expand Down Expand Up @@ -131,7 +166,7 @@ def forward(self, state_batch: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor
logit, value = self.prediction_network(encoded_state)
return logit, value

def compute_prob_value(self, state_batch: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def compute_policy_value(self, state_batch: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Overview:
The computation graph of AlphaZero model to calculate action selection probability and value.
Expand All @@ -147,9 +182,7 @@ def compute_prob_value(self, state_batch: torch.Tensor) -> Tuple[torch.Tensor, t
- value (:obj:`torch.Tensor`): :math:`(B, 1)`, where B is batch size.
"""
logit, value = self.forward(state_batch)
# construct categorical distribution to calculate probability
dist = torch.distributions.Categorical(logits=logit)
prob = dist.probs
prob = torch.nn.functional.softmax(logit, dim=-1)
return prob, value

def compute_logp_value(self, state_batch: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -178,6 +211,7 @@ class PredictionNetwork(nn.Module):
def __init__(
self,
action_space_size: int,
continuous_action_space: bool,
num_res_blocks: int,
num_channels: int,
value_head_channels: int,
Expand All @@ -189,6 +223,13 @@ def __init__(
flatten_output_size_for_policy_head: int,
last_linear_layer_init_zero: bool = True,
activation: Optional[nn.Module] = nn.ReLU(inplace=True),
# ==============================================================
# specific sampled related config
# ==============================================================
sigma_type='conditioned',
fixed_sigma_value: float = 0.3,
bound_type: str = None,
norm_type: str = 'BN',
) -> None:
"""
Overview:
Expand All @@ -213,6 +254,15 @@ def __init__(
operation to speedup, e.g. ReLU(inplace=True).
"""
super().__init__()
self.continuous_action_space = continuous_action_space
self.flatten_output_size_for_value_head = flatten_output_size_for_value_head
self.flatten_output_size_for_policy_head = flatten_output_size_for_policy_head
self.norm_type = norm_type
self.sigma_type = sigma_type
self.fixed_sigma_value = fixed_sigma_value
self.bound_type = bound_type
self.activation = activation

self.resblocks = nn.ModuleList(
[
ResBlock(in_channels=num_channels, activation=activation, norm_type='BN', res_type='basic', bias=False)
Expand All @@ -226,7 +276,7 @@ def __init__(
self.norm_policy = nn.BatchNorm2d(policy_head_channels)
self.flatten_output_size_for_value_head = flatten_output_size_for_value_head
self.flatten_output_size_for_policy_head = flatten_output_size_for_policy_head
self.fc_value = MLP(
self.fc_value_head = MLP(
in_channels=self.flatten_output_size_for_value_head,
hidden_channels=fc_value_layers[0],
out_channels=output_support_size,
Expand All @@ -237,17 +287,31 @@ def __init__(
output_norm=False,
last_linear_layer_init_zero=last_linear_layer_init_zero
)
self.fc_policy = MLP(
in_channels=self.flatten_output_size_for_policy_head,
hidden_channels=fc_policy_layers[0],
out_channels=action_space_size,
layer_num=len(fc_policy_layers) + 1,
activation=activation,
norm_type='LN',
output_activation=False,
output_norm=False,
last_linear_layer_init_zero=last_linear_layer_init_zero
)

# sampled related core code
if self.continuous_action_space:
self.fc_policy_head = ReparameterizationHead(
input_size=self.flatten_output_size_for_policy_head,
output_size=action_space_size,
layer_num=len(fc_policy_layers) + 1,
sigma_type=self.sigma_type,
fixed_sigma_value=self.fixed_sigma_value,
activation=nn.ReLU(),
norm_type=None,
bound_type=self.bound_type
)
else:
self.fc_policy_head = MLP(
in_channels=self.flatten_output_size_for_policy_head,
hidden_channels=fc_policy_layers[0],
out_channels=action_space_size,
layer_num=len(fc_policy_layers) + 1,
activation=activation,
norm_type='LN',
output_activation=False,
output_norm=False,
last_linear_layer_init_zero=last_linear_layer_init_zero
)

self.activation = activation

Expand Down Expand Up @@ -279,6 +343,11 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
value = value.reshape(-1, self.flatten_output_size_for_value_head)
policy = policy.reshape(-1, self.flatten_output_size_for_policy_head)

value = self.fc_value(value)
logit = self.fc_policy(policy)
return logit, value
value = self.fc_value_head(value)

# sampled related core code
policy = self.fc_policy_head(policy)
if self.continuous_action_space:
policy = torch.cat([policy['mu'], policy['sigma']], dim=-1)

return policy, value
6 changes: 4 additions & 2 deletions lzero/model/tests/test_alphazero_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,9 @@ def output_check(self, model, outputs):
prediction_network_args
)
def test_prediction_network(
self, action_space_size, batch_size, num_res_blocks, num_channels, value_head_channels, policy_head_channels,
fc_value_layers, fc_policy_layers, output_support_size
self, action_space_size, batch_size, num_res_blocks, num_channels, value_head_channels,
policy_head_channels,
fc_value_layers, fc_policy_layers, output_support_size
):
obs = torch.rand(batch_size, num_channels, 3, 3)
flatten_output_size_for_value_head = value_head_channels * observation_shape[1] * observation_shape[2]
Expand All @@ -64,6 +65,7 @@ def test_prediction_network(
# print('='*20)
prediction_network = PredictionNetwork(
action_space_size=action_space_size,
continuous_action_space=False,
num_res_blocks=num_res_blocks,
num_channels=num_channels,
value_head_channels=value_head_channels,
Expand Down
8 changes: 5 additions & 3 deletions lzero/policy/alphazero.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def _forward_learn(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, float]:
mcts_probs = mcts_probs.to(device=self._device, dtype=torch.float)
reward = reward.to(device=self._device, dtype=torch.float)

action_probs, values = self._learn_model.compute_prob_value(state_batch)
action_probs, values = self._learn_model.compute_policy_value(state_batch)
log_probs = torch.log(action_probs)

# calculate policy entropy, for monitoring only
Expand Down Expand Up @@ -258,7 +258,9 @@ def _init_eval(self) -> None:
self._get_simulation_env()
import copy
mcts_eval_config = copy.deepcopy(self._cfg.mcts)
mcts_eval_config.num_simulations = mcts_eval_config.num_simulations * 2
# TODO(pu): how to set proper num_simulations for evaluation
# mcts_eval_config.num_simulations = mcts_eval_config.num_simulations
mcts_eval_config.num_simulations = min(800, mcts_eval_config.num_simulations * 4)
self._eval_mcts = MCTS(mcts_eval_config, self.simulate_env)
self._eval_model = self._model

Expand Down Expand Up @@ -323,7 +325,7 @@ def _policy_value_fn(self, env: 'Env') -> Tuple[Dict[int, np.ndarray], float]:
device=self._device, dtype=torch.float
).unsqueeze(0)
with torch.no_grad():
action_probs, value = self._policy_model.compute_prob_value(current_state_scale)
action_probs, value = self._policy_model.compute_policy_value(current_state_scale)
action_probs_dict = dict(zip(legal_actions, action_probs.squeeze(0)[legal_actions].detach().cpu().numpy()))
return action_probs_dict, value.item()

Expand Down
Loading
Loading