-
Notifications
You must be signed in to change notification settings - Fork 113
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feature(pu/zt): add 2048 env and Stochastic MuZero (#64)
* add stochastic mz ptree * add stochastic mz ctree * add box2d, classic conrol, and 2048 config * made corrections to the comments and naming issues * made corrections to the comments and naming issues * polish(pu): polish game_2048_env * polish(pu): polish chance encoder * fix(pu): fix chance encoder related loss * sync code * polish(pu): polish 2048 env, add env save_render_gif method, add 2048 env unittest, add stochatic muzero model unittest * feature(pu): add stochastic muzero eval config * polish(pu): polish 2048 save_replay method * feature(pu): add num_of_possible_chance_tile option in 2048 env * polish(pu): delete collector filed in create config, move eval_config to entry directory * sync code * polish(pu): polish 2048 rule_bot move method, polish 2048 env, polish stochastic muzero game buffer * feature(pu): add stochastic_muzero_model_mlp * polish(pu): polish stochastic muzero configs * feature(pu): add analyze utlis for chance distribution * polish(pu): delete model_path personal info * polish(pu): add TestVisualizationFunctions, polish stochastic muzero model, rename xxx_eval_config.py to xxx_eval.py * fix(pu): fix test_game_segment.py * polish(pu): polish comments, abstract a get_target_obs_index_in_step_k method and add its unittest * polish(pu): use _get_target_obs_index_in_step_k in all policy, rename step_i to step_k --------- Co-authored-by: timothijoe <zt1301112@gmail.com> Co-authored-by: 蒲源 <PJLAB\puyuan@pjnl104220214l.pjlab.org>
- Loading branch information
1 parent
eccda94
commit 9c42878
Showing
89 changed files
with
6,841 additions
and
345 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
from typing import Any, Tuple, List | ||
|
||
import numpy as np | ||
from ding.utils import BUFFER_REGISTRY | ||
|
||
from lzero.mcts.utils import prepare_observation | ||
from .game_buffer_muzero import MuZeroGameBuffer | ||
|
||
|
||
@BUFFER_REGISTRY.register('game_buffer_stochastic_muzero') | ||
class StochasticMuZeroGameBuffer(MuZeroGameBuffer): | ||
""" | ||
Overview: | ||
The specific game buffer for Stochastic MuZero policy. | ||
""" | ||
|
||
def __init__(self, cfg: dict): | ||
super().__init__(cfg) | ||
""" | ||
Overview: | ||
Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key | ||
in the default configuration, the user-provided value will override the default configuration. Otherwise, | ||
the default configuration will be used. | ||
""" | ||
default_config = self.default_config() | ||
default_config.update(cfg) | ||
self._cfg = default_config | ||
assert self._cfg.env_type in ['not_board_games', 'board_games'] | ||
self.replay_buffer_size = self._cfg.replay_buffer_size | ||
self.batch_size = self._cfg.batch_size | ||
self._alpha = self._cfg.priority_prob_alpha | ||
self._beta = self._cfg.priority_prob_beta | ||
|
||
self.keep_ratio = 1 | ||
self.model_update_interval = 10 | ||
self.num_of_collected_episodes = 0 | ||
self.base_idx = 0 | ||
self.clear_time = 0 | ||
|
||
self.game_segment_buffer = [] | ||
self.game_pos_priorities = [] | ||
self.game_segment_game_pos_look_up = [] | ||
|
||
def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: | ||
""" | ||
Overview: | ||
first sample orig_data through ``_sample_orig_data()``, | ||
then prepare the context of a batch: | ||
reward_value_context: the context of reanalyzed value targets | ||
policy_re_context: the context of reanalyzed policy targets | ||
policy_non_re_context: the context of non-reanalyzed policy targets | ||
current_batch: the inputs of batch | ||
Arguments: | ||
- batch_size (:obj:`int`): the batch size of orig_data from replay buffer. | ||
- reanalyze_ratio (:obj:`float`): ratio of reanalyzed policy (value is 100% reanalyzed) | ||
Returns: | ||
- context (:obj:`Tuple`): reward_value_context, policy_re_context, policy_non_re_context, current_batch | ||
""" | ||
# obtain the batch context from replay buffer | ||
orig_data = self._sample_orig_data(batch_size) | ||
game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time_list = orig_data | ||
batch_size = len(batch_index_list) | ||
obs_list, action_list, mask_list = [], [], [] | ||
if self._cfg.use_ture_chance_label_in_chance_encoder: | ||
chance_list = [] | ||
# prepare the inputs of a batch | ||
for i in range(batch_size): | ||
game = game_segment_list[i] | ||
pos_in_game_segment = pos_in_game_segment_list[i] | ||
|
||
actions_tmp = game.action_segment[pos_in_game_segment:pos_in_game_segment + | ||
self._cfg.num_unroll_steps].tolist() | ||
if self._cfg.use_ture_chance_label_in_chance_encoder: | ||
chances_tmp = game.chance_segment[1 + pos_in_game_segment:1 + pos_in_game_segment + | ||
self._cfg.num_unroll_steps].tolist() | ||
# add mask for invalid actions (out of trajectory) | ||
mask_tmp = [1. for i in range(len(actions_tmp))] | ||
mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps - len(mask_tmp))] | ||
|
||
# pad random action | ||
actions_tmp += [ | ||
np.random.randint(0, game.action_space_size) | ||
for _ in range(self._cfg.num_unroll_steps - len(actions_tmp)) | ||
] | ||
if self._cfg.use_ture_chance_label_in_chance_encoder: | ||
chances_tmp += [ | ||
np.random.randint(0, game.action_space_size) | ||
for _ in range(self._cfg.num_unroll_steps - len(chances_tmp)) | ||
] | ||
# obtain the input observations | ||
# pad if length of obs in game_segment is less than stack+num_unroll_steps | ||
# e.g. stack+num_unroll_steps 4+5 | ||
obs_list.append( | ||
game_segment_list[i].get_unroll_obs( | ||
pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True | ||
) | ||
) | ||
action_list.append(actions_tmp) | ||
mask_list.append(mask_tmp) | ||
if self._cfg.use_ture_chance_label_in_chance_encoder: | ||
chance_list.append(chances_tmp) | ||
|
||
# formalize the input observations | ||
obs_list = prepare_observation(obs_list, self._cfg.model.model_type) | ||
|
||
# formalize the inputs of a batch | ||
if self._cfg.use_ture_chance_label_in_chance_encoder: | ||
current_batch = [obs_list, action_list, mask_list, batch_index_list, weights_list, make_time_list, | ||
chance_list] | ||
else: | ||
current_batch = [obs_list, action_list, mask_list, batch_index_list, weights_list, make_time_list] | ||
for i in range(len(current_batch)): | ||
current_batch[i] = np.asarray(current_batch[i]) | ||
|
||
total_transitions = self.get_num_of_transitions() | ||
|
||
# obtain the context of value targets | ||
reward_value_context = self._prepare_reward_value_context( | ||
batch_index_list, game_segment_list, pos_in_game_segment_list, total_transitions | ||
) | ||
""" | ||
only reanalyze recent reanalyze_ratio (e.g. 50%) data | ||
if self._cfg.reanalyze_outdated is True, batch_index_list is sorted according to its generated env_steps | ||
0: reanalyze_num -> reanalyzed policy, reanalyze_num:end -> non reanalyzed policy | ||
""" | ||
reanalyze_num = int(batch_size * reanalyze_ratio) | ||
# reanalyzed policy | ||
if reanalyze_num > 0: | ||
# obtain the context of reanalyzed policy targets | ||
policy_re_context = self._prepare_policy_reanalyzed_context( | ||
batch_index_list[:reanalyze_num], game_segment_list[:reanalyze_num], | ||
pos_in_game_segment_list[:reanalyze_num] | ||
) | ||
else: | ||
policy_re_context = None | ||
|
||
# non reanalyzed policy | ||
if reanalyze_num < batch_size: | ||
# obtain the context of non-reanalyzed policy targets | ||
policy_non_re_context = self._prepare_policy_non_reanalyzed_context( | ||
batch_index_list[reanalyze_num:], game_segment_list[reanalyze_num:], | ||
pos_in_game_segment_list[reanalyze_num:] | ||
) | ||
else: | ||
policy_non_re_context = None | ||
|
||
context = reward_value_context, policy_re_context, policy_non_re_context, current_batch | ||
return context | ||
|
||
def update_priority(self, train_data: List[np.ndarray], batch_priorities: Any) -> None: | ||
""" | ||
Overview: | ||
Update the priority of training data. | ||
Arguments: | ||
- train_data (:obj:`Optional[List[Optional[np.ndarray]]]`): training data to be updated priority. | ||
- batch_priorities (:obj:`batch_priorities`): priorities to update to. | ||
NOTE: | ||
train_data = [current_batch, target_batch] | ||
if self._cfg.use_ture_chance_label_in_chance_encoder: | ||
obs_batch_orig, action_batch, mask_batch, indices, weights, make_time, chance_batch = current_batch | ||
else: | ||
obs_batch_orig, action_batch, mask_batch, indices, weights, make_time = current_batch | ||
""" | ||
indices = train_data[0][3] | ||
metas = {'make_time': train_data[0][5], 'batch_priorities': batch_priorities} | ||
# only update the priorities for data still in replay buffer | ||
for i in range(len(indices)): | ||
if metas['make_time'][i] > self.clear_time: | ||
idx, prio = indices[i], metas['batch_priorities'][i] | ||
self.game_pos_priorities[idx] = prio |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Oops, something went wrong.