Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(pu/zt): add 2048 env and Stochastic MuZero #64

Merged
merged 32 commits into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
27d9b01
add stochastic mz ptree
timothijoe Jun 10, 2023
3799b46
add stochastic mz ctree
timothijoe Jun 10, 2023
14e3822
add box2d, classic conrol, and 2048 config
Jun 13, 2023
06c0558
made corrections to the comments and naming issues
Jun 16, 2023
fa88aab
made corrections to the comments and naming issues
Jun 16, 2023
b7a3fba
ok
Jul 10, 2023
9168a1b
ok
Jul 10, 2023
648f747
ok
timothijoe Jul 10, 2023
4693272
ok
timothijoe Jul 11, 2023
7000eed
Merge branch 'main' of https://github.com/opendilab/LightZero into de…
puyuan1996 Aug 2, 2023
11b4b7b
polish(pu): polish game_2048_env
puyuan1996 Aug 2, 2023
6bd0310
polish(pu): polish chance encoder
puyuan1996 Aug 2, 2023
a77a5bf
Merge branch 'explicit_chance_branch' of https://github.com/timothijo…
puyuan1996 Aug 3, 2023
f85aec3
fix(pu): fix chance encoder related loss
puyuan1996 Aug 3, 2023
b22dea7
sync code
puyuan1996 Aug 4, 2023
860fda1
polish(pu): polish 2048 env, add env save_render_gif method, add 2048…
puyuan1996 Aug 8, 2023
6e13727
feature(pu): add stochastic muzero eval config
puyuan1996 Aug 9, 2023
6f519b6
polish(pu): polish 2048 save_replay method
puyuan1996 Aug 9, 2023
e5f6b08
feature(pu): add num_of_possible_chance_tile option in 2048 env
puyuan1996 Aug 9, 2023
7d6f4f1
polish(pu): delete collector filed in create config, move eval_config…
puyuan1996 Aug 18, 2023
1f82928
sync code
puyuan1996 Aug 19, 2023
3119258
Merge branch 'main' of https://github.com/opendilab/LightZero into de…
puyuan1996 Aug 23, 2023
3a00a35
polish(pu): polish 2048 rule_bot move method, polish 2048 env, polish…
puyuan1996 Sep 5, 2023
89ab22a
Merge branch 'main' of https://github.com/opendilab/LightZero into de…
puyuan1996 Sep 5, 2023
02046f4
feature(pu): add stochastic_muzero_model_mlp
puyuan1996 Sep 5, 2023
1cbce65
polish(pu): polish stochastic muzero configs
puyuan1996 Sep 5, 2023
b6e9006
feature(pu): add analyze utlis for chance distribution
puyuan1996 Sep 5, 2023
f4556ce
polish(pu): delete model_path personal info
puyuan1996 Sep 5, 2023
3b7bcb0
polish(pu): add TestVisualizationFunctions, polish stochastic muzero …
puyuan1996 Sep 10, 2023
1258be5
fix(pu): fix test_game_segment.py
puyuan1996 Sep 10, 2023
9e5b3d8
polish(pu): polish comments, abstract a get_target_obs_index_in_step_…
puyuan1996 Sep 12, 2023
7ea632f
polish(pu): use _get_target_obs_index_in_step_k in all policy, rename…
puyuan1996 Sep 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lzero/entry/eval_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def eval_muzero(
- policy (:obj:`Policy`): Converged policy.
"""
cfg, create_cfg = input_cfg
assert create_cfg.policy.type in ['efficientzero', 'muzero', 'sampled_efficientzero'], \
"LightZero now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero'"
assert create_cfg.policy.type in ['efficientzero', 'muzero', 'stochastic_muzero', 'gumbel_muzero', 'sampled_efficientzero'], \
"LightZero now only support the following algo.: 'efficientzero', 'muzero', 'stochastic_muzero', 'gumbel_muzero', 'sampled_efficientzero'"

if cfg.policy.cuda and torch.cuda.is_available():
cfg.policy.device = 'cuda'
Expand Down
4 changes: 3 additions & 1 deletion lzero/entry/train_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def train_muzero(
"""

cfg, create_cfg = input_cfg
assert create_cfg.policy.type in ['efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero'], \
assert create_cfg.policy.type in ['efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'stochastic_muzero'], \
"train_muzero entry now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero'"

if create_cfg.policy.type == 'muzero':
Expand All @@ -58,6 +58,8 @@ def train_muzero(
from lzero.mcts import SampledEfficientZeroGameBuffer as GameBuffer
elif create_cfg.policy.type == 'gumbel_muzero':
from lzero.mcts import GumbelMuZeroGameBuffer as GameBuffer
elif create_cfg.policy.type == 'stochastic_muzero':
from lzero.mcts import StochasticMuZeroGameBuffer as GameBuffer

if cfg.policy.cuda and torch.cuda.is_available():
cfg.policy.device = 'cuda'
Expand Down
1 change: 1 addition & 0 deletions lzero/mcts/buffer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .game_buffer_efficientzero import EfficientZeroGameBuffer
from .game_buffer_sampled_efficientzero import SampledEfficientZeroGameBuffer
from .game_buffer_gumbel_muzero import GumbelMuZeroGameBuffer
from .game_buffer_stochastic_muzero import StochasticMuZeroGameBuffer
711 changes: 711 additions & 0 deletions lzero/mcts/buffer/game_buffer_stochastic_muzero.py

Large diffs are not rendered by default.

17 changes: 15 additions & 2 deletions lzero/mcts/buffer/game_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ def __init__(self, action_space: int, game_segment_length: int = 200, config: Ea

if self.config.sampled_algo:
self.root_sampled_actions = []
if self.config.use_ture_chance_label_in_chance_encoder:
self.chance_segment = []


def get_unroll_obs(self, timestep: int, num_unroll_steps: int = 0, padding: bool = False) -> np.ndarray:
"""
Expand Down Expand Up @@ -128,7 +131,8 @@ def append(
obs: np.ndarray,
reward: np.ndarray,
action_mask: np.ndarray = None,
to_play: int = -1
to_play: int = -1,
chance: np.ndarray = 0,
puyuan1996 marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
"""
Overview:
Expand All @@ -140,10 +144,12 @@ def append(

self.action_mask_segment.append(action_mask)
self.to_play_segment.append(to_play)
if self.config.use_ture_chance_label_in_chance_encoder:
self.chance_segment.append(chance)

def pad_over(
self, next_segment_observations: List, next_segment_rewards: List, next_segment_root_values: List,
next_segment_child_visits: List, next_segment_improved_policy: List = None
next_segment_child_visits: List, next_segment_improved_policy: List = None, next_chances: List = None,
) -> None:
"""
Overview:
Expand Down Expand Up @@ -184,6 +190,9 @@ def pad_over(
if self.config.gumbel_algo:
for improved_policy in next_segment_improved_policy:
self.improved_policy_probs.append(improved_policy)
if self.config.use_ture_chance_label_in_chance_encoder:
for chances in next_chances:
self.chance_segment.append(chances)

def get_targets(self, timestep: int) -> Tuple:
"""
Expand Down Expand Up @@ -253,6 +262,8 @@ def game_segment_to_array(self) -> None:

self.action_mask_segment = np.array(self.action_mask_segment)
self.to_play_segment = np.array(self.to_play_segment)
if self.config.use_ture_chance_label_in_chance_encoder:
self.chance_segment = np.array(self.chance_segment)

def reset(self, init_observations: np.ndarray) -> None:
"""
Expand All @@ -271,6 +282,8 @@ def reset(self, init_observations: np.ndarray) -> None:

self.action_mask_segment = []
self.to_play_segment = []
if self.config.use_ture_chance_label_in_chance_encoder:
self.chance_segment = []

assert len(init_observations) == self.frame_stack_num

Expand Down
Empty file.
Loading
Loading