Skip to content

Commit

Permalink
feature(pu/zt): add 2048 env and Stochastic MuZero (#64)
Browse files Browse the repository at this point in the history
* add stochastic mz ptree

* add stochastic mz ctree

* add box2d, classic conrol, and 2048 config

* made corrections to the comments and naming issues

* made corrections to the comments and naming issues

* polish(pu): polish game_2048_env

* polish(pu): polish chance encoder

* fix(pu): fix chance encoder related loss

* sync code

* polish(pu): polish 2048 env, add env save_render_gif method, add 2048 env unittest, add stochatic muzero model unittest

* feature(pu): add stochastic muzero eval config

* polish(pu): polish 2048 save_replay method

* feature(pu): add num_of_possible_chance_tile option in 2048 env

* polish(pu): delete collector filed in create config, move eval_config to entry directory

* sync code

* polish(pu): polish 2048 rule_bot move method, polish 2048 env, polish stochastic muzero game buffer

* feature(pu): add stochastic_muzero_model_mlp

* polish(pu): polish stochastic muzero configs

* feature(pu): add analyze utlis for chance distribution

* polish(pu): delete model_path personal info

* polish(pu): add TestVisualizationFunctions, polish stochastic muzero model, rename xxx_eval_config.py to xxx_eval.py

* fix(pu): fix test_game_segment.py

* polish(pu): polish comments, abstract a get_target_obs_index_in_step_k method and add its unittest

* polish(pu): use _get_target_obs_index_in_step_k in all policy, rename step_i to step_k

---------

Co-authored-by: timothijoe <zt1301112@gmail.com>
Co-authored-by: 蒲源 <PJLAB\puyuan@pjnl104220214l.pjlab.org>
  • Loading branch information
3 people authored Sep 12, 2023
1 parent eccda94 commit 9c42878
Show file tree
Hide file tree
Showing 89 changed files with 6,841 additions and 345 deletions.
4 changes: 2 additions & 2 deletions lzero/entry/eval_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def eval_muzero(
- policy (:obj:`Policy`): Converged policy.
"""
cfg, create_cfg = input_cfg
assert create_cfg.policy.type in ['efficientzero', 'muzero', 'sampled_efficientzero'], \
"LightZero now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero'"
assert create_cfg.policy.type in ['efficientzero', 'muzero', 'stochastic_muzero', 'gumbel_muzero', 'sampled_efficientzero'], \
"LightZero now only support the following algo.: 'efficientzero', 'muzero', 'stochastic_muzero', 'gumbel_muzero', 'sampled_efficientzero'"

if cfg.policy.cuda and torch.cuda.is_available():
cfg.policy.device = 'cuda'
Expand Down
4 changes: 3 additions & 1 deletion lzero/entry/train_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def train_muzero(
"""

cfg, create_cfg = input_cfg
assert create_cfg.policy.type in ['efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero'], \
assert create_cfg.policy.type in ['efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'stochastic_muzero'], \
"train_muzero entry now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero'"

if create_cfg.policy.type == 'muzero':
Expand All @@ -58,6 +58,8 @@ def train_muzero(
from lzero.mcts import SampledEfficientZeroGameBuffer as GameBuffer
elif create_cfg.policy.type == 'gumbel_muzero':
from lzero.mcts import GumbelMuZeroGameBuffer as GameBuffer
elif create_cfg.policy.type == 'stochastic_muzero':
from lzero.mcts import StochasticMuZeroGameBuffer as GameBuffer

if cfg.policy.cuda and torch.cuda.is_available():
cfg.policy.device = 'cuda'
Expand Down
1 change: 1 addition & 0 deletions lzero/mcts/buffer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .game_buffer_efficientzero import EfficientZeroGameBuffer
from .game_buffer_sampled_efficientzero import SampledEfficientZeroGameBuffer
from .game_buffer_gumbel_muzero import GumbelMuZeroGameBuffer
from .game_buffer_stochastic_muzero import StochasticMuZeroGameBuffer
171 changes: 171 additions & 0 deletions lzero/mcts/buffer/game_buffer_stochastic_muzero.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
from typing import Any, Tuple, List

import numpy as np
from ding.utils import BUFFER_REGISTRY

from lzero.mcts.utils import prepare_observation
from .game_buffer_muzero import MuZeroGameBuffer


@BUFFER_REGISTRY.register('game_buffer_stochastic_muzero')
class StochasticMuZeroGameBuffer(MuZeroGameBuffer):
"""
Overview:
The specific game buffer for Stochastic MuZero policy.
"""

def __init__(self, cfg: dict):
super().__init__(cfg)
"""
Overview:
Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key
in the default configuration, the user-provided value will override the default configuration. Otherwise,
the default configuration will be used.
"""
default_config = self.default_config()
default_config.update(cfg)
self._cfg = default_config
assert self._cfg.env_type in ['not_board_games', 'board_games']
self.replay_buffer_size = self._cfg.replay_buffer_size
self.batch_size = self._cfg.batch_size
self._alpha = self._cfg.priority_prob_alpha
self._beta = self._cfg.priority_prob_beta

self.keep_ratio = 1
self.model_update_interval = 10
self.num_of_collected_episodes = 0
self.base_idx = 0
self.clear_time = 0

self.game_segment_buffer = []
self.game_pos_priorities = []
self.game_segment_game_pos_look_up = []

def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
"""
Overview:
first sample orig_data through ``_sample_orig_data()``,
then prepare the context of a batch:
reward_value_context: the context of reanalyzed value targets
policy_re_context: the context of reanalyzed policy targets
policy_non_re_context: the context of non-reanalyzed policy targets
current_batch: the inputs of batch
Arguments:
- batch_size (:obj:`int`): the batch size of orig_data from replay buffer.
- reanalyze_ratio (:obj:`float`): ratio of reanalyzed policy (value is 100% reanalyzed)
Returns:
- context (:obj:`Tuple`): reward_value_context, policy_re_context, policy_non_re_context, current_batch
"""
# obtain the batch context from replay buffer
orig_data = self._sample_orig_data(batch_size)
game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time_list = orig_data
batch_size = len(batch_index_list)
obs_list, action_list, mask_list = [], [], []
if self._cfg.use_ture_chance_label_in_chance_encoder:
chance_list = []
# prepare the inputs of a batch
for i in range(batch_size):
game = game_segment_list[i]
pos_in_game_segment = pos_in_game_segment_list[i]

actions_tmp = game.action_segment[pos_in_game_segment:pos_in_game_segment +
self._cfg.num_unroll_steps].tolist()
if self._cfg.use_ture_chance_label_in_chance_encoder:
chances_tmp = game.chance_segment[1 + pos_in_game_segment:1 + pos_in_game_segment +
self._cfg.num_unroll_steps].tolist()
# add mask for invalid actions (out of trajectory)
mask_tmp = [1. for i in range(len(actions_tmp))]
mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps - len(mask_tmp))]

# pad random action
actions_tmp += [
np.random.randint(0, game.action_space_size)
for _ in range(self._cfg.num_unroll_steps - len(actions_tmp))
]
if self._cfg.use_ture_chance_label_in_chance_encoder:
chances_tmp += [
np.random.randint(0, game.action_space_size)
for _ in range(self._cfg.num_unroll_steps - len(chances_tmp))
]
# obtain the input observations
# pad if length of obs in game_segment is less than stack+num_unroll_steps
# e.g. stack+num_unroll_steps 4+5
obs_list.append(
game_segment_list[i].get_unroll_obs(
pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True
)
)
action_list.append(actions_tmp)
mask_list.append(mask_tmp)
if self._cfg.use_ture_chance_label_in_chance_encoder:
chance_list.append(chances_tmp)

# formalize the input observations
obs_list = prepare_observation(obs_list, self._cfg.model.model_type)

# formalize the inputs of a batch
if self._cfg.use_ture_chance_label_in_chance_encoder:
current_batch = [obs_list, action_list, mask_list, batch_index_list, weights_list, make_time_list,
chance_list]
else:
current_batch = [obs_list, action_list, mask_list, batch_index_list, weights_list, make_time_list]
for i in range(len(current_batch)):
current_batch[i] = np.asarray(current_batch[i])

total_transitions = self.get_num_of_transitions()

# obtain the context of value targets
reward_value_context = self._prepare_reward_value_context(
batch_index_list, game_segment_list, pos_in_game_segment_list, total_transitions
)
"""
only reanalyze recent reanalyze_ratio (e.g. 50%) data
if self._cfg.reanalyze_outdated is True, batch_index_list is sorted according to its generated env_steps
0: reanalyze_num -> reanalyzed policy, reanalyze_num:end -> non reanalyzed policy
"""
reanalyze_num = int(batch_size * reanalyze_ratio)
# reanalyzed policy
if reanalyze_num > 0:
# obtain the context of reanalyzed policy targets
policy_re_context = self._prepare_policy_reanalyzed_context(
batch_index_list[:reanalyze_num], game_segment_list[:reanalyze_num],
pos_in_game_segment_list[:reanalyze_num]
)
else:
policy_re_context = None

# non reanalyzed policy
if reanalyze_num < batch_size:
# obtain the context of non-reanalyzed policy targets
policy_non_re_context = self._prepare_policy_non_reanalyzed_context(
batch_index_list[reanalyze_num:], game_segment_list[reanalyze_num:],
pos_in_game_segment_list[reanalyze_num:]
)
else:
policy_non_re_context = None

context = reward_value_context, policy_re_context, policy_non_re_context, current_batch
return context

def update_priority(self, train_data: List[np.ndarray], batch_priorities: Any) -> None:
"""
Overview:
Update the priority of training data.
Arguments:
- train_data (:obj:`Optional[List[Optional[np.ndarray]]]`): training data to be updated priority.
- batch_priorities (:obj:`batch_priorities`): priorities to update to.
NOTE:
train_data = [current_batch, target_batch]
if self._cfg.use_ture_chance_label_in_chance_encoder:
obs_batch_orig, action_batch, mask_batch, indices, weights, make_time, chance_batch = current_batch
else:
obs_batch_orig, action_batch, mask_batch, indices, weights, make_time = current_batch
"""
indices = train_data[0][3]
metas = {'make_time': train_data[0][5], 'batch_priorities': batch_priorities}
# only update the priorities for data still in replay buffer
for i in range(len(indices)):
if metas['make_time'][i] > self.clear_time:
idx, prio = indices[i], metas['batch_priorities'][i]
self.game_pos_priorities[idx] = prio
57 changes: 37 additions & 20 deletions lzero/mcts/buffer/game_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,17 @@ def __init__(self, action_space: int, game_segment_length: int = 200, config: Ea
"""
self.action_space = action_space
self.game_segment_length = game_segment_length
self.config = config

self.num_unroll_steps = config.num_unroll_steps
self.td_steps = config.td_steps
self.frame_stack_num = config.model.frame_stack_num
self.discount_factor = config.discount_factor
self.action_space_size = config.model.action_space_size
self.gray_scale = config.gray_scale
self.transform2string = config.transform2string
self.sampled_algo = config.sampled_algo
self.gumbel_algo = config.gumbel_algo
self.use_ture_chance_label_in_chance_encoder = config.use_ture_chance_label_in_chance_encoder

if isinstance(config.model.observation_shape, int) or len(config.model.observation_shape) == 1:
# for vector obs input, e.g. classical control and box2d environments
self.zero_obs_shape = config.model.observation_shape
Expand All @@ -71,8 +77,11 @@ def __init__(self, action_space: int, game_segment_length: int = 200, config: Ea

self.improved_policy_probs = []

if self.config.sampled_algo:
if self.sampled_algo:
self.root_sampled_actions = []
if self.use_ture_chance_label_in_chance_encoder:
self.chance_segment = []


def get_unroll_obs(self, timestep: int, num_unroll_steps: int = 0, padding: bool = False) -> np.ndarray:
"""
Expand All @@ -89,8 +98,8 @@ def get_unroll_obs(self, timestep: int, num_unroll_steps: int = 0, padding: bool
if pad_len > 0:
pad_frames = np.array([stacked_obs[-1] for _ in range(pad_len)])
stacked_obs = np.concatenate((stacked_obs, pad_frames))
if self.config.transform2string:
stacked_obs = [jpeg_data_decompressor(obs, self.config.gray_scale) for obs in stacked_obs]
if self.transform2string:
stacked_obs = [jpeg_data_decompressor(obs, self.gray_scale) for obs in stacked_obs]
return stacked_obs

def zero_obs(self) -> List:
Expand All @@ -114,12 +123,10 @@ def get_obs(self) -> List:
assert timestep_obs == timestep_reward, "timestep_obs: {}, timestep_reward: {}".format(
timestep_obs, timestep_reward
)
# TODO:
timestep = timestep_obs
timestep = timestep_reward
stacked_obs = self.obs_segment[timestep:timestep + self.frame_stack_num]
if self.config.transform2string:
stacked_obs = [jpeg_data_decompressor(obs, self.config.gray_scale) for obs in stacked_obs]
if self.transform2string:
stacked_obs = [jpeg_data_decompressor(obs, self.gray_scale) for obs in stacked_obs]
return stacked_obs

def append(
Expand All @@ -128,7 +135,8 @@ def append(
obs: np.ndarray,
reward: np.ndarray,
action_mask: np.ndarray = None,
to_play: int = -1
to_play: int = -1,
chance: int = 0,
) -> None:
"""
Overview:
Expand All @@ -140,10 +148,12 @@ def append(

self.action_mask_segment.append(action_mask)
self.to_play_segment.append(to_play)
if self.use_ture_chance_label_in_chance_encoder:
self.chance_segment.append(chance)

def pad_over(
self, next_segment_observations: List, next_segment_rewards: List, next_segment_root_values: List,
next_segment_child_visits: List, next_segment_improved_policy: List = None
next_segment_child_visits: List, next_segment_improved_policy: List = None, next_chances: List = None,
) -> None:
"""
Overview:
Expand All @@ -158,15 +168,15 @@ def pad_over(
- next_segment_child_visits (:obj:`list`): root visit count distributions of MCTS from the next game_segment
- next_segment_improved_policy (:obj:`list`): root children select policy of MCTS from the next game_segment (Only used in Gumbel MuZero)
"""
assert len(next_segment_observations) <= self.config.num_unroll_steps
assert len(next_segment_child_visits) <= self.config.num_unroll_steps
assert len(next_segment_root_values) <= self.config.num_unroll_steps + self.config.td_steps
assert len(next_segment_rewards) <= self.config.num_unroll_steps + self.config.td_steps - 1
assert len(next_segment_observations) <= self.num_unroll_steps
assert len(next_segment_child_visits) <= self.num_unroll_steps
assert len(next_segment_root_values) <= self.num_unroll_steps + self.num_unroll_steps
assert len(next_segment_rewards) <= self.num_unroll_steps + self.num_unroll_steps - 1
# ==============================================================
# The core difference between GumbelMuZero and MuZero
# ==============================================================
if self.config.gumbel_algo:
assert len(next_segment_improved_policy) <= self.config.num_unroll_steps + self.config.td_steps
if self.gumbel_algo:
assert len(next_segment_improved_policy) <= self.num_unroll_steps + self.num_unroll_steps

# NOTE: next block observation should start from (stacked_observation - 1) in next trajectory
for observation in next_segment_observations:
Expand All @@ -181,9 +191,12 @@ def pad_over(
for child_visits in next_segment_child_visits:
self.child_visit_segment.append(child_visits)

if self.config.gumbel_algo:
if self.gumbel_algo:
for improved_policy in next_segment_improved_policy:
self.improved_policy_probs.append(improved_policy)
if self.use_ture_chance_label_in_chance_encoder:
for chances in next_chances:
self.chance_segment.append(chances)

def get_targets(self, timestep: int) -> Tuple:
"""
Expand All @@ -203,10 +216,10 @@ def store_search_stats(
if idx is None:
self.child_visit_segment.append([visit_count / sum_visits for visit_count in visit_counts])
self.root_value_segment.append(root_value)
if self.config.sampled_algo:
if self.sampled_algo:
self.root_sampled_actions.append(root_sampled_actions)
# store the improved policy in Gumbel Muzero: \pi'=softmax(logits + \sigma(CompletedQ))
if self.config.gumbel_algo:
if self.gumbel_algo:
self.improved_policy_probs.append(improved_policy)
else:
self.child_visit_segment[idx] = [visit_count / sum_visits for visit_count in visit_counts]
Expand Down Expand Up @@ -261,6 +274,8 @@ def game_segment_to_array(self) -> None:

self.action_mask_segment = np.array(self.action_mask_segment)
self.to_play_segment = np.array(self.to_play_segment)
if self.use_ture_chance_label_in_chance_encoder:
self.chance_segment = np.array(self.chance_segment)

def reset(self, init_observations: np.ndarray) -> None:
"""
Expand All @@ -279,6 +294,8 @@ def reset(self, init_observations: np.ndarray) -> None:

self.action_mask_segment = []
self.to_play_segment = []
if self.use_ture_chance_label_in_chance_encoder:
self.chance_segment = []

assert len(init_observations) == self.frame_stack_num

Expand Down
Empty file.
Loading

0 comments on commit 9c42878

Please sign in to comment.