diff --git a/README.md b/README.md
index 1b7558d5d..6372e2483 100644
--- a/README.md
+++ b/README.md
@@ -122,23 +122,25 @@ LightZero is a library with a [PyTorch](https://pytorch.org/) implementation of
The environments and algorithms currently supported by LightZero are shown in the table below:
-| Env./Algo. | AlphaZero | MuZero | EfficientZero | Sampled EfficientZero | Gumbel MuZero | Stochastic MuZero | UniZero |
-|---------------| -------- | ------ |-------------| ------------------ | ---------- |----------------|---------------|
-| TicTacToe | โ | โ | ๐ | ๐ | โ | ๐ |โ|
-| Gomoku | โ | โ | ๐ | ๐ | โ | ๐ |โ|
-| Connect4 | โ | โ | ๐ | ๐ | ๐ | ๐ |โ|
-| 2048 | --- | โ | ๐ | ๐ | ๐ | โ |โ|
-| Chess | ๐ | ๐ | ๐ | ๐ | ๐ | ๐ |๐|
-| Go | ๐ | ๐ | ๐ | ๐ | ๐ | ๐ |๐|
-| CartPole | --- | โ | โ | โ | โ | โ |โ|
-| Pendulum | --- | โ | โ | โ | โ | โ |๐|
-| LunarLander | --- | โ | โ | โ | โ | โ |โ|
-| BipedalWalker | --- | โ | โ | โ | โ | ๐ |๐|
-| Atari | --- | โ | โ | โ | โ | โ |โ|
-| MuJoCo | --- | โ | โ | โ | ๐ | ๐ |๐|
-| MiniGrid | --- | โ | โ | โ | ๐ | ๐ |โ|
-| Bsuite | --- | โ | โ | โ | ๐ | ๐ |โ|
-| Memory | --- | โ | โ | โ | ๐ | ๐ |โ|
+
+| Env./Algo. | AlphaZero | MuZero | EfficientZero | Sampled EfficientZero | Gumbel MuZero | Stochastic MuZero | UniZero |ReZero |
+|---------------| -------- | ------ |-------------| ------------------ | ---------- |----------------|---------------|----------------|
+| TicTacToe | โ | โ | ๐ | ๐ | โ | ๐ |โ|๐ |
+| Gomoku | โ | โ | ๐ | ๐ | โ | ๐ |โ|โ |
+| Connect4 | โ | โ | ๐ | ๐ | ๐ | ๐ |โ|โ |
+| 2048 | --- | โ | ๐ | ๐ | ๐ | โ |โ|๐ |
+| Chess | ๐ | ๐ | ๐ | ๐ | ๐ | ๐ |๐|๐ |
+| Go | ๐ | ๐ | ๐ | ๐ | ๐ | ๐ |๐|๐ |
+| CartPole | --- | โ | โ | โ | โ | โ |โ|โ |
+| Pendulum | --- | โ | โ | โ | โ | โ |๐|๐ |
+| LunarLander | --- | โ | โ | โ | โ | โ |โ|๐ |
+| BipedalWalker | --- | โ | โ | โ | โ | ๐ |๐|๐ |
+| Atari | --- | โ | โ | โ | โ | โ |โ|โ |
+| MuJoCo | --- | โ | โ | โ | ๐ | ๐ |๐|๐ |
+| MiniGrid | --- | โ | โ | โ | ๐ | ๐ |โ|๐ |
+| Bsuite | --- | โ | โ | โ | ๐ | ๐ |โ|๐ |
+| Memory | --- | โ | โ | โ | ๐ | ๐ |โ|๐ |
+
(1): "โ" means that the corresponding item is finished and well-tested.
diff --git a/README.zh.md b/README.zh.md
index f0770c7a0..6dc96a479 100644
--- a/README.zh.md
+++ b/README.zh.md
@@ -110,23 +110,23 @@ LightZero ๆฏๅบไบ [PyTorch](https://pytorch.org/) ๅฎ็ฐ็ MCTS ็ฎๆณๅบ๏ผ
LightZero ็ฎๅๆฏๆ็็ฏๅขๅ็ฎๆณๅฆไธ่กจๆ็คบ๏ผ
-| Env./Algo. | AlphaZero | MuZero | EfficientZero | Sampled EfficientZero | Gumbel MuZero | Stochastic MuZero | UniZero |
-|---------------| -------- | ------ |-------------| ------------------ | ---------- |----------------|---------------|
-| TicTacToe | โ | โ | ๐ | ๐ | โ | ๐ |โ|
-| Gomoku | โ | โ | ๐ | ๐ | โ | ๐ |โ|
-| Connect4 | โ | โ | ๐ | ๐ | ๐ | ๐ |โ|
-| 2048 | --- | โ | ๐ | ๐ | ๐ | โ |โ|
-| Chess | ๐ | ๐ | ๐ | ๐ | ๐ | ๐ |๐|
-| Go | ๐ | ๐ | ๐ | ๐ | ๐ | ๐ |๐|
-| CartPole | --- | โ | โ | โ | โ | โ |โ|
-| Pendulum | --- | โ | โ | โ | โ | โ |๐|
-| LunarLander | --- | โ | โ | โ | โ | โ |โ|
-| BipedalWalker | --- | โ | โ | โ | โ | ๐ |๐|
-| Atari | --- | โ | โ | โ | โ | โ |โ|
-| MuJoCo | --- | โ | โ | โ | ๐ | ๐ |๐|
-| MiniGrid | --- | โ | โ | โ | ๐ | ๐ |โ|
-| Bsuite | --- | โ | โ | โ | ๐ | ๐ |โ|
-| Memory | --- | โ | โ | โ | ๐ | ๐ |โ|
+| Env./Algo. | AlphaZero | MuZero | EfficientZero | Sampled EfficientZero | Gumbel MuZero | Stochastic MuZero | UniZero |ReZero |
+|---------------| -------- | ------ |-------------| ------------------ | ---------- |----------------|---------------|----------------|
+| TicTacToe | โ | โ | ๐ | ๐ | โ | ๐ |โ|๐ |
+| Gomoku | โ | โ | ๐ | ๐ | โ | ๐ |โ|โ |
+| Connect4 | โ | โ | ๐ | ๐ | ๐ | ๐ |โ|โ |
+| 2048 | --- | โ | ๐ | ๐ | ๐ | โ |โ|๐ |
+| Chess | ๐ | ๐ | ๐ | ๐ | ๐ | ๐ |๐|๐ |
+| Go | ๐ | ๐ | ๐ | ๐ | ๐ | ๐ |๐|๐ |
+| CartPole | --- | โ | โ | โ | โ | โ |โ|โ |
+| Pendulum | --- | โ | โ | โ | โ | โ |๐|๐ |
+| LunarLander | --- | โ | โ | โ | โ | โ |โ|๐ |
+| BipedalWalker | --- | โ | โ | โ | โ | ๐ |๐|๐ |
+| Atari | --- | โ | โ | โ | โ | โ |โ|โ |
+| MuJoCo | --- | โ | โ | โ | ๐ | ๐ |๐|๐ |
+| MiniGrid | --- | โ | โ | โ | ๐ | ๐ |โ|๐ |
+| Bsuite | --- | โ | โ | โ | ๐ | ๐ |โ|๐ |
+| Memory | --- | โ | โ | โ | ๐ | ๐ |โ|๐ |
(1): "โ" ่กจ็คบๅฏนๅบ็้กน็ฎๅทฒ็ปๅฎๆๅนถ็ป่ฟ่ฏๅฅฝ็ๆต่ฏใ
diff --git a/lzero/agent/alphazero.py b/lzero/agent/alphazero.py
index d31e17bb9..3eb265bde 100644
--- a/lzero/agent/alphazero.py
+++ b/lzero/agent/alphazero.py
@@ -198,9 +198,9 @@ def train(
new_data = sum(new_data, [])
if self.cfg.policy.update_per_collect is None:
- # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio.
+ # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio.
collected_transitions_num = len(new_data)
- update_per_collect = int(collected_transitions_num * self.cfg.policy.model_update_ratio)
+ update_per_collect = int(collected_transitions_num * self.cfg.policy.replay_ratio)
replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
# Learn policy from collected data
diff --git a/lzero/agent/efficientzero.py b/lzero/agent/efficientzero.py
index 421cea881..b2362d9cc 100644
--- a/lzero/agent/efficientzero.py
+++ b/lzero/agent/efficientzero.py
@@ -228,9 +228,9 @@ def train(
# Collect data by default config n_sample/n_episode.
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
if self.cfg.policy.update_per_collect is None:
- # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio.
+ # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio.
collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]])
- update_per_collect = int(collected_transitions_num * self.cfg.policy.model_update_ratio)
+ update_per_collect = int(collected_transitions_num * self.cfg.policy.replay_ratio)
# save returned new_data collected by the collector
replay_buffer.push_game_segments(new_data)
# remove the oldest data if the replay buffer is full.
diff --git a/lzero/agent/gumbel_muzero.py b/lzero/agent/gumbel_muzero.py
index 0df583ab1..0ad0ff0fb 100644
--- a/lzero/agent/gumbel_muzero.py
+++ b/lzero/agent/gumbel_muzero.py
@@ -228,9 +228,9 @@ def train(
# Collect data by default config n_sample/n_episode.
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
if self.cfg.policy.update_per_collect is None:
- # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio.
+ # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio.
collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]])
- update_per_collect = int(collected_transitions_num * self.cfg.policy.model_update_ratio)
+ update_per_collect = int(collected_transitions_num * self.cfg.policy.replay_ratio)
# save returned new_data collected by the collector
replay_buffer.push_game_segments(new_data)
# remove the oldest data if the replay buffer is full.
diff --git a/lzero/agent/muzero.py b/lzero/agent/muzero.py
index 55dda5d00..7a77996b5 100644
--- a/lzero/agent/muzero.py
+++ b/lzero/agent/muzero.py
@@ -228,9 +228,9 @@ def train(
# Collect data by default config n_sample/n_episode.
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
if self.cfg.policy.update_per_collect is None:
- # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio.
+ # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio.
collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]])
- update_per_collect = int(collected_transitions_num * self.cfg.policy.model_update_ratio)
+ update_per_collect = int(collected_transitions_num * self.cfg.policy.replay_ratio)
# save returned new_data collected by the collector
replay_buffer.push_game_segments(new_data)
# remove the oldest data if the replay buffer is full.
diff --git a/lzero/agent/sampled_alphazero.py b/lzero/agent/sampled_alphazero.py
index dc76c16e5..2f762c352 100644
--- a/lzero/agent/sampled_alphazero.py
+++ b/lzero/agent/sampled_alphazero.py
@@ -198,9 +198,9 @@ def train(
new_data = sum(new_data, [])
if self.cfg.policy.update_per_collect is None:
- # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio.
+ # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio.
collected_transitions_num = len(new_data)
- update_per_collect = int(collected_transitions_num * self.cfg.policy.model_update_ratio)
+ update_per_collect = int(collected_transitions_num * self.cfg.policy.replay_ratio)
replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
# Learn policy from collected data
diff --git a/lzero/agent/sampled_efficientzero.py b/lzero/agent/sampled_efficientzero.py
index 079bdd11d..19542120e 100644
--- a/lzero/agent/sampled_efficientzero.py
+++ b/lzero/agent/sampled_efficientzero.py
@@ -228,9 +228,9 @@ def train(
# Collect data by default config n_sample/n_episode.
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
if self.cfg.policy.update_per_collect is None:
- # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio.
+ # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio.
collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]])
- update_per_collect = int(collected_transitions_num * self.cfg.policy.model_update_ratio)
+ update_per_collect = int(collected_transitions_num * self.cfg.policy.replay_ratio)
# save returned new_data collected by the collector
replay_buffer.push_game_segments(new_data)
# remove the oldest data if the replay buffer is full.
diff --git a/lzero/entry/__init__.py b/lzero/entry/__init__.py
index 9b50ed6ec..37928cd10 100644
--- a/lzero/entry/__init__.py
+++ b/lzero/entry/__init__.py
@@ -4,4 +4,5 @@
from .train_muzero_with_reward_model import train_muzero_with_reward_model
from .eval_muzero import eval_muzero
from .eval_muzero_with_gym_env import eval_muzero_with_gym_env
-from .train_muzero_with_gym_env import train_muzero_with_gym_env
\ No newline at end of file
+from .train_muzero_with_gym_env import train_muzero_with_gym_env
+from .train_rezero import train_rezero
diff --git a/lzero/entry/train_alphazero.py b/lzero/entry/train_alphazero.py
index 3b455adb1..b77d69a75 100644
--- a/lzero/entry/train_alphazero.py
+++ b/lzero/entry/train_alphazero.py
@@ -119,9 +119,9 @@ def train_alphazero(
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
new_data = sum(new_data, [])
if cfg.policy.update_per_collect is None:
- # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio.
+ # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio.
collected_transitions_num = len(new_data)
- update_per_collect = int(collected_transitions_num * cfg.policy.model_update_ratio)
+ update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio)
replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
# Learn policy from collected data
diff --git a/lzero/entry/train_muzero.py b/lzero/entry/train_muzero.py
index 55ff5cb70..f8f70f80d 100644
--- a/lzero/entry/train_muzero.py
+++ b/lzero/entry/train_muzero.py
@@ -8,12 +8,12 @@
from ding.envs import create_env_manager
from ding.envs import get_vec_env_setting
from ding.policy import create_policy
-from ding.utils import set_pkg_seed, get_rank
from ding.rl_utils import get_epsilon_greedy_fn
+from ding.utils import set_pkg_seed, get_rank
from ding.worker import BaseLearner
from tensorboardX import SummaryWriter
-from lzero.entry.utils import log_buffer_memory_usage
+from lzero.entry.utils import log_buffer_memory_usage, log_buffer_run_time
from lzero.policy import visit_count_temperature
from lzero.policy.random_policy import LightZeroRandomPolicy
from lzero.worker import MuZeroCollector as Collector
@@ -69,7 +69,6 @@ def train_muzero(
cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)
# Create main components: env, policy
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
-
collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
@@ -138,6 +137,7 @@ def train_muzero(
while True:
log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger)
+ log_buffer_run_time(learner.train_iter, replay_buffer, tb_logger)
collect_kwargs = {}
# set temperature for visit count distributions according to the train_iter,
# please refer to Appendix D in MuZero paper for details.
@@ -172,9 +172,9 @@ def train_muzero(
# Collect data by default config n_sample/n_episode.
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
if cfg.policy.update_per_collect is None:
- # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio.
+ # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio.
collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]])
- update_per_collect = int(collected_transitions_num * cfg.policy.model_update_ratio)
+ update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio)
# save returned new_data collected by the collector
replay_buffer.push_game_segments(new_data)
# remove the oldest data if the replay buffer is full.
diff --git a/lzero/entry/train_muzero_with_gym_env.py b/lzero/entry/train_muzero_with_gym_env.py
index 3aa3906b9..f2ec552d8 100644
--- a/lzero/entry/train_muzero_with_gym_env.py
+++ b/lzero/entry/train_muzero_with_gym_env.py
@@ -136,9 +136,9 @@ def train_muzero_with_gym_env(
# Collect data by default config n_sample/n_episode.
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
if cfg.policy.update_per_collect is None:
- # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio.
+ # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio.
collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]])
- update_per_collect = int(collected_transitions_num * cfg.policy.model_update_ratio)
+ update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio)
# save returned new_data collected by the collector
replay_buffer.push_game_segments(new_data)
# remove the oldest data if the replay buffer is full.
diff --git a/lzero/entry/train_muzero_with_reward_model.py b/lzero/entry/train_muzero_with_reward_model.py
index 2ae409601..20c9d05b5 100644
--- a/lzero/entry/train_muzero_with_reward_model.py
+++ b/lzero/entry/train_muzero_with_reward_model.py
@@ -171,9 +171,9 @@ def train_muzero_with_reward_model(
reward_model.clear_old_data()
if cfg.policy.update_per_collect is None:
- # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio.
+ # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio.
collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]])
- update_per_collect = int(collected_transitions_num * cfg.policy.model_update_ratio)
+ update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio)
# save returned new_data collected by the collector
replay_buffer.push_game_segments(new_data)
# remove the oldest data if the replay buffer is full.
diff --git a/lzero/entry/train_rezero.py b/lzero/entry/train_rezero.py
new file mode 100644
index 000000000..565b5b444
--- /dev/null
+++ b/lzero/entry/train_rezero.py
@@ -0,0 +1,232 @@
+import logging
+import os
+from functools import partial
+from typing import Optional, Tuple
+
+import torch
+from ding.config import compile_config
+from ding.envs import create_env_manager, get_vec_env_setting
+from ding.policy import create_policy
+from ding.rl_utils import get_epsilon_greedy_fn
+from ding.utils import set_pkg_seed, get_rank
+from ding.worker import BaseLearner
+from tensorboardX import SummaryWriter
+
+from lzero.entry.utils import log_buffer_memory_usage, log_buffer_run_time
+from lzero.policy import visit_count_temperature
+from lzero.policy.random_policy import LightZeroRandomPolicy
+from lzero.worker import MuZeroCollector as Collector
+from lzero.worker import MuZeroEvaluator as Evaluator
+from .utils import random_collect
+
+
+def train_rezero(
+ input_cfg: Tuple[dict, dict],
+ seed: int = 0,
+ model: Optional[torch.nn.Module] = None,
+ model_path: Optional[str] = None,
+ max_train_iter: Optional[int] = int(1e10),
+ max_env_step: Optional[int] = int(1e10),
+) -> 'Policy':
+ """
+ Train entry for ReZero algorithms (ReZero-MuZero, ReZero-EfficientZero).
+
+ Args:
+ - input_cfg (:obj:`Tuple[dict, dict]`): Configuration dictionaries (user_config, create_cfg).
+ - seed (:obj:`int`): Random seed for reproducibility.
+ - model (:obj:`Optional[torch.nn.Module]`): Pre-initialized model instance.
+ - model_path (:obj:`Optional[str]`): Path to pretrained model checkpoint.
+ - max_train_iter (:obj:`Optional[int]`): Maximum number of training iterations.
+ - max_env_step (:obj:`Optional[int]`): Maximum number of environment steps.
+
+ Returns:
+ - Policy: Trained policy object.
+ """
+ cfg, create_cfg = input_cfg
+ assert create_cfg.policy.type in ['efficientzero', 'muzero'], \
+ "train_rezero entry only supports 'efficientzero' and 'muzero' algorithms"
+
+ # Import appropriate GameBuffer based on policy type
+ if create_cfg.policy.type == 'muzero':
+ from lzero.mcts import ReZeroMZGameBuffer as GameBuffer
+ elif create_cfg.policy.type == 'efficientzero':
+ from lzero.mcts import ReZeroEZGameBuffer as GameBuffer
+
+ # Set device (CUDA if available and enabled, otherwise CPU)
+ cfg.policy.device = 'cuda' if cfg.policy.cuda and torch.cuda.is_available() else 'cpu'
+
+ # Compile and finalize configuration
+ cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)
+
+ # Create environment, policy, and core components
+ env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
+ collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
+ evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
+
+ # Set seeds for reproducibility
+ collector_env.seed(cfg.seed)
+ evaluator_env.seed(cfg.seed, dynamic_seed=False)
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ # Adjust checkpoint saving frequency for offline evaluation
+ if cfg.policy.eval_offline:
+ cfg.policy.learn.learner.hook.save_ckpt_after_iter = cfg.policy.eval_freq
+
+ # Create and initialize policy
+ policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval'])
+ if model_path:
+ policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device))
+
+ # Initialize worker components
+ tb_logger = SummaryWriter(os.path.join(f'./{cfg.exp_name}/log/', 'serial')) if get_rank() == 0 else None
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ replay_buffer = GameBuffer(cfg.policy)
+ collector = Collector(
+ env=collector_env,
+ policy=policy.collect_mode,
+ tb_logger=tb_logger,
+ exp_name=cfg.exp_name,
+ policy_config=cfg.policy
+ )
+ evaluator = Evaluator(
+ eval_freq=cfg.policy.eval_freq,
+ n_evaluator_episode=cfg.env.n_evaluator_episode,
+ stop_value=cfg.env.stop_value,
+ env=evaluator_env,
+ policy=policy.eval_mode,
+ tb_logger=tb_logger,
+ exp_name=cfg.exp_name,
+ policy_config=cfg.policy
+ )
+
+ # Main training loop
+ learner.call_hook('before_run')
+ update_per_collect = cfg.policy.update_per_collect
+
+ # Perform initial random data collection if specified
+ if cfg.policy.random_collect_episode_num > 0:
+ random_collect(cfg.policy, policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer)
+
+ # Initialize offline evaluation tracking if enabled
+ if cfg.policy.eval_offline:
+ eval_train_iter_list, eval_train_envstep_list = [], []
+
+ # Evaluate initial random agent
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+
+ buffer_reanalyze_count = 0
+ train_epoch = 0
+ while True:
+ # Log buffer metrics
+ log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger)
+ log_buffer_run_time(learner.train_iter, replay_buffer, tb_logger)
+
+ # Prepare collection parameters
+ collect_kwargs = {
+ 'temperature': visit_count_temperature(
+ cfg.policy.manual_temperature_decay,
+ cfg.policy.fixed_temperature_value,
+ cfg.policy.threshold_training_steps_for_final_temperature,
+ trained_steps=learner.train_iter
+ ),
+ 'epsilon': get_epsilon_greedy_fn(
+ cfg.policy.eps.start, cfg.policy.eps.end,
+ cfg.policy.eps.decay, cfg.policy.eps.type
+ )(collector.envstep) if cfg.policy.eps.eps_greedy_exploration_in_collect else 0.0
+ }
+
+ # Periodic evaluation
+ if evaluator.should_eval(learner.train_iter):
+ if cfg.policy.eval_offline:
+ eval_train_iter_list.append(learner.train_iter)
+ eval_train_envstep_list.append(collector.envstep)
+ else:
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+
+ # Collect new data
+ new_data = collector.collect(
+ train_iter=learner.train_iter,
+ policy_kwargs=collect_kwargs,
+ collect_with_pure_policy=cfg.policy.collect_with_pure_policy
+ )
+
+ # Update collection frequency if not specified
+ if update_per_collect is None:
+ collected_transitions = sum(len(segment) for segment in new_data[0])
+ update_per_collect = int(collected_transitions * cfg.policy.replay_ratio)
+
+ # Update replay buffer
+ replay_buffer.push_game_segments(new_data)
+ replay_buffer.remove_oldest_data_to_fit()
+
+ # Periodically reanalyze buffer
+ if cfg.policy.buffer_reanalyze_freq >= 1:
+ # Reanalyze buffer times in one train_epoch
+ reanalyze_interval = update_per_collect // cfg.policy.buffer_reanalyze_freq
+ else:
+ # Reanalyze buffer each <1/buffer_reanalyze_freq> train_epoch
+ if train_epoch % (1//cfg.policy.buffer_reanalyze_freq) == 0 and replay_buffer.get_num_of_transitions() > 2000:
+ # When reanalyzing the buffer, the samples in the entire buffer are processed in mini-batches with a batch size of 2000.
+ # This is an empirically selected value for optimal efficiency.
+ replay_buffer.reanalyze_buffer(2000, policy)
+ buffer_reanalyze_count += 1
+ logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}')
+
+ # Training loop
+ for i in range(update_per_collect):
+ if cfg.policy.buffer_reanalyze_freq >= 1:
+ # Reanalyze buffer times in one train_epoch
+ if i % reanalyze_interval == 0 and replay_buffer.get_num_of_transitions() > 2000:
+ # When reanalyzing the buffer, the samples in the entire buffer are processed in mini-batches with a batch size of 2000.
+ # This is an empirically selected value for optimal efficiency.
+ replay_buffer.reanalyze_buffer(2000, policy)
+ buffer_reanalyze_count += 1
+ logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}')
+
+ # Sample and train on mini-batch
+ if replay_buffer.get_num_of_transitions() > cfg.policy.batch_size:
+ train_data = replay_buffer.sample(cfg.policy.batch_size, policy)
+ log_vars = learner.train(train_data, collector.envstep)
+
+ if cfg.policy.use_priority:
+ replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig'])
+ else:
+ logging.warning('Insufficient data in replay buffer for sampling. Continuing collection...')
+ break
+
+ train_epoch += 1
+ # Check termination conditions
+ if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
+ if cfg.policy.eval_offline:
+ perform_offline_evaluation(cfg, learner, policy, evaluator, eval_train_iter_list,
+ eval_train_envstep_list)
+ break
+
+ learner.call_hook('after_run')
+ return policy
+
+
+def perform_offline_evaluation(cfg, learner, policy, evaluator, eval_train_iter_list, eval_train_envstep_list):
+ """
+ Perform offline evaluation of the trained model.
+
+ Args:
+ cfg (dict): Configuration dictionary.
+ learner (BaseLearner): Learner object.
+ policy (Policy): Policy object.
+ evaluator (Evaluator): Evaluator object.
+ eval_train_iter_list (list): List of training iterations for evaluation.
+ eval_train_envstep_list (list): List of environment steps for evaluation.
+ """
+ logging.info('Starting offline evaluation...')
+ ckpt_dirname = f'./{learner.exp_name}/ckpt'
+
+ for train_iter, collector_envstep in zip(eval_train_iter_list, eval_train_envstep_list):
+ ckpt_path = os.path.join(ckpt_dirname, f'iteration_{train_iter}.pth.tar')
+ policy.learn_mode.load_state_dict(torch.load(ckpt_path, map_location=cfg.policy.device))
+ stop, reward = evaluator.eval(learner.save_checkpoint, train_iter, collector_envstep)
+ logging.info(f'Offline eval at iter: {train_iter}, steps: {collector_envstep}, reward: {reward}')
+
+ logging.info('Offline evaluation completed')
\ No newline at end of file
diff --git a/lzero/entry/utils.py b/lzero/entry/utils.py
index dfcbbf559..04d80f264 100644
--- a/lzero/entry/utils.py
+++ b/lzero/entry/utils.py
@@ -1,9 +1,9 @@
import os
+from typing import Optional, Callable
import psutil
from pympler.asizeof import asizeof
from tensorboardX import SummaryWriter
-from typing import Optional, Callable
def random_collect(
@@ -26,7 +26,8 @@ def random_collect(
collect_kwargs = {'temperature': 1, 'epsilon': 0.0}
# Collect data by default config n_sample/n_episode.
- new_data = collector.collect(n_episode=policy_cfg.random_collect_episode_num, train_iter=0, policy_kwargs=collect_kwargs)
+ new_data = collector.collect(n_episode=policy_cfg.random_collect_episode_num, train_iter=0,
+ policy_kwargs=collect_kwargs)
if postprocess_data_fn is not None:
new_data = postprocess_data_fn(new_data)
@@ -75,3 +76,41 @@ def log_buffer_memory_usage(train_iter: int, buffer: "GameBuffer", writer: Summa
# Record the memory usage of the process to TensorBoard.
writer.add_scalar('Buffer/memory_usage/process', process_memory_usage_mb, train_iter)
+
+
+def log_buffer_run_time(train_iter: int, buffer: "GameBuffer", writer: SummaryWriter) -> None:
+ """
+ Overview:
+ Log the average runtime metrics of the buffer to TensorBoard.
+ Arguments:
+ - train_iter (:obj:`int`): The current training iteration.
+ - buffer (:obj:`GameBuffer`): The game buffer containing runtime metrics.
+ - writer (:obj:`SummaryWriter`): The TensorBoard writer for logging metrics.
+
+ .. note::
+ "writer is None" indicates that the function is being called in a slave process in the DDP setup.
+ """
+ if writer is not None:
+ sample_times = buffer.sample_times
+
+ if sample_times == 0:
+ return
+
+ # Calculate and log average reanalyze time.
+ average_reanalyze_time = buffer.compute_target_re_time / sample_times
+ writer.add_scalar('Buffer/average_reanalyze_time', average_reanalyze_time, train_iter)
+
+ # Calculate and log average origin search time.
+ average_origin_search_time = buffer.origin_search_time / sample_times
+ writer.add_scalar('Buffer/average_origin_search_time', average_origin_search_time, train_iter)
+
+ # Calculate and log average reuse search time.
+ average_reuse_search_time = buffer.reuse_search_time / sample_times
+ writer.add_scalar('Buffer/average_reuse_search_time', average_reuse_search_time, train_iter)
+
+ # Calculate and log average active root number.
+ average_active_root_num = buffer.active_root_num / sample_times
+ writer.add_scalar('Buffer/average_active_root_num', average_active_root_num, train_iter)
+
+ # Reset the time records in the buffer.
+ buffer.reset_runtime_metrics()
diff --git a/lzero/mcts/buffer/__init__.py b/lzero/mcts/buffer/__init__.py
index 31680a75e..9e3c910a0 100644
--- a/lzero/mcts/buffer/__init__.py
+++ b/lzero/mcts/buffer/__init__.py
@@ -3,3 +3,5 @@
from .game_buffer_sampled_efficientzero import SampledEfficientZeroGameBuffer
from .game_buffer_gumbel_muzero import GumbelMuZeroGameBuffer
from .game_buffer_stochastic_muzero import StochasticMuZeroGameBuffer
+from .game_buffer_rezero_mz import ReZeroMZGameBuffer
+from .game_buffer_rezero_ez import ReZeroEZGameBuffer
diff --git a/lzero/mcts/buffer/game_buffer.py b/lzero/mcts/buffer/game_buffer.py
index ec594c98e..81bf7b88c 100644
--- a/lzero/mcts/buffer/game_buffer.py
+++ b/lzero/mcts/buffer/game_buffer.py
@@ -30,7 +30,7 @@ def default_config(cls: type) -> EasyDict:
# (int) The size/capacity of the replay buffer in terms of transitions.
replay_buffer_size=int(1e6),
# (float) The ratio of experiences required for the reanalyzing part in a minibatch.
- reanalyze_ratio=0.3,
+ reanalyze_ratio=0,
# (bool) Whether to consider outdated experiences for reanalyzing. If True, we first sort the data in the minibatch by the time it was produced
# and only reanalyze the oldest ``reanalyze_ratio`` fraction.
reanalyze_outdated=True,
@@ -124,6 +124,8 @@ def _sample_orig_data(self, batch_size: int) -> Tuple:
# sample according to transition index
# TODO(pu): replace=True
+ # print(f"num transitions is {num_of_transitions}")
+ # print(f"length of probs is {len(probs)}")
batch_index_list = np.random.choice(num_of_transitions, batch_size, p=probs, replace=False)
if self._cfg.reanalyze_outdated is True:
@@ -148,9 +150,50 @@ def _sample_orig_data(self, batch_size: int) -> Tuple:
orig_data = (game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time)
return orig_data
+
+ def _sample_orig_reanalyze_data(self, batch_size: int) -> Tuple:
+ """
+ Overview:
+ sample orig_data that contains:
+ game_segment_list: a list of game segments
+ pos_in_game_segment_list: transition index in game (relative index)
+ batch_index_list: the index of start transition of sampled minibatch in replay buffer
+ weights_list: the weight concerning the priority
+ make_time: the time the batch is made (for correctly updating replay buffer when data is deleted)
+ Arguments:
+ - batch_size (:obj:`int`): batch size
+ - beta: float the parameter in PER for calculating the priority
+ """
+ segment_length = (self.get_num_of_transitions()//2000)
+ assert self._beta > 0
+ num_of_transitions = self.get_num_of_transitions()
+ sample_points = num_of_transitions // segment_length
+
+ batch_index_list = np.random.choice(2000, batch_size, replace=False)
+
+ if self._cfg.reanalyze_outdated is True:
+ # NOTE: used in reanalyze part
+ batch_index_list.sort()
+
+ # TODO(xcy): use weighted sample
+ game_segment_list = []
+ pos_in_game_segment_list = []
+
+ for idx in batch_index_list:
+ game_segment_idx, pos_in_game_segment = self.game_segment_game_pos_look_up[idx*segment_length]
+ game_segment_idx -= self.base_idx
+ game_segment = self.game_segment_buffer[game_segment_idx]
+
+ game_segment_list.append(game_segment)
+ pos_in_game_segment_list.append(pos_in_game_segment)
+
+ make_time = [time.time() for _ in range(len(batch_index_list))]
+
+ orig_data = (game_segment_list, pos_in_game_segment_list, batch_index_list, [], make_time)
+ return orig_data
def _preprocess_to_play_and_action_mask(
- self, game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list
+ self, game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list, unroll_steps = None
):
"""
Overview:
@@ -158,15 +201,17 @@ def _preprocess_to_play_and_action_mask(
- to_play: {list: game_segment_batch_size * (num_unroll_steps+1)}
- action_mask: {list: game_segment_batch_size * (num_unroll_steps+1)}
"""
+ unroll_steps = unroll_steps if unroll_steps is not None else self._cfg.num_unroll_steps
+
to_play = []
for bs in range(game_segment_batch_size):
to_play_tmp = list(
to_play_segment[bs][pos_in_game_segment_list[bs]:pos_in_game_segment_list[bs] +
- self._cfg.num_unroll_steps + 1]
+ unroll_steps + 1]
)
- if len(to_play_tmp) < self._cfg.num_unroll_steps + 1:
+ if len(to_play_tmp) < unroll_steps + 1:
# NOTE: the effective to play index is {1,2}, for null padding data, we set to_play=-1
- to_play_tmp += [-1 for _ in range(self._cfg.num_unroll_steps + 1 - len(to_play_tmp))]
+ to_play_tmp += [-1 for _ in range(unroll_steps + 1 - len(to_play_tmp))]
to_play.append(to_play_tmp)
to_play = sum(to_play, [])
@@ -178,12 +223,12 @@ def _preprocess_to_play_and_action_mask(
for bs in range(game_segment_batch_size):
action_mask_tmp = list(
action_mask_segment[bs][pos_in_game_segment_list[bs]:pos_in_game_segment_list[bs] +
- self._cfg.num_unroll_steps + 1]
+ unroll_steps + 1]
)
- if len(action_mask_tmp) < self._cfg.num_unroll_steps + 1:
+ if len(action_mask_tmp) < unroll_steps + 1:
action_mask_tmp += [
list(np.ones(self._cfg.model.action_space_size, dtype=np.int8))
- for _ in range(self._cfg.num_unroll_steps + 1 - len(action_mask_tmp))
+ for _ in range(unroll_steps + 1 - len(action_mask_tmp))
]
action_mask.append(action_mask_tmp)
action_mask = to_list(action_mask)
@@ -334,6 +379,7 @@ def _push_game_segment(self, data: Any, meta: Optional[dict] = None) -> None:
valid_len = len(data)
else:
valid_len = len(data) - meta['unroll_plus_td_steps']
+ # print(f'valid_len is {valid_len}')
if meta['priorities'] is None:
max_prio = self.game_pos_priorities.max() if self.game_segment_buffer else 1
@@ -354,6 +400,8 @@ def _push_game_segment(self, data: Any, meta: Optional[dict] = None) -> None:
self.game_segment_game_pos_look_up += [
(self.base_idx + len(self.game_segment_buffer) - 1, step_pos) for step_pos in range(len(data))
]
+ # print(f'potioritys is {self.game_pos_priorities}')
+ # print(f'num of transitions is {len(self.game_segment_game_pos_look_up)}')
def remove_oldest_data_to_fit(self) -> None:
"""
diff --git a/lzero/mcts/buffer/game_buffer_muzero.py b/lzero/mcts/buffer/game_buffer_muzero.py
index 873fdbdf8..420eb5922 100644
--- a/lzero/mcts/buffer/game_buffer_muzero.py
+++ b/lzero/mcts/buffer/game_buffer_muzero.py
@@ -2,7 +2,8 @@
import numpy as np
import torch
-from ding.utils import BUFFER_REGISTRY
+from ding.utils import BUFFER_REGISTRY, EasyTimer
+# from line_profiler import line_profiler
from lzero.mcts.tree_search.mcts_ctree import MuZeroMCTSCtree as MCTSCtree
from lzero.mcts.tree_search.mcts_ptree import MuZeroMCTSPtree as MCTSPtree
@@ -49,6 +50,28 @@ def __init__(self, cfg: dict):
self.game_pos_priorities = []
self.game_segment_game_pos_look_up = []
+ self._compute_target_timer = EasyTimer()
+ self._reuse_search_timer = EasyTimer()
+ self._origin_search_timer = EasyTimer()
+ self.buffer_reanalyze = False
+
+ self.compute_target_re_time = 0
+ self.reuse_search_time = 0
+ self.origin_search_time = 0
+ self.sample_times = 0
+ self.active_root_num = 0
+
+ def reset_runtime_metrics(self):
+ """
+ Overview:
+ Reset the runtime metrics of the buffer.
+ """
+ self.compute_target_re_time = 0
+ self.reuse_search_time = 0
+ self.origin_search_time = 0
+ self.sample_times = 0
+ self.active_root_num = 0
+
def sample(
self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"]
) -> List[Any]:
@@ -72,8 +95,10 @@ def sample(
batch_rewards, batch_target_values = self._compute_target_reward_value(
reward_value_context, policy._target_model
)
- # target policy
- batch_target_policies_re = self._compute_target_policy_reanalyzed(policy_re_context, policy._target_model)
+ with self._compute_target_timer:
+ batch_target_policies_re = self._compute_target_policy_reanalyzed(policy_re_context, policy._target_model)
+ self.compute_target_re_time += self._compute_target_timer.value
+
batch_target_policies_non_re = self._compute_target_policy_non_reanalyzed(
policy_non_re_context, self._cfg.model.action_space_size
)
@@ -90,6 +115,8 @@ def sample(
# a batch contains the current_batch and the target_batch
train_data = [current_batch, target_batch]
+ if not self.buffer_reanalyze:
+ self.sample_times += 1
return train_data
def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
@@ -488,6 +515,7 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
batch_target_values = np.asarray(batch_target_values, dtype=object)
return batch_rewards, batch_target_values
+ # @profile
def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: Any) -> np.ndarray:
"""
Overview:
@@ -549,7 +577,6 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
_, reward_pool, policy_logits_pool, latent_state_roots = concat_output(network_output, data_type='muzero')
reward_pool = reward_pool.squeeze().tolist()
policy_logits_pool = policy_logits_pool.tolist()
- # noises are not necessary for reanalyze
noises = [
np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self._cfg.model.action_space_size
).astype(np.float32).tolist() for _ in range(transition_batch_size)
@@ -562,7 +589,9 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
else:
roots.prepare_no_noise(reward_pool, policy_logits_pool, to_play)
# do MCTS for a new policy with the recent target model
- MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play)
+ with self._origin_search_timer:
+ MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play)
+ self.origin_search_time += self._origin_search_timer.value
else:
# python mcts_tree
roots = MCTSPtree.roots(transition_batch_size, legal_actions)
diff --git a/lzero/mcts/buffer/game_buffer_rezero_ez.py b/lzero/mcts/buffer/game_buffer_rezero_ez.py
new file mode 100644
index 000000000..fdfae46df
--- /dev/null
+++ b/lzero/mcts/buffer/game_buffer_rezero_ez.py
@@ -0,0 +1,329 @@
+from typing import Any, List, Union, TYPE_CHECKING
+
+import numpy as np
+import torch
+from ding.utils import BUFFER_REGISTRY, EasyTimer
+
+from lzero.mcts.tree_search.mcts_ctree import EfficientZeroMCTSCtree as MCTSCtree
+from lzero.mcts.utils import prepare_observation
+from lzero.policy import to_detach_cpu_numpy, concat_output, inverse_scalar_transform
+from .game_buffer_efficientzero import EfficientZeroGameBuffer
+from .game_buffer_rezero_mz import ReZeroMZGameBuffer, compute_all_filters
+
+# from line_profiler import line_profiler
+
+if TYPE_CHECKING:
+ from lzero.policy import MuZeroPolicy, EfficientZeroPolicy, SampledEfficientZeroPolicy
+
+
+@BUFFER_REGISTRY.register('game_buffer_rezero_ez')
+class ReZeroEZGameBuffer(EfficientZeroGameBuffer, ReZeroMZGameBuffer):
+ """
+ Overview:
+ The specific game buffer for ReZero-EfficientZero policy.
+ """
+
+ def __init__(self, cfg: dict):
+ """
+ Overview:
+ Initialize the ReZeroEZGameBuffer with the given configuration. If a user passes in a cfg with a key that matches an existing key
+ in the default configuration, the user-provided value will override the default configuration. Otherwise,
+ the default configuration will be used.
+ """
+ super().__init__(cfg)
+
+ # Update the default configuration with the provided configuration.
+ default_config = self.default_config()
+ default_config.update(cfg)
+ self._cfg = default_config
+
+ # Ensure the configuration values are valid.
+ assert self._cfg.env_type in ['not_board_games', 'board_games']
+ assert self._cfg.action_type in ['fixed_action_space', 'varied_action_space']
+
+ # Initialize various parameters from the configuration.
+ self.replay_buffer_size = self._cfg.replay_buffer_size
+ self.batch_size = self._cfg.batch_size
+ self._alpha = self._cfg.priority_prob_alpha
+ self._beta = self._cfg.priority_prob_beta
+
+ self.keep_ratio = 1
+ self.model_update_interval = 10
+ self.num_of_collected_episodes = 0
+ self.base_idx = 0
+ self.clear_time = 0
+
+ self.game_segment_buffer = []
+ self.game_pos_priorities = []
+ self.game_segment_game_pos_look_up = []
+
+ # Timers for performance monitoring
+ self._compute_target_timer = EasyTimer()
+ self._reuse_search_timer = EasyTimer()
+ self._origin_search_timer = EasyTimer()
+ self.buffer_reanalyze = True
+
+ # Performance metrics
+ self.compute_target_re_time = 0
+ self.reuse_search_time = 0
+ self.origin_search_time = 0
+ self.sample_times = 0
+ self.active_root_num = 0
+ self.average_infer = 0
+
+ def sample(
+ self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"]
+ ) -> List[Any]:
+ """
+ Overview:
+ Sample data from the GameBuffer and prepare the current and target batch for training.
+ Arguments:
+ - batch_size (int): Batch size.
+ - policy (Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"]): Policy.
+ Returns:
+ - train_data (List): List of train data, including current_batch and target_batch.
+ """
+ policy._target_model.to(self._cfg.device)
+ policy._target_model.eval()
+
+ # Obtain the current_batch and prepare target context
+ reward_value_context, policy_re_context, policy_non_re_context, current_batch = self._make_batch(
+ batch_size, self._cfg.reanalyze_ratio
+ )
+
+ # Compute target values and policies
+ batch_value_prefixs, batch_target_values = self._compute_target_reward_value(
+ reward_value_context, policy._target_model
+ )
+
+ with self._compute_target_timer:
+ batch_target_policies_re = self._compute_target_policy_reanalyzed(policy_re_context, policy._target_model)
+ self.compute_target_re_time += self._compute_target_timer.value
+
+ batch_target_policies_non_re = self._compute_target_policy_non_reanalyzed(
+ policy_non_re_context, self._cfg.model.action_space_size
+ )
+
+ # Fuse reanalyzed and non-reanalyzed target policies
+ if 0 < self._cfg.reanalyze_ratio < 1:
+ batch_target_policies = np.concatenate([batch_target_policies_re, batch_target_policies_non_re])
+ elif self._cfg.reanalyze_ratio == 1:
+ batch_target_policies = batch_target_policies_re
+ elif self._cfg.reanalyze_ratio == 0:
+ batch_target_policies = batch_target_policies_non_re
+
+ target_batch = [batch_value_prefixs, batch_target_values, batch_target_policies]
+
+ # A batch contains the current_batch and the target_batch
+ train_data = [current_batch, target_batch]
+ if not self.buffer_reanalyze:
+ self.sample_times += 1
+ return train_data
+
+ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: Any, length=None) -> np.ndarray:
+ """
+ Overview:
+ Prepare policy targets from the reanalyzed context of policies.
+ Arguments:
+ - policy_re_context (List): List of policy context to be reanalyzed.
+ - model (Any): The model used for inference.
+ - length (int, optional): The length of unroll steps.
+ Returns:
+ - batch_target_policies_re (np.ndarray): The reanalyzed policy targets.
+ """
+ if policy_re_context is None:
+ return []
+
+ batch_target_policies_re = []
+
+ unroll_steps = length - 1 if length is not None else self._cfg.num_unroll_steps
+
+ policy_obs_list, true_action, policy_mask, pos_in_game_segment_list, batch_index_list, child_visits, root_values, game_segment_lens, action_mask_segment, to_play_segment = policy_re_context
+
+ transition_batch_size = len(policy_obs_list)
+ game_segment_batch_size = len(pos_in_game_segment_list)
+
+ to_play, action_mask = self._preprocess_to_play_and_action_mask(
+ game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list, length
+ )
+
+ if self._cfg.model.continuous_action_space:
+ action_mask = [
+ list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size)
+ ]
+ legal_actions = [
+ [-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size)
+ ]
+ else:
+ legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)]
+
+ with torch.no_grad():
+ policy_obs_list = prepare_observation(policy_obs_list, self._cfg.model.model_type)
+ slices = int(np.ceil(transition_batch_size / self._cfg.mini_infer_size))
+ network_output = []
+
+ for i in range(slices):
+ beg_index = self._cfg.mini_infer_size * i
+ end_index = self._cfg.mini_infer_size * (i + 1)
+ m_obs = torch.from_numpy(policy_obs_list[beg_index:end_index]).to(self._cfg.device)
+ m_output = model.initial_inference(m_obs)
+
+ if not model.training:
+ [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy(
+ [
+ m_output.latent_state,
+ inverse_scalar_transform(m_output.value, self._cfg.model.support_scale),
+ m_output.policy_logits
+ ]
+ )
+ m_output.reward_hidden_state = (
+ m_output.reward_hidden_state[0].detach().cpu().numpy(),
+ m_output.reward_hidden_state[1].detach().cpu().numpy()
+ )
+
+ network_output.append(m_output)
+
+ _, value_prefix_pool, policy_logits_pool, latent_state_roots, reward_hidden_state_roots = concat_output(
+ network_output, data_type='efficientzero'
+ )
+ value_prefix_pool = value_prefix_pool.squeeze().tolist()
+ policy_logits_pool = policy_logits_pool.tolist()
+ noises = [
+ np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self._cfg.model.action_space_size).astype(
+ np.float32).tolist()
+ for _ in range(transition_batch_size)
+ ]
+
+ if self._cfg.mcts_ctree:
+ legal_actions_by_iter = compute_all_filters(legal_actions, unroll_steps)
+ noises_by_iter = compute_all_filters(noises, unroll_steps)
+ value_prefix_pool_by_iter = compute_all_filters(value_prefix_pool, unroll_steps)
+ policy_logits_pool_by_iter = compute_all_filters(policy_logits_pool, unroll_steps)
+ to_play_by_iter = compute_all_filters(to_play, unroll_steps)
+ latent_state_roots_by_iter = compute_all_filters(latent_state_roots, unroll_steps)
+
+ batch1_core_by_iter = compute_all_filters(reward_hidden_state_roots[0][0], unroll_steps)
+ batch2_core_by_iter = compute_all_filters(reward_hidden_state_roots[1][0], unroll_steps)
+ true_action_by_iter = compute_all_filters(true_action, unroll_steps)
+
+ temp_values = []
+ temp_distributions = []
+ mcts_ctree = MCTSCtree(self._cfg)
+ temp_search_time = 0
+ temp_length = 0
+ temp_infer = 0
+
+ if self._cfg.reuse_search:
+ for iter in range(unroll_steps + 1):
+ iter_batch_size = transition_batch_size / (unroll_steps + 1)
+ roots = MCTSCtree.roots(iter_batch_size, legal_actions_by_iter[iter])
+ if self._cfg.reanalyze_noise:
+ roots.prepare(self._cfg.root_noise_weight, noises_by_iter[iter],
+ value_prefix_pool_by_iter[iter], policy_logits_pool_by_iter[iter],
+ to_play_by_iter[iter])
+ else:
+ roots.prepare_no_noise(value_prefix_pool_by_iter[iter], policy_logits_pool_by_iter[iter],
+ to_play_by_iter[iter])
+
+ if iter == 0:
+ with self._origin_search_timer:
+ mcts_ctree.search(roots, model, latent_state_roots_by_iter[iter],
+ [[batch1_core_by_iter[iter]], [batch2_core_by_iter[iter]]],
+ to_play_by_iter[iter])
+ self.origin_search_time += self._origin_search_timer.value
+ else:
+ with self._reuse_search_timer:
+ # ===================== Core implementation of ReZero: search_with_reuse =====================
+ length, average_infer = mcts_ctree.search_with_reuse(roots, model,
+ latent_state_roots_by_iter[iter],
+ [[batch1_core_by_iter[iter]],
+ [batch2_core_by_iter[iter]]],
+ to_play_by_iter[iter],
+ true_action_list=
+ true_action_by_iter[iter],
+ reuse_value_list=iter_values)
+ temp_search_time += self._reuse_search_timer.value
+ temp_length += length
+ temp_infer += average_infer
+
+ iter_values = roots.get_values()
+ iter_distributions = roots.get_distributions()
+ temp_values.append(iter_values)
+ temp_distributions.append(iter_distributions)
+
+ else:
+ for iter in range(unroll_steps + 1):
+ iter_batch_size = transition_batch_size / (unroll_steps + 1)
+ roots = MCTSCtree.roots(iter_batch_size, legal_actions_by_iter[iter])
+ if self._cfg.reanalyze_noise:
+ roots.prepare(self._cfg.root_noise_weight, noises_by_iter[iter],
+ value_prefix_pool_by_iter[iter], policy_logits_pool_by_iter[iter],
+ to_play_by_iter[iter])
+ else:
+ roots.prepare_no_noise(value_prefix_pool_by_iter[iter], policy_logits_pool_by_iter[iter],
+ to_play_by_iter[iter])
+
+ with self._origin_search_timer:
+ mcts_ctree.search(roots, model, latent_state_roots_by_iter[iter],
+ [[batch1_core_by_iter[iter]], [batch2_core_by_iter[iter]]],
+ to_play_by_iter[iter])
+ self.origin_search_time += self._origin_search_timer.value
+
+ iter_values = roots.get_values()
+ iter_distributions = roots.get_distributions()
+ temp_values.append(iter_values)
+ temp_distributions.append(iter_distributions)
+
+ self.origin_search_time = self.origin_search_time / (unroll_steps + 1)
+
+ if unroll_steps == 0:
+ self.reuse_search_time = 0
+ self.active_root_num = 0
+ else:
+ self.reuse_search_time += (temp_search_time / unroll_steps)
+ self.active_root_num += (temp_length / unroll_steps)
+ self.average_infer += (temp_infer / unroll_steps)
+
+ roots_legal_actions_list = legal_actions
+ temp_values.reverse()
+ temp_distributions.reverse()
+ roots_values = []
+ roots_distributions = []
+ [roots_values.extend(column) for column in zip(*temp_values)]
+ [roots_distributions.extend(column) for column in zip(*temp_distributions)]
+
+ policy_index = 0
+ for state_index, child_visit, root_value in zip(pos_in_game_segment_list, child_visits, root_values):
+ target_policies = []
+
+ for current_index in range(state_index, state_index + unroll_steps + 1):
+ distributions = roots_distributions[policy_index]
+ searched_value = roots_values[policy_index]
+
+ if policy_mask[policy_index] == 0:
+ target_policies.append([0 for _ in range(self._cfg.model.action_space_size)])
+ else:
+ if distributions is None:
+ target_policies.append(list(
+ np.ones(self._cfg.model.action_space_size) / self._cfg.model.action_space_size))
+ else:
+ sim_num = sum(distributions)
+ child_visit[current_index] = [visit_count / sim_num for visit_count in distributions]
+ root_value[current_index] = searched_value
+ if self._cfg.action_type == 'fixed_action_space':
+ sum_visits = sum(distributions)
+ policy = [visit_count / sum_visits for visit_count in distributions]
+ target_policies.append(policy)
+ else:
+ policy_tmp = [0 for _ in range(self._cfg.model.action_space_size)]
+ sum_visits = sum(distributions)
+ policy = [visit_count / sum_visits for visit_count in distributions]
+ for index, legal_action in enumerate(roots_legal_actions_list[policy_index]):
+ policy_tmp[legal_action] = policy[index]
+ target_policies.append(policy_tmp)
+
+ policy_index += 1
+
+ batch_target_policies_re.append(target_policies)
+
+ return np.array(batch_target_policies_re)
diff --git a/lzero/mcts/buffer/game_buffer_rezero_mz.py b/lzero/mcts/buffer/game_buffer_rezero_mz.py
new file mode 100644
index 000000000..9e864ac5e
--- /dev/null
+++ b/lzero/mcts/buffer/game_buffer_rezero_mz.py
@@ -0,0 +1,401 @@
+from typing import Any, List, Tuple, Union, TYPE_CHECKING
+
+import numpy as np
+import torch
+from ding.torch_utils.data_helper import to_list
+from ding.utils import BUFFER_REGISTRY, EasyTimer
+
+from lzero.mcts.tree_search.mcts_ctree import MuZeroMCTSCtree as MCTSCtree
+from lzero.mcts.tree_search.mcts_ptree import MuZeroMCTSPtree as MCTSPtree
+from lzero.mcts.utils import prepare_observation
+from lzero.policy import to_detach_cpu_numpy, concat_output, inverse_scalar_transform
+from .game_buffer_muzero import MuZeroGameBuffer
+
+# from line_profiler import line_profiler
+if TYPE_CHECKING:
+ from lzero.policy import MuZeroPolicy, EfficientZeroPolicy, SampledEfficientZeroPolicy
+
+
+def compute_all_filters(data, num_unroll_steps):
+ data_by_iter = []
+ for iter in range(num_unroll_steps + 1):
+ iter_data = [x for i, x in enumerate(data)
+ if (i + 1) % (num_unroll_steps + 1) ==
+ ((num_unroll_steps + 1 - iter) % (num_unroll_steps + 1))]
+ data_by_iter.append(iter_data)
+ return data_by_iter
+
+
+@BUFFER_REGISTRY.register('game_buffer_rezero_mz')
+class ReZeroMZGameBuffer(MuZeroGameBuffer):
+ """
+ Overview:
+ The specific game buffer for ReZero-MuZero policy.
+ """
+
+ def __init__(self, cfg: dict):
+ """
+ Overview:
+ Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key
+ in the default configuration, the user-provided value will override the default configuration. Otherwise,
+ the default configuration will be used.
+ """
+ super().__init__(cfg)
+ default_config = self.default_config()
+ default_config.update(cfg)
+ self._cfg = default_config
+
+ # Ensure valid configuration values
+ assert self._cfg.env_type in ['not_board_games', 'board_games']
+ assert self._cfg.action_type in ['fixed_action_space', 'varied_action_space']
+
+ # Initialize buffer parameters
+ self.replay_buffer_size = self._cfg.replay_buffer_size
+ self.batch_size = self._cfg.batch_size
+ self._alpha = self._cfg.priority_prob_alpha
+ self._beta = self._cfg.priority_prob_beta
+
+ self.keep_ratio = 1
+ self.model_update_interval = 10
+ self.num_of_collected_episodes = 0
+ self.base_idx = 0
+ self.clear_time = 0
+
+ self.game_segment_buffer = []
+ self.game_pos_priorities = []
+ self.game_segment_game_pos_look_up = []
+
+ self._compute_target_timer = EasyTimer()
+ self._reuse_search_timer = EasyTimer()
+ self._origin_search_timer = EasyTimer()
+ self.buffer_reanalyze = True
+ self.compute_target_re_time = 0
+ self.reuse_search_time = 0
+ self.origin_search_time = 0
+ self.sample_times = 0
+ self.active_root_num = 0
+ self.average_infer = 0
+
+ def reanalyze_buffer(
+ self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"]
+ ) -> List[Any]:
+ """
+ Overview:
+ Sample data from ``GameBuffer`` and prepare the current and target batch for training.
+ Arguments:
+ - batch_size (:obj:`int`): Batch size.
+ - policy (:obj:`Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"]`): Policy.
+ Returns:
+ - train_data (:obj:`List`): List of train data, including current_batch and target_batch.
+ """
+ assert self._cfg.mcts_ctree is True, "ReZero-MuZero only supports cpp mcts_ctree now!"
+ policy._target_model.to(self._cfg.device)
+ policy._target_model.eval()
+
+ # Obtain the current batch and prepare target context
+ policy_re_context = self._make_reanalyze_batch(batch_size)
+
+ with self._compute_target_timer:
+ segment_length = self.get_num_of_transitions() // 2000
+ batch_target_policies_re = self._compute_target_policy_reanalyzed(
+ policy_re_context, policy._target_model, segment_length
+ )
+
+ self.compute_target_re_time += self._compute_target_timer.value
+
+ if self.buffer_reanalyze:
+ self.sample_times += 1
+
+ def _make_reanalyze_batch(self, batch_size: int) -> Tuple[Any]:
+ """
+ Overview:
+ First sample orig_data through ``_sample_orig_data()``, then prepare the context of a batch:
+ reward_value_context: The context of reanalyzed value targets.
+ policy_re_context: The context of reanalyzed policy targets.
+ policy_non_re_context: The context of non-reanalyzed policy targets.
+ current_batch: The inputs of batch.
+ Arguments:
+ - batch_size (:obj:`int`): The batch size of orig_data from replay buffer.
+ Returns:
+ - context (:obj:`Tuple`): reward_value_context, policy_re_context, policy_non_re_context, current_batch
+ """
+ # Obtain the batch context from replay buffer
+ orig_data = self._sample_orig_reanalyze_data(batch_size)
+ game_segment_list, pos_in_game_segment_list, batch_index_list, _, make_time_list = orig_data
+ segment_length = self.get_num_of_transitions() // 2000
+ policy_re_context = self._prepare_policy_reanalyzed_context(
+ [], game_segment_list, pos_in_game_segment_list, segment_length
+ )
+ return policy_re_context
+
+ def _prepare_policy_reanalyzed_context(
+ self, batch_index_list: List[str], game_segment_list: List[Any], pos_in_game_segment_list: List[str],
+ length=None
+ ) -> List[Any]:
+ """
+ Overview:
+ Prepare the context of policies for calculating policy target in reanalyzing part.
+ Arguments:
+ - batch_index_list (:obj:`list`): Start transition index in the replay buffer.
+ - game_segment_list (:obj:`list`): List of game segments.
+ - pos_in_game_segment_list (:obj:`list`): Position of transition index in one game history.
+ - length (:obj:`int`, optional): Length of segments.
+ Returns:
+ - policy_re_context (:obj:`list`): policy_obs_list, policy_mask, pos_in_game_segment_list, indices,
+ child_visits, game_segment_lens, action_mask_segment, to_play_segment
+ """
+ zero_obs = game_segment_list[0].zero_obs()
+ with torch.no_grad():
+ policy_obs_list = []
+ true_action = []
+ policy_mask = []
+
+ unroll_steps = length - 1 if length is not None else self._cfg.num_unroll_steps
+ rewards, child_visits, game_segment_lens, root_values = [], [], [], []
+ action_mask_segment, to_play_segment = [], []
+
+ for game_segment, state_index in zip(game_segment_list, pos_in_game_segment_list):
+ game_segment_len = len(game_segment)
+ game_segment_lens.append(game_segment_len)
+ rewards.append(game_segment.reward_segment)
+ action_mask_segment.append(game_segment.action_mask_segment)
+ to_play_segment.append(game_segment.to_play_segment)
+ child_visits.append(game_segment.child_visit_segment)
+ root_values.append(game_segment.root_value_segment)
+
+ # Prepare the corresponding observations
+ game_obs = game_segment.get_unroll_obs(state_index, unroll_steps)
+ for current_index in range(state_index, state_index + unroll_steps + 1):
+ if current_index < game_segment_len:
+ policy_mask.append(1)
+ beg_index = current_index - state_index
+ end_index = beg_index + self._cfg.model.frame_stack_num
+ obs = game_obs[beg_index:end_index]
+ action = game_segment.action_segment[current_index]
+ if current_index == game_segment_len - 1:
+ action = -2 # use the illegal negative action to represent the gi
+ else:
+ policy_mask.append(0)
+ obs = zero_obs
+ action = -2 # use the illegal negative action to represent the padding action
+ policy_obs_list.append(obs)
+ true_action.append(action)
+
+ policy_re_context = [
+ policy_obs_list, true_action, policy_mask, pos_in_game_segment_list, batch_index_list, child_visits,
+ root_values, game_segment_lens, action_mask_segment, to_play_segment
+ ]
+ return policy_re_context
+
+ # @profile
+ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: Any, length=None) -> np.ndarray:
+ """
+ Overview:
+ Prepare policy targets from the reanalyzed context of policies.
+ Arguments:
+ - policy_re_context (:obj:`List`): List of policy context to be reanalyzed.
+ - model (:obj:`Any`): The model used for inference.
+ - length (:obj:`int`, optional): The length of the unroll steps.
+ Returns:
+ - batch_target_policies_re (:obj:`np.ndarray`): The reanalyzed batch target policies.
+ """
+ if policy_re_context is None:
+ return []
+
+ batch_target_policies_re = []
+
+ unroll_steps = length - 1 if length is not None else self._cfg.num_unroll_steps
+
+ # Unpack the policy reanalyze context
+ (
+ policy_obs_list, true_action, policy_mask, pos_in_game_segment_list, batch_index_list,
+ child_visits, root_values, game_segment_lens, action_mask_segment, to_play_segment
+ ) = policy_re_context
+
+ transition_batch_size = len(policy_obs_list)
+ game_segment_batch_size = len(pos_in_game_segment_list)
+
+ to_play, action_mask = self._preprocess_to_play_and_action_mask(
+ game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list, unroll_steps
+ )
+
+ if self._cfg.model.continuous_action_space:
+ action_mask = [
+ list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size)
+ ]
+ legal_actions = [
+ [-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size)
+ ]
+ else:
+ legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)]
+
+ with torch.no_grad():
+ policy_obs_list = prepare_observation(policy_obs_list, self._cfg.model.model_type)
+ slices = int(np.ceil(transition_batch_size / self._cfg.mini_infer_size))
+ network_output = []
+
+ for i in range(slices):
+ beg_index = self._cfg.mini_infer_size * i
+ end_index = self._cfg.mini_infer_size * (i + 1)
+ m_obs = torch.from_numpy(policy_obs_list[beg_index:end_index]).to(self._cfg.device)
+ m_output = model.initial_inference(m_obs)
+
+ if not model.training:
+ m_output.latent_state, m_output.value, m_output.policy_logits = to_detach_cpu_numpy(
+ [
+ m_output.latent_state,
+ inverse_scalar_transform(m_output.value, self._cfg.model.support_scale),
+ m_output.policy_logits
+ ]
+ )
+
+ network_output.append(m_output)
+
+ _, reward_pool, policy_logits_pool, latent_state_roots = concat_output(network_output, data_type='muzero')
+ reward_pool = reward_pool.squeeze().tolist()
+ policy_logits_pool = policy_logits_pool.tolist()
+ noises = [
+ np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self._cfg.model.action_space_size).astype(
+ np.float32).tolist()
+ for _ in range(transition_batch_size)
+ ]
+
+ if self._cfg.mcts_ctree:
+ legal_actions_by_iter = compute_all_filters(legal_actions, unroll_steps)
+ noises_by_iter = compute_all_filters(noises, unroll_steps)
+ reward_pool_by_iter = compute_all_filters(reward_pool, unroll_steps)
+ policy_logits_pool_by_iter = compute_all_filters(policy_logits_pool, unroll_steps)
+ to_play_by_iter = compute_all_filters(to_play, unroll_steps)
+ latent_state_roots_by_iter = compute_all_filters(latent_state_roots, unroll_steps)
+ true_action_by_iter = compute_all_filters(true_action, unroll_steps)
+
+ temp_values, temp_distributions = [], []
+ mcts_ctree = MCTSCtree(self._cfg)
+ temp_search_time, temp_length, temp_infer = 0, 0, 0
+
+ if self._cfg.reuse_search:
+ for iter in range(unroll_steps + 1):
+ iter_batch_size = transition_batch_size / (unroll_steps + 1)
+ roots = MCTSCtree.roots(iter_batch_size, legal_actions_by_iter[iter])
+
+ if self._cfg.reanalyze_noise:
+ roots.prepare(
+ self._cfg.root_noise_weight,
+ noises_by_iter[iter],
+ reward_pool_by_iter[iter],
+ policy_logits_pool_by_iter[iter],
+ to_play_by_iter[iter]
+ )
+ else:
+ roots.prepare_no_noise(
+ reward_pool_by_iter[iter],
+ policy_logits_pool_by_iter[iter],
+ to_play_by_iter[iter]
+ )
+
+ if iter == 0:
+ with self._origin_search_timer:
+ mcts_ctree.search(roots, model, latent_state_roots_by_iter[iter], to_play_by_iter[iter])
+ self.origin_search_time += self._origin_search_timer.value
+ else:
+ with self._reuse_search_timer:
+ # ===================== Core implementation of ReZero: search_with_reuse =====================
+ length, average_infer = mcts_ctree.search_with_reuse(
+ roots, model, latent_state_roots_by_iter[iter], to_play_by_iter[iter],
+ true_action_list=true_action_by_iter[iter], reuse_value_list=iter_values
+ )
+ temp_search_time += self._reuse_search_timer.value
+ temp_length += length
+ temp_infer += average_infer
+
+ iter_values = roots.get_values()
+ iter_distributions = roots.get_distributions()
+ temp_values.append(iter_values)
+ temp_distributions.append(iter_distributions)
+ else:
+ for iter in range(unroll_steps + 1):
+ iter_batch_size = transition_batch_size / (unroll_steps + 1)
+ roots = MCTSCtree.roots(iter_batch_size, legal_actions_by_iter[iter])
+
+ if self._cfg.reanalyze_noise:
+ roots.prepare(
+ self._cfg.root_noise_weight,
+ noises_by_iter[iter],
+ reward_pool_by_iter[iter],
+ policy_logits_pool_by_iter[iter],
+ to_play_by_iter[iter]
+ )
+ else:
+ roots.prepare_no_noise(
+ reward_pool_by_iter[iter],
+ policy_logits_pool_by_iter[iter],
+ to_play_by_iter[iter]
+ )
+
+ with self._origin_search_timer:
+ mcts_ctree.search(roots, model, latent_state_roots_by_iter[iter], to_play_by_iter[iter])
+ self.origin_search_time += self._origin_search_timer.value
+
+ iter_values = roots.get_values()
+ iter_distributions = roots.get_distributions()
+ temp_values.append(iter_values)
+ temp_distributions.append(iter_distributions)
+
+ self.origin_search_time /= (unroll_steps + 1)
+
+ if unroll_steps == 0:
+ self.reuse_search_time, self.active_root_num = 0, 0
+ else:
+ self.reuse_search_time += (temp_search_time / unroll_steps)
+ self.active_root_num += (temp_length / unroll_steps)
+ self.average_infer += (temp_infer / unroll_steps)
+
+ roots_legal_actions_list = legal_actions
+ temp_values.reverse()
+ temp_distributions.reverse()
+ roots_values, roots_distributions = [], []
+ [roots_values.extend(column) for column in zip(*temp_values)]
+ [roots_distributions.extend(column) for column in zip(*temp_distributions)]
+
+ policy_index = 0
+ for state_index, child_visit, root_value in zip(pos_in_game_segment_list, child_visits, root_values):
+ target_policies = []
+
+ for current_index in range(state_index, state_index + unroll_steps + 1):
+ distributions = roots_distributions[policy_index]
+ searched_value = roots_values[policy_index]
+
+ if policy_mask[policy_index] == 0:
+ target_policies.append([0 for _ in range(self._cfg.model.action_space_size)])
+ else:
+ if distributions is None:
+ target_policies.append(
+ list(np.ones(self._cfg.model.action_space_size) / self._cfg.model.action_space_size)
+ )
+ else:
+ # ===================== Update the data in buffer =====================
+ # After the reanalysis search, new target policies and root values are obtained.
+ # These target policies and root values are stored in the game segment, specifically in the ``child_visit_segment`` and ``root_value_segment``.
+ # We replace the data at the corresponding locations with the latest search results to maintain the most up-to-date targets.
+ sim_num = sum(distributions)
+ child_visit[current_index] = [visit_count / sim_num for visit_count in distributions]
+ root_value[current_index] = searched_value
+
+ if self._cfg.action_type == 'fixed_action_space':
+ sum_visits = sum(distributions)
+ policy = [visit_count / sum_visits for visit_count in distributions]
+ target_policies.append(policy)
+ else:
+ policy_tmp = [0 for _ in range(self._cfg.model.action_space_size)]
+ sum_visits = sum(distributions)
+ policy = [visit_count / sum_visits for visit_count in distributions]
+ for index, legal_action in enumerate(roots_legal_actions_list[policy_index]):
+ policy_tmp[legal_action] = policy[index]
+ target_policies.append(policy_tmp)
+
+ policy_index += 1
+
+ batch_target_policies_re.append(target_policies)
+
+ return np.array(batch_target_policies_re)
+
diff --git a/lzero/mcts/buffer/game_segment.py b/lzero/mcts/buffer/game_segment.py
index da889d98f..c50d93e3d 100644
--- a/lzero/mcts/buffer/game_segment.py
+++ b/lzero/mcts/buffer/game_segment.py
@@ -210,12 +210,15 @@ def store_search_stats(
store the visit count distributions and value of the root node after MCTS.
"""
sum_visits = sum(visit_counts)
+ if sum_visits == 0:
+ # if the sum of visit counts is 0, set it to a small value to avoid division by zero
+ sum_visits = 1e-6
if idx is None:
self.child_visit_segment.append([visit_count / sum_visits for visit_count in visit_counts])
self.root_value_segment.append(root_value)
if self.sampled_algo:
self.root_sampled_actions.append(root_sampled_actions)
- # store the improved policy in Gumbel Muzero: \pi'=softmax(logits + \sigma(CompletedQ))
+ # store the improved policy in Gumbel MuZero: \pi'=softmax(logits + \sigma(CompletedQ))
if self.gumbel_algo:
self.improved_policy_probs.append(improved_policy)
else:
diff --git a/lzero/mcts/ctree/ctree_efficientzero/ez_tree.pxd b/lzero/mcts/ctree/ctree_efficientzero/ez_tree.pxd
index 915199349..1023150aa 100644
--- a/lzero/mcts/ctree/ctree_efficientzero/ez_tree.pxd
+++ b/lzero/mcts/ctree/ctree_efficientzero/ez_tree.pxd
@@ -79,9 +79,17 @@ cdef extern from "lib/cnode.h" namespace "tree":
vector[float] values, vector[vector[float]] policies,
CMinMaxStatsList *min_max_stats_lst, CSearchResults & results,
vector[int] is_reset_list, vector[int] & to_play_batch)
+ void cbatch_backpropagate_with_reuse(int current_latent_state_index, float discount_factor, vector[float] value_prefixs,
+ vector[float] values, vector[vector[float]] policies,
+ CMinMaxStatsList *min_max_stats_lst, CSearchResults &results,
+ vector[int] is_reset_list, vector[int] &to_play_batch, vector[int] &no_inference_lst,
+ vector[int] &reuse_lst, vector[float] &reuse_value_lst)
void cbatch_traverse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor,
CMinMaxStatsList *min_max_stats_lst, CSearchResults & results,
vector[int] & virtual_to_play_batch)
+ void cbatch_traverse_with_reuse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor,
+ CMinMaxStatsList *min_max_stats_lst, CSearchResults &results,
+ vector[int] &virtual_to_play_batch, vector[int] &true_action, vector[float] &reuse_value)
cdef class MinMaxStatsList:
cdef CMinMaxStatsList *cmin_max_stats_lst
diff --git a/lzero/mcts/ctree/ctree_efficientzero/ez_tree.pyx b/lzero/mcts/ctree/ctree_efficientzero/ez_tree.pyx
index 8149f8569..b7d28c1fd 100644
--- a/lzero/mcts/ctree/ctree_efficientzero/ez_tree.pyx
+++ b/lzero/mcts/ctree/ctree_efficientzero/ez_tree.pyx
@@ -91,6 +91,19 @@ def batch_backpropagate(int current_latent_state_index, float discount_factor, l
cbatch_backpropagate(current_latent_state_index, discount_factor, cvalue_prefixs, cvalues, cpolicies,
min_max_stats_lst.cmin_max_stats_lst, results.cresults, is_reset_list, to_play_batch)
+@cython.binding
+def batch_backpropagate_with_reuse(int current_latent_state_index, float discount_factor, list value_prefixs, list values, list policies,
+ MinMaxStatsList min_max_stats_lst, ResultsWrapper results, list is_reset_list,
+ list to_play_batch, list no_inference_lst, list reuse_lst, list reuse_value_lst):
+ cdef int i
+ cdef vector[float] cvalue_prefixs = value_prefixs
+ cdef vector[float] cvalues = values
+ cdef vector[vector[float]] cpolicies = policies
+ cdef vector[float] creuse_value_lst = reuse_value_lst
+
+ cbatch_backpropagate_with_reuse(current_latent_state_index, discount_factor, cvalue_prefixs, cvalues, cpolicies,
+ min_max_stats_lst.cmin_max_stats_lst, results.cresults, is_reset_list, to_play_batch, no_inference_lst, reuse_lst, creuse_value_lst)
+
@cython.binding
def batch_traverse(Roots roots, int pb_c_base, float pb_c_init, float discount_factor, MinMaxStatsList min_max_stats_lst,
ResultsWrapper results, list virtual_to_play_batch):
@@ -98,3 +111,11 @@ def batch_traverse(Roots roots, int pb_c_base, float pb_c_init, float discount_f
results.cresults, virtual_to_play_batch)
return results.cresults.latent_state_index_in_search_path, results.cresults.latent_state_index_in_batch, results.cresults.last_actions, results.cresults.virtual_to_play_batchs
+
+@cython.binding
+def batch_traverse_with_reuse(Roots roots, int pb_c_base, float pb_c_init, float discount_factor, MinMaxStatsList min_max_stats_lst,
+ ResultsWrapper results, list virtual_to_play_batch, list true_action, list reuse_value):
+ cbatch_traverse_with_reuse(roots.roots, pb_c_base, pb_c_init, discount_factor, min_max_stats_lst.cmin_max_stats_lst, results.cresults,
+ virtual_to_play_batch, true_action, reuse_value)
+
+ return results.cresults.latent_state_index_in_search_path, results.cresults.latent_state_index_in_batch, results.cresults.last_actions, results.cresults.virtual_to_play_batchs
diff --git a/lzero/mcts/ctree/ctree_efficientzero/lib/cnode.cpp b/lzero/mcts/ctree/ctree_efficientzero/lib/cnode.cpp
index 606744e0a..8a94a7ca9 100644
--- a/lzero/mcts/ctree/ctree_efficientzero/lib/cnode.cpp
+++ b/lzero/mcts/ctree/ctree_efficientzero/lib/cnode.cpp
@@ -600,6 +600,54 @@ namespace tree
}
}
+ void cbatch_backpropagate_with_reuse(int current_latent_state_index, float discount_factor, const std::vector &value_prefixs, const std::vector &values, const std::vector > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector is_reset_list, std::vector &to_play_batch, std::vector &no_inference_lst, std::vector &reuse_lst, std::vector &reuse_value_lst)
+ {
+ /*
+ Overview:
+ Expand the nodes along the search path and update the infos.
+ Arguments:
+ - current_latent_state_index: The index of latent state of the leaf node in the search path.
+ - discount_factor: the discount factor of reward.
+ - value_prefixs: the value prefixs of nodes along the search path.
+ - values: the values to propagate along the search path.
+ - policies: the policy logits of nodes along the search path.
+ - min_max_stats: a tool used to min-max normalize the q value.
+ - results: the search results.
+ - to_play_batch: the batch of which player is playing on this node.
+ - no_inference_lst: the list of the nodes which does not need to expand.
+ - reuse_lst: the list of the nodes which should use reuse-value to backpropagate.
+ - reuse_value_lst: the list of the reuse-value.
+ */
+ int count_a = 0;
+ int count_b = 0;
+ int count_c = 0;
+ float value_propagate = 0;
+ for (int i = 0; i < results.num; ++i)
+ {
+ if (i == no_inference_lst[count_a])
+ {
+ count_a = count_a + 1;
+ value_propagate = reuse_value_lst[i];
+ }
+ else
+ {
+ results.nodes[i]->expand(to_play_batch[i], current_latent_state_index, count_b, value_prefixs[count_b], policies[count_b]);
+ if (i == reuse_lst[count_c])
+ {
+ value_propagate = reuse_value_lst[i];
+ count_c = count_c + 1;
+ }
+ else
+ {
+ value_propagate = values[count_b];
+ }
+ count_b = count_b + 1;
+ }
+ results.nodes[i]->is_reset = is_reset_list[i];
+ cbackpropagate(results.search_paths[i], min_max_stats_lst->stats_lst[i], to_play_batch[i], value_propagate, discount_factor);
+ }
+ }
+
int cselect_child(CNode *root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q, int players)
{
/*
@@ -646,6 +694,65 @@ namespace tree
return action;
}
+ int cselect_root_child(CNode *root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q, int players, int true_action, float reuse_value)
+ {
+ /*
+ Overview:
+ Select the child node of the roots according to ucb scores.
+ Arguments:
+ - root: the roots to select the child node.
+ - min_max_stats: a tool used to min-max normalize the score.
+ - pb_c_base: constants c2 in muzero.
+ - pb_c_init: constants c1 in muzero.
+ - disount_factor: the discount factor of reward.
+ - mean_q: the mean q value of the parent node.
+ - players: the number of players.
+ - true_action: the action chosen in the trajectory.
+ - reuse_value: the value obtained from the search of the next state in the trajectory.
+ Returns:
+ - action: the action to select.
+ */
+
+ float max_score = FLOAT_MIN;
+ const float epsilon = 0.000001;
+ std::vector max_index_lst;
+ for (auto a : root->legal_actions)
+ {
+
+ CNode *child = root->get_child(a);
+ float temp_score = 0.0;
+ if (a == true_action)
+ {
+ temp_score = carm_score(child, min_max_stats, mean_q, root->is_reset, reuse_value, root->visit_count - 1, root->value_prefix, pb_c_base, pb_c_init, discount_factor, players);
+ }
+ else
+ {
+ temp_score = cucb_score(child, min_max_stats, mean_q, root->is_reset, root->visit_count - 1, root->value_prefix, pb_c_base, pb_c_init, discount_factor, players);
+ }
+
+ if (max_score < temp_score)
+ {
+ max_score = temp_score;
+
+ max_index_lst.clear();
+ max_index_lst.push_back(a);
+ }
+ else if (temp_score >= max_score - epsilon)
+ {
+ max_index_lst.push_back(a);
+ }
+ }
+
+ int action = 0;
+ if (max_index_lst.size() > 0)
+ {
+ int rand_index = rand() % max_index_lst.size();
+ action = max_index_lst[rand_index];
+ }
+ // printf("select root child ends");
+ return action;
+ }
+
float cucb_score(CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, int is_reset, float total_children_visit_counts, float parent_value_prefix, float pb_c_base, float pb_c_init, float discount_factor, int players)
{
/*
@@ -706,6 +813,76 @@ namespace tree
return prior_score + value_score; // ucb_value
}
+ float carm_score(CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, int is_reset, float reuse_value, float total_children_visit_counts, float parent_value_prefix, float pb_c_base, float pb_c_init, float discount_factor, int players)
+ {
+ /*
+ Overview:
+ Compute the ucb score of the child.
+ Arguments:
+ - child: the child node to compute ucb score.
+ - min_max_stats: a tool used to min-max normalize the score.
+ - parent_mean_q: the mean q value of the parent node.
+ - is_reset: whether the value prefix needs to be reset.
+ - total_children_visit_counts: the total visit counts of the child nodes of the parent node.
+ - parent_value_prefix: the value prefix of parent node.
+ - pb_c_base: constants c2 in muzero.
+ - pb_c_init: constants c1 in muzero.
+ - disount_factor: the discount factor of reward.
+ - players: the number of players.
+ Returns:
+ - ucb_value: the ucb score of the child.
+ */
+ float pb_c = 0.0, prior_score = 0.0, value_score = 0.0;
+ pb_c = log((total_children_visit_counts + pb_c_base + 1) / pb_c_base) + pb_c_init;
+ pb_c *= (sqrt(total_children_visit_counts) / (child->visit_count + 1));
+
+ prior_score = pb_c * child->prior;
+ if (child->visit_count == 0)
+ {
+ value_score = parent_mean_q;
+ }
+ else
+ {
+ float true_reward = child->value_prefix - parent_value_prefix;
+ if (is_reset == 1)
+ {
+ true_reward = child->value_prefix;
+ }
+
+ if (players == 1)
+ {
+ value_score = true_reward + discount_factor * reuse_value;
+ }
+ else if (players == 2)
+ {
+ value_score = true_reward + discount_factor * (-reuse_value);
+ }
+ }
+
+ value_score = min_max_stats.normalize(value_score);
+
+ if (value_score < 0)
+ {
+ value_score = 0;
+ }
+ else if (value_score > 1)
+ {
+ value_score = 1;
+ }
+
+ float ucb_value = 0.0;
+ if (child->visit_count == 0)
+ {
+ ucb_value = prior_score + value_score;
+ }
+ else
+ {
+ ucb_value = value_score;
+ }
+ // printf("carmscore ends");
+ return ucb_value;
+ }
+
void cbatch_traverse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector &virtual_to_play_batch)
{
/*
@@ -784,4 +961,113 @@ namespace tree
results.virtual_to_play_batchs.push_back(virtual_to_play_batch[i]);
}
}
+
+ void cbatch_traverse_with_reuse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector &virtual_to_play_batch, std::vector &true_action, std::vector &reuse_value)
+ {
+ /*
+ Overview:
+ Search node path from the roots.
+ Arguments:
+ - roots: the roots that search from.
+ - pb_c_base: constants c2 in muzero.
+ - pb_c_init: constants c1 in muzero.
+ - disount_factor: the discount factor of reward.
+ - min_max_stats: a tool used to min-max normalize the score.
+ - results: the search results.
+ - virtual_to_play_batch: the batch of which player is playing on this node.
+ - true_action: the action chosen in the trajectory.
+ - reuse_value: the value obtained from the search of the next state in the trajectory.
+ */
+ // set seed
+ get_time_and_set_rand_seed();
+
+ int last_action = -1;
+ float parent_q = 0.0;
+ results.search_lens = std::vector();
+
+ int players = 0;
+ int largest_element = *max_element(virtual_to_play_batch.begin(), virtual_to_play_batch.end()); // 0 or 2
+ if (largest_element == -1)
+ {
+ players = 1;
+ }
+ else
+ {
+ players = 2;
+ }
+
+ for (int i = 0; i < results.num; ++i)
+ {
+ CNode *node = &(roots->roots[i]);
+ int is_root = 1;
+ int search_len = 0;
+ results.search_paths[i].push_back(node);
+
+ while (node->expanded())
+ {
+ float mean_q = node->compute_mean_q(is_root, parent_q, discount_factor);
+ parent_q = mean_q;
+
+ int action = 0;
+ if (is_root)
+ {
+ action = cselect_root_child(node, min_max_stats_lst->stats_lst[i], pb_c_base, pb_c_init, discount_factor, mean_q, players, true_action[i], reuse_value[i]);
+ }
+ else
+ {
+ action = cselect_child(node, min_max_stats_lst->stats_lst[i], pb_c_base, pb_c_init, discount_factor, mean_q, players);
+ }
+
+ if (players > 1)
+ {
+ assert(virtual_to_play_batch[i] == 1 || virtual_to_play_batch[i] == 2);
+ if (virtual_to_play_batch[i] == 1)
+ {
+ virtual_to_play_batch[i] = 2;
+ }
+ else
+ {
+ virtual_to_play_batch[i] = 1;
+ }
+ }
+
+ node->best_action = action;
+ // next
+ node = node->get_child(action);
+ last_action = action;
+ results.search_paths[i].push_back(node);
+ search_len += 1;
+
+ if(is_root && action == true_action[i])
+ {
+ break;
+ }
+
+ is_root = 0;
+ }
+
+ if (node->expanded())
+ {
+ results.latent_state_index_in_search_path.push_back(-1);
+ results.latent_state_index_in_batch.push_back(i);
+
+ results.last_actions.push_back(last_action);
+ results.search_lens.push_back(search_len);
+ results.nodes.push_back(node);
+ results.virtual_to_play_batchs.push_back(virtual_to_play_batch[i]);
+ }
+ else
+ {
+ CNode *parent = results.search_paths[i][results.search_paths[i].size() - 2];
+
+ results.latent_state_index_in_search_path.push_back(parent->current_latent_state_index);
+ results.latent_state_index_in_batch.push_back(parent->batch_index);
+
+ results.last_actions.push_back(last_action);
+ results.search_lens.push_back(search_len);
+ results.nodes.push_back(node);
+ results.virtual_to_play_batchs.push_back(virtual_to_play_batch[i]);
+ }
+ }
+ }
}
\ No newline at end of file
diff --git a/lzero/mcts/ctree/ctree_efficientzero/lib/cnode.h b/lzero/mcts/ctree/ctree_efficientzero/lib/cnode.h
index 52b6e6dfa..bca179151 100644
--- a/lzero/mcts/ctree/ctree_efficientzero/lib/cnode.h
+++ b/lzero/mcts/ctree/ctree_efficientzero/lib/cnode.h
@@ -83,9 +83,13 @@ namespace tree {
void update_tree_q(CNode* root, tools::CMinMaxStats &min_max_stats, float discount_factor, int players);
void cbackpropagate(std::vector &search_path, tools::CMinMaxStats &min_max_stats, int to_play, float value, float discount_factor);
void cbatch_backpropagate(int current_latent_state_index, float discount_factor, const std::vector &value_prefixs, const std::vector &values, const std::vector > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector is_reset_list, std::vector &to_play_batch);
+ void cbatch_backpropagate_with_reuse(int current_latent_state_index, float discount_factor, const std::vector &value_prefixs, const std::vector &values, const std::vector > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector is_reset_list, std::vector &to_play_batch, std::vector &no_inference_lst, std::vector &reuse_lst, std::vector &reuse_value_lst);
int cselect_child(CNode* root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q, int players);
+ int cselect_root_child(CNode *root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q, int players, int true_action, float reuse_value);
float cucb_score(CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, int is_reset, float total_children_visit_counts, float parent_value_prefix, float pb_c_base, float pb_c_init, float discount_factor, int players);
+ float carm_score(CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, int is_reset, float reuse_value, float total_children_visit_counts, float parent_value_prefix, float pb_c_base, float pb_c_init, float discount_factor, int players);
void cbatch_traverse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector &virtual_to_play_batch);
+ void cbatch_traverse_with_reuse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector &virtual_to_play_batch, std::vector &true_action, std::vector &reuse_value);
}
#endif
\ No newline at end of file
diff --git a/lzero/mcts/ctree/ctree_muzero/lib/cnode.cpp b/lzero/mcts/ctree/ctree_muzero/lib/cnode.cpp
index 0fe9d8018..24eb3605c 100644
--- a/lzero/mcts/ctree/ctree_muzero/lib/cnode.cpp
+++ b/lzero/mcts/ctree/ctree_muzero/lib/cnode.cpp
@@ -499,6 +499,55 @@ namespace tree
}
}
+ void cbatch_backpropagate_with_reuse(int current_latent_state_index, float discount_factor, const std::vector &value_prefixs, const std::vector &values, const std::vector > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector &to_play_batch, std::vector &no_inference_lst, std::vector &reuse_lst, std::vector &reuse_value_lst)
+ {
+ /*
+ Overview:
+ Expand the nodes along the search path and update the infos. Details are similar to cbatch_backpropagate, but with reuse value.
+ Please refer to https://arxiv.org/abs/2404.16364 for more details.
+ Arguments:
+ - current_latent_state_index: The index of latent state of the leaf node in the search path.
+ - discount_factor: the discount factor of reward.
+ - value_prefixs: the value prefixs of nodes along the search path.
+ - values: the values to propagate along the search path.
+ - policies: the policy logits of nodes along the search path.
+ - min_max_stats: a tool used to min-max normalize the q value.
+ - results: the search results.
+ - to_play_batch: the batch of which player is playing on this node.
+ - no_inference_lst: the list of the nodes which does not need to expand.
+ - reuse_lst: the list of the nodes which should use reuse-value to backpropagate.
+ - reuse_value_lst: the list of the reuse-value.
+ */
+ int count_a = 0;
+ int count_b = 0;
+ int count_c = 0;
+ float value_propagate = 0;
+ for (int i = 0; i < results.num; ++i)
+ {
+ if (i == no_inference_lst[count_a])
+ {
+ count_a = count_a + 1;
+ value_propagate = reuse_value_lst[i];
+ }
+ else
+ {
+ results.nodes[i]->expand(to_play_batch[i], current_latent_state_index, count_b, value_prefixs[count_b], policies[count_b]);
+ if (i == reuse_lst[count_c])
+ {
+ value_propagate = reuse_value_lst[i];
+ count_c = count_c + 1;
+ }
+ else
+ {
+ value_propagate = values[count_b];
+ }
+ count_b = count_b + 1;
+ }
+
+ cbackpropagate(results.search_paths[i], min_max_stats_lst->stats_lst[i], to_play_batch[i], value_propagate, discount_factor);
+ }
+ }
+
int cselect_child(CNode *root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q, int players)
{
/*
@@ -546,6 +595,63 @@ namespace tree
return action;
}
+ int cselect_root_child(CNode *root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q, int players, int true_action, float reuse_value)
+ {
+ /*
+ Overview:
+ Select the child node of the roots according to ucb scores.
+ Arguments:
+ - root: the roots to select the child node.
+ - min_max_stats: a tool used to min-max normalize the score.
+ - pb_c_base: constants c2 in muzero.
+ - pb_c_init: constants c1 in muzero.
+ - discount_factor: the discount factor of reward.
+ - mean_q: the mean q value of the parent node.
+ - players: the number of players.
+ - true_action: the action chosen in the trajectory.
+ - reuse_value: the value obtained from the search of the next state in the trajectory.
+ Returns:
+ - action: the action to select.
+ */
+ float max_score = FLOAT_MIN;
+ const float epsilon = 0.000001;
+ std::vector max_index_lst;
+ for (auto a : root->legal_actions)
+ {
+
+ CNode *child = root->get_child(a);
+ float temp_score = 0.0;
+ if (a == true_action)
+ {
+ temp_score = carm_score(child, min_max_stats, mean_q, reuse_value, root->visit_count - 1, pb_c_base, pb_c_init, discount_factor, players);
+ }
+ else
+ {
+ temp_score = cucb_score(child, min_max_stats, mean_q, root->visit_count - 1, pb_c_base, pb_c_init, discount_factor, players);
+ }
+
+ if (max_score < temp_score)
+ {
+ max_score = temp_score;
+
+ max_index_lst.clear();
+ max_index_lst.push_back(a);
+ }
+ else if (temp_score >= max_score - epsilon)
+ {
+ max_index_lst.push_back(a);
+ }
+ }
+
+ int action = 0;
+ if (max_index_lst.size() > 0)
+ {
+ int rand_index = rand() % max_index_lst.size();
+ action = max_index_lst[rand_index];
+ }
+ return action;
+ }
+
float cucb_score(CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, float total_children_visit_counts, float pb_c_base, float pb_c_init, float discount_factor, int players)
{
/*
@@ -592,6 +698,60 @@ namespace tree
return ucb_value;
}
+
+ float carm_score(CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, float reuse_value, float total_children_visit_counts, float pb_c_base, float pb_c_init, float discount_factor, int players)
+ {
+ /*
+ Overview:
+ Compute the ucb score of the child.
+ Arguments:
+ - child: the child node to compute ucb score.
+ - min_max_stats: a tool used to min-max normalize the score.
+ - mean_q: the mean q value of the parent node.
+ - total_children_visit_counts: the total visit counts of the child nodes of the parent node.
+ - pb_c_base: constants c2 in muzero.
+ - pb_c_init: constants c1 in muzero.
+ - disount_factor: the discount factor of reward.
+ - players: the number of players.
+ Returns:
+ - ucb_value: the ucb score of the child.
+ */
+ float pb_c = 0.0, prior_score = 0.0, value_score = 0.0;
+ pb_c = log((total_children_visit_counts + pb_c_base + 1) / pb_c_base) + pb_c_init;
+ pb_c *= (sqrt(total_children_visit_counts) / (child->visit_count + 1));
+
+ prior_score = pb_c * child->prior;
+ if (child->visit_count == 0)
+ {
+ value_score = parent_mean_q;
+ }
+ else
+ {
+ float true_reward = child->reward;
+ if (players == 1)
+ value_score = true_reward + discount_factor * reuse_value;
+ else if (players == 2)
+ value_score = true_reward + discount_factor * (-reuse_value);
+ }
+
+ value_score = min_max_stats.normalize(value_score);
+
+ if (value_score < 0)
+ value_score = 0;
+ if (value_score > 1)
+ value_score = 1;
+ float ucb_value = 0.0;
+ if (child->visit_count == 0)
+ {
+ ucb_value = prior_score + value_score;
+ }
+ else
+ {
+ ucb_value = value_score;
+ }
+ return ucb_value;
+ }
+
void cbatch_traverse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector &virtual_to_play_batch)
{
/*
@@ -663,4 +823,108 @@ namespace tree
}
}
+
+ void cbatch_traverse_with_reuse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector &virtual_to_play_batch, std::vector &true_action, std::vector &reuse_value)
+ {
+ /*
+ Overview:
+ Search node path from the roots. Details are similar to cbatch_traverse, but with reuse value.
+ Please refer to https://arxiv.org/abs/2404.16364 for more details.
+ Arguments:
+ - roots: the roots that search from.
+ - pb_c_base: constants c2 in muzero.
+ - pb_c_init: constants c1 in muzero.
+ - disount_factor: the discount factor of reward.
+ - min_max_stats: a tool used to min-max normalize the score.
+ - results: the search results.
+ - virtual_to_play_batch: the batch of which player is playing on this node.
+ - true_action: the action chosen in the trajectory.
+ - reuse_value: the value obtained from the search of the next state in the trajectory.
+ */
+ // set seed
+ get_time_and_set_rand_seed();
+
+ int last_action = -1;
+ float parent_q = 0.0;
+ results.search_lens = std::vector();
+
+ int players = 0;
+ int largest_element = *max_element(virtual_to_play_batch.begin(), virtual_to_play_batch.end()); // 0 or 2
+ if (largest_element == -1)
+ players = 1;
+ else
+ players = 2;
+
+ for (int i = 0; i < results.num; ++i)
+ {
+ CNode *node = &(roots->roots[i]);
+ int is_root = 1;
+ int search_len = 0;
+ results.search_paths[i].push_back(node);
+
+ while (node->expanded())
+ {
+ float mean_q = node->compute_mean_q(is_root, parent_q, discount_factor);
+ parent_q = mean_q;
+
+ int action = 0;
+ if (is_root)
+ {
+ action = cselect_root_child(node, min_max_stats_lst->stats_lst[i], pb_c_base, pb_c_init, discount_factor, mean_q, players, true_action[i], reuse_value[i]);
+ }
+ else
+ {
+ action = cselect_child(node, min_max_stats_lst->stats_lst[i], pb_c_base, pb_c_init, discount_factor, mean_q, players);
+ }
+
+ if (players > 1)
+ {
+ assert(virtual_to_play_batch[i] == 1 || virtual_to_play_batch[i] == 2);
+ if (virtual_to_play_batch[i] == 1)
+ virtual_to_play_batch[i] = 2;
+ else
+ virtual_to_play_batch[i] = 1;
+ }
+
+ node->best_action = action;
+ // next
+ node = node->get_child(action);
+ last_action = action;
+ results.search_paths[i].push_back(node);
+ search_len += 1;
+
+ if(is_root && action == true_action[i])
+ {
+ break;
+ }
+
+ is_root = 0;
+
+ }
+
+ if (node->expanded())
+ {
+ results.latent_state_index_in_search_path.push_back(-1);
+ results.latent_state_index_in_batch.push_back(i);
+
+ results.last_actions.push_back(last_action);
+ results.search_lens.push_back(search_len);
+ results.nodes.push_back(node);
+ results.virtual_to_play_batchs.push_back(virtual_to_play_batch[i]);
+ }
+ else
+ {
+ CNode *parent = results.search_paths[i][results.search_paths[i].size() - 2];
+
+ results.latent_state_index_in_search_path.push_back(parent->current_latent_state_index);
+ results.latent_state_index_in_batch.push_back(parent->batch_index);
+
+ results.last_actions.push_back(last_action);
+ results.search_lens.push_back(search_len);
+ results.nodes.push_back(node);
+ results.virtual_to_play_batchs.push_back(virtual_to_play_batch[i]);
+ }
+ }
+ }
+
}
\ No newline at end of file
diff --git a/lzero/mcts/ctree/ctree_muzero/lib/cnode.h b/lzero/mcts/ctree/ctree_muzero/lib/cnode.h
index 1c5fdea6e..c8f1c9091 100644
--- a/lzero/mcts/ctree/ctree_muzero/lib/cnode.h
+++ b/lzero/mcts/ctree/ctree_muzero/lib/cnode.h
@@ -83,9 +83,13 @@ namespace tree {
void update_tree_q(CNode* root, tools::CMinMaxStats &min_max_stats, float discount_factor, int players);
void cbackpropagate(std::vector &search_path, tools::CMinMaxStats &min_max_stats, int to_play, float value, float discount_factor);
void cbatch_backpropagate(int current_latent_state_index, float discount_factor, const std::vector &rewards, const std::vector &values, const std::vector > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector &to_play_batch);
+ void cbatch_backpropagate_with_reuse(int current_latent_state_index, float discount_factor, const std::vector &value_prefixs, const std::vector &values, const std::vector > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector &to_play_batch, std::vector &no_inference_lst, std::vector &reuse_lst, std::vector &reuse_value_lst);
int cselect_child(CNode* root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q, int players);
+ int cselect_root_child(CNode *root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q, int players, int true_action, float reuse_value);
float cucb_score(CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, float total_children_visit_counts, float pb_c_base, float pb_c_init, float discount_factor, int players);
+ float carm_score(CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, float reuse_value, float total_children_visit_counts, float pb_c_base, float pb_c_init, float discount_factor, int players);
void cbatch_traverse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector &virtual_to_play_batch);
+ void cbatch_traverse_with_reuse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector &virtual_to_play_batch, std::vector &true_action, std::vector &reuse_value);
}
#endif
\ No newline at end of file
diff --git a/lzero/mcts/ctree/ctree_muzero/mz_tree.pxd b/lzero/mcts/ctree/ctree_muzero/mz_tree.pxd
index b46f86912..6c65f4466 100644
--- a/lzero/mcts/ctree/ctree_muzero/mz_tree.pxd
+++ b/lzero/mcts/ctree/ctree_muzero/mz_tree.pxd
@@ -70,4 +70,7 @@ cdef extern from "lib/cnode.h" namespace "tree":
cdef void cbackpropagate(vector[CNode*] &search_path, CMinMaxStats &min_max_stats, int to_play, float value, float discount_factor)
void cbatch_backpropagate(int current_latent_state_index, float discount_factor, vector[float] value_prefixs, vector[float] values, vector[vector[float]] policies,
CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, vector[int] &to_play_batch)
+ void cbatch_backpropagate_with_reuse(int current_latent_state_index, float discount_factor, vector[float] value_prefixs, vector[float] values, vector[vector[float]] policies,
+ CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, vector[int] &to_play_batch, vector[int] &no_inference_lst, vector[int] &reuse_lst, vector[float] &reuse_value_lst)
void cbatch_traverse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, vector[int] &virtual_to_play_batch)
+ void cbatch_traverse_with_reuse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, vector[int] &virtual_to_play_batch, vector[int] &true_action, vector[float] &reuse_value)
diff --git a/lzero/mcts/ctree/ctree_muzero/mz_tree.pyx b/lzero/mcts/ctree/ctree_muzero/mz_tree.pyx
index 80e5c4505..506e69d57 100644
--- a/lzero/mcts/ctree/ctree_muzero/mz_tree.pyx
+++ b/lzero/mcts/ctree/ctree_muzero/mz_tree.pyx
@@ -81,9 +81,27 @@ def batch_backpropagate(int current_latent_state_index, float discount_factor, l
cbatch_backpropagate(current_latent_state_index, discount_factor, cvalue_prefixs, cvalues, cpolicies,
min_max_stats_lst.cmin_max_stats_lst, results.cresults, to_play_batch)
+def batch_backpropagate_with_reuse(int current_latent_state_index, float discount_factor, list value_prefixs, list values, list policies,
+ MinMaxStatsList min_max_stats_lst, ResultsWrapper results, list to_play_batch, list no_inference_lst, list reuse_lst, list reuse_value_lst):
+ cdef int i
+ cdef vector[float] cvalue_prefixs = value_prefixs
+ cdef vector[float] cvalues = values
+ cdef vector[vector[float]] cpolicies = policies
+ cdef vector[float] creuse_value_lst = reuse_value_lst
+
+ cbatch_backpropagate_with_reuse(current_latent_state_index, discount_factor, cvalue_prefixs, cvalues, cpolicies,
+ min_max_stats_lst.cmin_max_stats_lst, results.cresults, to_play_batch, no_inference_lst, reuse_lst, creuse_value_lst)
+
def batch_traverse(Roots roots, int pb_c_base, float pb_c_init, float discount_factor, MinMaxStatsList min_max_stats_lst,
ResultsWrapper results, list virtual_to_play_batch):
cbatch_traverse(roots.roots, pb_c_base, pb_c_init, discount_factor, min_max_stats_lst.cmin_max_stats_lst, results.cresults,
virtual_to_play_batch)
return results.cresults.latent_state_index_in_search_path, results.cresults.latent_state_index_in_batch, results.cresults.last_actions, results.cresults.virtual_to_play_batchs
+
+def batch_traverse_with_reuse(Roots roots, int pb_c_base, float pb_c_init, float discount_factor, MinMaxStatsList min_max_stats_lst,
+ ResultsWrapper results, list virtual_to_play_batch, list true_action, list reuse_value):
+ cbatch_traverse_with_reuse(roots.roots, pb_c_base, pb_c_init, discount_factor, min_max_stats_lst.cmin_max_stats_lst, results.cresults,
+ virtual_to_play_batch, true_action, reuse_value)
+
+ return results.cresults.latent_state_index_in_search_path, results.cresults.latent_state_index_in_batch, results.cresults.last_actions, results.cresults.virtual_to_play_batchs
diff --git a/lzero/mcts/ptree/ptree_mz.py b/lzero/mcts/ptree/ptree_mz.py
index 909d48e4f..afb6f7667 100644
--- a/lzero/mcts/ptree/ptree_mz.py
+++ b/lzero/mcts/ptree/ptree_mz.py
@@ -3,7 +3,7 @@
"""
import math
import random
-from typing import List, Dict, Any, Tuple, Union
+from typing import List, Any, Tuple, Union
import numpy as np
import torch
@@ -19,6 +19,7 @@ class Node:
``__init__``, ``expand``, ``add_exploration_noise``, ``compute_mean_q``, ``get_trajectory``, \
``get_children_distribution``, ``get_child``, ``expanded``, ``value``.
"""
+
def __init__(self, prior: float, legal_actions: List = None, action_space_size: int = 9) -> None:
"""
Overview:
@@ -412,7 +413,7 @@ def compute_ucb_score(
value_score = 0
if value_score > 1:
value_score = 1
-
+
ucb_score = prior_score + value_score
return ucb_score
diff --git a/lzero/mcts/tests/config/tictactoe_muzero_bot_mode_config_for_test.py b/lzero/mcts/tests/config/tictactoe_muzero_bot_mode_config_for_test.py
index 8734298e3..ef5d79233 100644
--- a/lzero/mcts/tests/config/tictactoe_muzero_bot_mode_config_for_test.py
+++ b/lzero/mcts/tests/config/tictactoe_muzero_bot_mode_config_for_test.py
@@ -16,7 +16,7 @@
# ==============================================================
tictactoe_muzero_config = dict(
- exp_name='data_mz_ctree/tictactoe_muzero_bot_mode_seed0',
+ exp_name='data_muzero/tictactoe_muzero_bot_mode_seed0',
env=dict(
battle_mode='play_with_bot_mode',
collector_env_num=collector_env_num,
diff --git a/lzero/mcts/tree_search/mcts_ctree.py b/lzero/mcts/tree_search/mcts_ctree.py
index 76131d01a..e7543ab9d 100644
--- a/lzero/mcts/tree_search/mcts_ctree.py
+++ b/lzero/mcts/tree_search/mcts_ctree.py
@@ -6,10 +6,13 @@
from easydict import EasyDict
from lzero.mcts.ctree.ctree_efficientzero import ez_tree as tree_efficientzero
-from lzero.mcts.ctree.ctree_muzero import mz_tree as tree_muzero
from lzero.mcts.ctree.ctree_gumbel_muzero import gmz_tree as tree_gumbel_muzero
+from lzero.mcts.ctree.ctree_muzero import mz_tree as tree_muzero
from lzero.policy import InverseScalarTransform, to_detach_cpu_numpy
+# from line_profiler import line_profiler
+# import cProfile
+
if TYPE_CHECKING:
from lzero.mcts.ctree.ctree_efficientzero import ez_tree as ez_ctree
from lzero.mcts.ctree.ctree_muzero import mz_tree as mz_ctree
@@ -186,6 +189,105 @@ def search(
min_max_stats_lst, results, virtual_to_play_batch
)
+ def search_with_reuse(
+ self,
+ roots: Any,
+ model: torch.nn.Module,
+ latent_state_roots: List[Any],
+ to_play_batch: Union[int, List[Any]],
+ true_action_list=None,
+ reuse_value_list=None
+ ) -> None:
+ """
+ Overview:
+ Perform Monte Carlo Tree Search (MCTS) for the root nodes in parallel. Utilizes the cpp ctree for efficiency.
+ Please refer to https://arxiv.org/abs/2404.16364 for more details.
+ Arguments:
+ - roots (:obj:`Any`): A batch of expanded root nodes.
+ - model (:obj:`torch.nn.Module`): The neural network model.
+ - latent_state_roots (:obj:`list`): The hidden states of the root nodes.
+ - to_play_batch (:obj:`Union[int, list]`): The list or batch indicator for players in self-play mode.
+ - true_action_list (:obj:`list`, optional): A list of true actions for reuse.
+ - reuse_value_list (:obj:`list`, optional): A list of values for reuse.
+ """
+
+ with torch.no_grad():
+ model.eval()
+
+ # Initialize constants and variables
+ batch_size = roots.num
+ pb_c_base, pb_c_init, discount_factor = self._cfg.pb_c_base, self._cfg.pb_c_init, self._cfg.discount_factor
+ latent_state_batch_in_search_path = [latent_state_roots]
+ min_max_stats_lst = tree_muzero.MinMaxStatsList(batch_size)
+ min_max_stats_lst.set_delta(self._cfg.value_delta_max)
+ infer_sum = 0
+
+ for simulation_index in range(self._cfg.num_simulations):
+ latent_states = []
+ temp_actions = []
+ no_inference_lst = []
+ reuse_lst = []
+ results = tree_muzero.ResultsWrapper(num=batch_size)
+
+ # Selection phase: traverse the tree to select a leaf node
+ if self._cfg.env_type == 'not_board_games':
+ latent_state_index_in_search_path, latent_state_index_in_batch, last_actions, virtual_to_play_batch = tree_muzero.batch_traverse_with_reuse(
+ roots, pb_c_base, pb_c_init, discount_factor, min_max_stats_lst, results,
+ to_play_batch, true_action_list, reuse_value_list
+ )
+ else:
+ latent_state_index_in_search_path, latent_state_index_in_batch, last_actions, virtual_to_play_batch = tree_muzero.batch_traverse_with_reuse(
+ roots, pb_c_base, pb_c_init, discount_factor, min_max_stats_lst, results,
+ copy.deepcopy(to_play_batch), true_action_list, reuse_value_list
+ )
+
+ # Collect latent states and actions for expansion
+ for count, (ix, iy) in enumerate(zip(latent_state_index_in_search_path, latent_state_index_in_batch)):
+ if ix != -1:
+ latent_states.append(latent_state_batch_in_search_path[ix][iy])
+ temp_actions.append(last_actions[count])
+ else:
+ no_inference_lst.append(iy)
+ if ix == 0 and last_actions[count] == true_action_list[count]:
+ reuse_lst.append(count)
+
+ length = len(temp_actions)
+ latent_states = torch.from_numpy(np.asarray(latent_states)).to(self._cfg.device)
+ temp_actions = torch.from_numpy(np.asarray(temp_actions)).to(self._cfg.device).long()
+
+ # Expansion phase: expand the leaf node and evaluate the new node
+ if length != 0:
+ network_output = model.recurrent_inference(latent_states, temp_actions)
+ network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state)
+ network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits)
+ network_output.value = to_detach_cpu_numpy(
+ self.inverse_scalar_transform_handle(network_output.value))
+ network_output.reward = to_detach_cpu_numpy(
+ self.inverse_scalar_transform_handle(network_output.reward))
+
+ latent_state_batch_in_search_path.append(network_output.latent_state)
+ reward_batch = network_output.reward.reshape(-1).tolist()
+ value_batch = network_output.value.reshape(-1).tolist()
+ policy_logits_batch = network_output.policy_logits.tolist()
+ else:
+ latent_state_batch_in_search_path.append([])
+ reward_batch = []
+ value_batch = []
+ policy_logits_batch = []
+
+ # Backup phase: propagate the evaluation results back through the tree
+ current_latent_state_index = simulation_index + 1
+ no_inference_lst.append(-1)
+ reuse_lst.append(-1)
+ tree_muzero.batch_backpropagate_with_reuse(
+ current_latent_state_index, discount_factor, reward_batch, value_batch, policy_logits_batch,
+ min_max_stats_lst, results, virtual_to_play_batch, no_inference_lst, reuse_lst, reuse_value_list
+ )
+ infer_sum += length
+
+ average_infer = infer_sum / self._cfg.num_simulations
+ return length, average_infer
+
class EfficientZeroMCTSCtree(object):
"""
@@ -288,6 +390,8 @@ def search(
# the data storage of latent states: storing the latent state of all the nodes in one search.
latent_state_batch_in_search_path = [latent_state_roots]
# the data storage of value prefix hidden states in LSTM
+ # print(f"reward_hidden_state_roots[0]={reward_hidden_state_roots[0]}")
+ # print(f"reward_hidden_state_roots[1]={reward_hidden_state_roots[1]}")
reward_hidden_state_c_batch = [reward_hidden_state_roots[0]]
reward_hidden_state_h_batch = [reward_hidden_state_roots[1]]
@@ -354,7 +458,8 @@ def search(
network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state)
network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits)
network_output.value = to_detach_cpu_numpy(self.inverse_scalar_transform_handle(network_output.value))
- network_output.value_prefix = to_detach_cpu_numpy(self.inverse_scalar_transform_handle(network_output.value_prefix))
+ network_output.value_prefix = to_detach_cpu_numpy(
+ self.inverse_scalar_transform_handle(network_output.value_prefix))
network_output.reward_hidden_state = (
network_output.reward_hidden_state[0].detach().cpu().numpy(),
@@ -390,6 +495,133 @@ def search(
min_max_stats_lst, results, is_reset_list, virtual_to_play_batch
)
+ def search_with_reuse(
+ self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any],
+ reward_hidden_state_roots: List[Any], to_play_batch: Union[int, List[Any]],
+ true_action_list=None, reuse_value_list=None
+ ) -> None:
+ """
+ Perform Monte Carlo Tree Search (MCTS) for the root nodes in parallel, utilizing model inference in parallel.
+ This method uses the cpp ctree for efficiency.
+ Please refer to https://arxiv.org/abs/2404.16364 for more details.
+
+ Args:
+ roots (Any): A batch of expanded root nodes.
+ model (torch.nn.Module): The model to use for inference.
+ latent_state_roots (List[Any]): The hidden states of the root nodes.
+ reward_hidden_state_roots (List[Any]): The value prefix hidden states in the LSTM of the roots.
+ to_play_batch (Union[int, List[Any]]): The to_play_batch list used in self-play-mode board games.
+ true_action_list (Optional[List[Any]]): List of true actions for reuse.
+ reuse_value_list (Optional[List[Any]]): List of values for reuse.
+
+ Returns:
+ None
+ """
+ with torch.no_grad():
+ model.eval()
+
+ batch_size = roots.num
+ pb_c_base, pb_c_init, discount_factor = self._cfg.pb_c_base, self._cfg.pb_c_init, self._cfg.discount_factor
+
+ latent_state_batch_in_search_path = [latent_state_roots]
+ reward_hidden_state_c_batch = [reward_hidden_state_roots[0]]
+ reward_hidden_state_h_batch = [reward_hidden_state_roots[1]]
+
+ min_max_stats_lst = tree_efficientzero.MinMaxStatsList(batch_size)
+ min_max_stats_lst.set_delta(self._cfg.value_delta_max)
+
+ infer_sum = 0
+
+ for simulation_index in range(self._cfg.num_simulations):
+ latent_states, hidden_states_c_reward, hidden_states_h_reward = [], [], []
+ temp_actions, temp_search_lens, no_inference_lst, reuse_lst = [], [], [], []
+
+ results = tree_efficientzero.ResultsWrapper(num=batch_size)
+
+ if self._cfg.env_type == 'not_board_games':
+ latent_state_index_in_search_path, latent_state_index_in_batch, last_actions, virtual_to_play_batch = tree_efficientzero.batch_traverse_with_reuse(
+ roots, pb_c_base, pb_c_init, discount_factor, min_max_stats_lst, results,
+ to_play_batch, true_action_list, reuse_value_list
+ )
+ else:
+ latent_state_index_in_search_path, latent_state_index_in_batch, last_actions, virtual_to_play_batch = tree_efficientzero.batch_traverse_with_reuse(
+ roots, pb_c_base, pb_c_init, discount_factor, min_max_stats_lst, results,
+ copy.deepcopy(to_play_batch), true_action_list, reuse_value_list
+ )
+
+ search_lens = results.get_search_len()
+
+ for count, (ix, iy) in enumerate(zip(latent_state_index_in_search_path, latent_state_index_in_batch)):
+ if ix != -1:
+ latent_states.append(latent_state_batch_in_search_path[ix][iy])
+ hidden_states_c_reward.append(reward_hidden_state_c_batch[ix][0][iy])
+ hidden_states_h_reward.append(reward_hidden_state_h_batch[ix][0][iy])
+ temp_actions.append(last_actions[count])
+ temp_search_lens.append(search_lens[count])
+ else:
+ no_inference_lst.append(iy)
+ if ix == 0 and last_actions[count] == true_action_list[count]:
+ reuse_lst.append(count)
+
+ length = len(temp_actions)
+ latent_states = torch.tensor(latent_states, device=self._cfg.device)
+ hidden_states_c_reward = torch.tensor(hidden_states_c_reward, device=self._cfg.device).unsqueeze(0)
+ hidden_states_h_reward = torch.tensor(hidden_states_h_reward, device=self._cfg.device).unsqueeze(0)
+ temp_actions = torch.tensor(temp_actions, device=self._cfg.device).long()
+
+ if length != 0:
+ network_output = model.recurrent_inference(
+ latent_states, (hidden_states_c_reward, hidden_states_h_reward), temp_actions
+ )
+
+ network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state)
+ network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits)
+ network_output.value = to_detach_cpu_numpy(
+ self.inverse_scalar_transform_handle(network_output.value))
+ network_output.value_prefix = to_detach_cpu_numpy(
+ self.inverse_scalar_transform_handle(network_output.value_prefix))
+
+ network_output.reward_hidden_state = (
+ network_output.reward_hidden_state[0].detach().cpu().numpy(),
+ network_output.reward_hidden_state[1].detach().cpu().numpy()
+ )
+
+ latent_state_batch_in_search_path.append(network_output.latent_state)
+ value_prefix_batch = network_output.value_prefix.reshape(-1).tolist()
+ value_batch = network_output.value.reshape(-1).tolist()
+ policy_logits_batch = network_output.policy_logits.tolist()
+
+ reward_latent_state_batch = network_output.reward_hidden_state
+ assert self._cfg.lstm_horizon_len > 0
+ reset_idx = (np.array(temp_search_lens) % self._cfg.lstm_horizon_len == 0)
+ reward_latent_state_batch[0][:, reset_idx, :] = 0
+ reward_latent_state_batch[1][:, reset_idx, :] = 0
+ is_reset_list = reset_idx.astype(np.int32).tolist()
+ reward_hidden_state_c_batch.append(reward_latent_state_batch[0])
+ reward_hidden_state_h_batch.append(reward_latent_state_batch[1])
+ else:
+ latent_state_batch_in_search_path.append([])
+ value_batch, policy_logits_batch, value_prefix_batch = [], [], []
+ reward_hidden_state_c_batch.append([])
+ reward_hidden_state_h_batch.append([])
+ assert self._cfg.lstm_horizon_len > 0
+ reset_idx = (np.array(search_lens) % self._cfg.lstm_horizon_len == 0)
+ assert len(reset_idx) == batch_size
+ is_reset_list = reset_idx.astype(np.int32).tolist()
+
+ current_latent_state_index = simulation_index + 1
+ no_inference_lst.append(-1)
+ reuse_lst.append(-1)
+ tree_efficientzero.batch_backpropagate_with_reuse(
+ current_latent_state_index, discount_factor, value_prefix_batch, value_batch, policy_logits_batch,
+ min_max_stats_lst, results, is_reset_list, virtual_to_play_batch, no_inference_lst, reuse_lst,
+ reuse_value_list
+ )
+ infer_sum += length
+
+ average_infer = infer_sum / self._cfg.num_simulations
+ return length, average_infer
+
class GumbelMuZeroMCTSCtree(object):
"""
@@ -446,7 +678,7 @@ def __init__(self, cfg: EasyDict = None) -> None:
self.inverse_scalar_transform_handle = InverseScalarTransform(
self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution
)
-
+
@classmethod
def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "gmz_ctree":
"""
@@ -462,8 +694,8 @@ def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "g
return tree_gumbel_muzero.Roots(active_collect_env_num, legal_actions)
def search(self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], to_play_batch: Union[int,
- List[Any]]
- ) -> None:
+ List[Any]]
+ ) -> None:
"""
Overview:
Do MCTS for a batch of roots. Parallel in model inference. \
@@ -511,12 +743,14 @@ def search(self, roots: Any, model: torch.nn.Module, latent_state_roots: List[An
"""
if self._cfg.env_type == 'not_board_games':
latent_state_index_in_search_path, latent_state_index_in_batch, last_actions, virtual_to_play_batch = tree_gumbel_muzero.batch_traverse(
- roots, self._cfg.num_simulations, self._cfg.max_num_considered_actions, discount_factor, results, to_play_batch
+ roots, self._cfg.num_simulations, self._cfg.max_num_considered_actions, discount_factor,
+ results, to_play_batch
)
else:
# the ``to_play_batch`` is only used in board games, here we need to deepcopy it to avoid changing the original data.
latent_state_index_in_search_path, latent_state_index_in_batch, last_actions, virtual_to_play_batch = tree_gumbel_muzero.batch_traverse(
- roots, self._cfg.num_simulations, self._cfg.max_num_considered_actions, discount_factor, results, copy.deepcopy(to_play_batch)
+ roots, self._cfg.num_simulations, self._cfg.max_num_considered_actions, discount_factor,
+ results, copy.deepcopy(to_play_batch)
)
# obtain the states for leaf nodes
diff --git a/lzero/policy/alphazero.py b/lzero/policy/alphazero.py
index c0b7e5d7a..6589198a8 100644
--- a/lzero/policy/alphazero.py
+++ b/lzero/policy/alphazero.py
@@ -50,10 +50,10 @@ class AlphaZeroPolicy(Policy):
# collect data -> update policy-> collect data -> ...
# For different env, we have different episode_length,
# we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor.
- # If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.model_update_ratio automatically.
+ # If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.replay_ratio automatically.
update_per_collect=None,
# (float) The ratio of the collected data used for training. Only effective when ``update_per_collect`` is not None.
- model_update_ratio=0.1,
+ replay_ratio=0.25,
# (int) Minibatch size for one gradient descent.
batch_size=256,
# (str) Optimizer for training policy network. ['SGD', 'Adam', 'AdamW']
diff --git a/lzero/policy/efficientzero.py b/lzero/policy/efficientzero.py
index 3a94baf51..9f747e0db 100644
--- a/lzero/policy/efficientzero.py
+++ b/lzero/policy/efficientzero.py
@@ -111,10 +111,10 @@ class EfficientZeroPolicy(MuZeroPolicy):
# collect data -> update policy-> collect data -> ...
# For different env, we have different episode_length,
# we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor
- # if we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.model_update_ratio automatically.
+ # if we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.replay_ratio automatically.
update_per_collect=None,
# (float) The ratio of the collected data used for training. Only effective when ``update_per_collect`` is not None.
- model_update_ratio=0.1,
+ replay_ratio=0.25,
# (int) Minibatch size for one gradient descent.
batch_size=256,
# (str) Optimizer for training policy network. ['SGD', 'Adam', 'AdamW']
@@ -165,7 +165,11 @@ class EfficientZeroPolicy(MuZeroPolicy):
# (bool) Whether to use the true chance in MCTS in some environments with stochastic dynamics, such as 2048.
use_ture_chance_label_in_chance_encoder=False,
# (bool) Whether to add noise to roots during reanalyze process.
- reanalyze_noise=False,
+ reanalyze_noise=True,
+ # (bool) Whether to reuse the root value between batch searches.
+ reuse_search=True,
+ # (bool) whether to use the pure policy to collect data. If False, use the MCTS guided with policy.
+ collect_with_pure_policy=False,
# ****** Priority ******
# (bool) Whether to use priority when sampling training data from the buffer.
@@ -582,58 +586,77 @@ def _forward_collect(
policy_logits = policy_logits.detach().cpu().numpy().tolist()
legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num)]
- # the only difference between collect and eval is the dirichlet noise.
- noises = [
- np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j]))
- ).astype(np.float32).tolist() for j in range(active_collect_env_num)
- ]
- if self._cfg.mcts_ctree:
- # cpp mcts_tree
- roots = MCTSCtree.roots(active_collect_env_num, legal_actions)
- else:
- # python mcts_tree
- roots = MCTSPtree.roots(active_collect_env_num, legal_actions)
- roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_roots, policy_logits, to_play)
- self._mcts_collect.search(
- roots, self._collect_model, latent_state_roots, reward_hidden_state_roots, to_play
- )
- roots_visit_count_distributions = roots.get_distributions()
- roots_values = roots.get_values() # shape: {list: batch_size}
+ if not self._cfg.collect_with_pure_policy:
+ # the only difference between collect and eval is the dirichlet noise.
+ noises = [
+ np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j]))
+ ).astype(np.float32).tolist() for j in range(active_collect_env_num)
+ ]
+ if self._cfg.mcts_ctree:
+ # cpp mcts_tree
+ roots = MCTSCtree.roots(active_collect_env_num, legal_actions)
+ else:
+ # python mcts_tree
+ roots = MCTSPtree.roots(active_collect_env_num, legal_actions)
+ roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_roots, policy_logits, to_play)
+ self._mcts_collect.search(
+ roots, self._collect_model, latent_state_roots, reward_hidden_state_roots, to_play
+ )
- data_id = [i for i in range(active_collect_env_num)]
- output = {i: None for i in data_id}
- if ready_env_id is None:
- ready_env_id = np.arange(active_collect_env_num)
+ roots_visit_count_distributions = roots.get_distributions()
+ roots_values = roots.get_values() # shape: {list: batch_size}
+
+ data_id = [i for i in range(active_collect_env_num)]
+ output = {i: None for i in data_id}
+ if ready_env_id is None:
+ ready_env_id = np.arange(active_collect_env_num)
+
+ for i, env_id in enumerate(ready_env_id):
+ distributions, value = roots_visit_count_distributions[i], roots_values[i]
+ if self._cfg.eps.eps_greedy_exploration_in_collect:
+ # eps-greedy collect
+ action_index_in_legal_action_set, visit_count_distribution_entropy = select_action(
+ distributions, temperature=self._collect_mcts_temperature, deterministic=True
+ )
+ action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set]
+ if np.random.rand() < self.collect_epsilon:
+ action = np.random.choice(legal_actions[i])
+ else:
+ # normal collect
+ # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents
+ # the index within the legal action set, rather than the index in the entire action set.
+ action_index_in_legal_action_set, visit_count_distribution_entropy = select_action(
+ distributions, temperature=self._collect_mcts_temperature, deterministic=False
+ )
+ # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set.
+ action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set]
+ output[env_id] = {
+ 'action': action,
+ 'visit_count_distributions': distributions,
+ 'visit_count_distribution_entropy': visit_count_distribution_entropy,
+ 'searched_value': value,
+ 'predicted_value': pred_values[i],
+ 'predicted_policy_logits': policy_logits[i],
+ }
+ else:
+ data_id = [i for i in range(active_collect_env_num)]
+ output = {i: None for i in data_id}
- for i, env_id in enumerate(ready_env_id):
- distributions, value = roots_visit_count_distributions[i], roots_values[i]
- if self._cfg.eps.eps_greedy_exploration_in_collect:
- # eps-greedy collect
- action_index_in_legal_action_set, visit_count_distribution_entropy = select_action(
- distributions, temperature=self._collect_mcts_temperature, deterministic=True
- )
- action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set]
- if np.random.rand() < self.collect_epsilon:
- action = np.random.choice(legal_actions[i])
- else:
- # normal collect
- # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents
- # the index within the legal action set, rather than the index in the entire action set.
- action_index_in_legal_action_set, visit_count_distribution_entropy = select_action(
- distributions, temperature=self._collect_mcts_temperature, deterministic=False
- )
- # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set.
- action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set]
- output[env_id] = {
- 'action': action,
- 'visit_count_distributions': distributions,
- 'visit_count_distribution_entropy': visit_count_distribution_entropy,
- 'searched_value': value,
- 'predicted_value': pred_values[i],
- 'predicted_policy_logits': policy_logits[i],
- }
+ if ready_env_id is None:
+ ready_env_id = np.arange(active_collect_env_num)
+ for i, env_id in enumerate(ready_env_id):
+ policy_values = torch.softmax(torch.tensor([policy_logits[i][a] for a in legal_actions[i]]), dim=0).tolist()
+ policy_values = policy_values / np.sum(policy_values)
+ action_index_in_legal_action_set = np.random.choice(len(legal_actions[i]), p=policy_values)
+ action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set]
+ output[env_id] = {
+ 'action': action,
+ 'searched_value': pred_values[i],
+ 'predicted_value': pred_values[i],
+ 'predicted_policy_logits': policy_logits[i],
+ }
return output
def _init_eval(self) -> None:
diff --git a/lzero/policy/gumbel_alphazero.py b/lzero/policy/gumbel_alphazero.py
index 0dcb53982..dae84e2c5 100644
--- a/lzero/policy/gumbel_alphazero.py
+++ b/lzero/policy/gumbel_alphazero.py
@@ -51,10 +51,10 @@ class GumbelAlphaZeroPolicy(Policy):
# collect data -> update policy-> collect data -> ...
# For different env, we have different episode_length,
# we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor.
- # If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.model_update_ratio automatically.
+ # If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.replay_ratio automatically.
update_per_collect=None,
# (float) The ratio of the collected data used for training. Only effective when ``update_per_collect`` is not None.
- model_update_ratio=0.1,
+ replay_ratio=0.25,
# (int) Minibatch size for one gradient descent.
batch_size=256,
# (str) Optimizer for training policy network. ['SGD', 'Adam', 'AdamW']
diff --git a/lzero/policy/gumbel_muzero.py b/lzero/policy/gumbel_muzero.py
index c04fbb3c3..aeecfae67 100644
--- a/lzero/policy/gumbel_muzero.py
+++ b/lzero/policy/gumbel_muzero.py
@@ -111,10 +111,10 @@ class GumbelMuZeroPolicy(MuZeroPolicy):
# Bigger "update_per_collect" means bigger off-policy.
# collect data -> update policy-> collect data -> ...
# For different env, we have different episode_length,
- # If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.model_update_ratio automatically.
+ # If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.replay_ratio automatically.
update_per_collect=None,
# (float) The ratio of the collected data used for training. Only effective when ``update_per_collect`` is not None.
- model_update_ratio=0.1,
+ replay_ratio=0.25,
# (int) Minibatch size for one gradient descent.
batch_size=256,
# (str) Optimizer for training policy network. ['SGD' or 'Adam']
@@ -162,7 +162,7 @@ class GumbelMuZeroPolicy(MuZeroPolicy):
# The larger the value, the more exploration. This value is only used when manual_temperature_decay=False.
fixed_temperature_value=0.25,
# (bool) Whether to add noise to roots during reanalyze process.
- reanalyze_noise=False,
+ reanalyze_noise=True,
# ****** Priority ******
# (bool) Whether to use priority when sampling training data from the buffer.
diff --git a/lzero/policy/muzero.py b/lzero/policy/muzero.py
index 0acc66b07..66736047a 100644
--- a/lzero/policy/muzero.py
+++ b/lzero/policy/muzero.py
@@ -87,7 +87,7 @@ class MuZeroPolicy(Policy):
# (bool) Whether to monitor extra statistics in tensorboard.
monitor_extra_statistics=True,
# (int) The transition number of one ``GameSegment``.
- game_segment_length=200,
+ game_segment_length=400,
# (bool): Indicates whether to perform an offline evaluation of the checkpoint (ckpt).
# If set to True, the checkpoint will be evaluated after the training process is complete.
# IMPORTANT: Setting eval_offline to True requires configuring the saving of checkpoints to align with the evaluation frequency.
@@ -114,10 +114,10 @@ class MuZeroPolicy(Policy):
# collect data -> update policy-> collect data -> ...
# For different env, we have different episode_length,
# we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor.
- # If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.model_update_ratio automatically.
+ # If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.replay_ratio automatically.
update_per_collect=None,
# (float) The ratio of the collected data used for training. Only effective when ``update_per_collect`` is not None.
- model_update_ratio=0.1,
+ replay_ratio=0.25,
# (int) Minibatch size for one gradient descent.
batch_size=256,
# (str) Optimizer for training policy network. ['SGD', 'Adam']
@@ -169,7 +169,11 @@ class MuZeroPolicy(Policy):
# (bool) Whether to use the true chance in MCTS in some environments with stochastic dynamics, such as 2048.
use_ture_chance_label_in_chance_encoder=False,
# (bool) Whether to add noise to roots during reanalyze process.
- reanalyze_noise=False,
+ reanalyze_noise=True,
+ # (bool) Whether to reuse the root value between batch searches.
+ reuse_search=True,
+ # (bool) whether to use the pure policy to collect data. If False, use the MCTS guided with policy.
+ collect_with_pure_policy=False,
# ****** Priority ******
# (bool) Whether to use priority when sampling training data from the buffer.
@@ -446,11 +450,11 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
self._cfg.policy_entropy_loss_weight * policy_entropy_loss
)
weighted_total_loss = (weights * loss).mean()
-
gradient_scale = 1 / self._cfg.num_unroll_steps
weighted_total_loss.register_hook(lambda grad: grad * gradient_scale)
self._optimizer.zero_grad()
weighted_total_loss.backward()
+
if self._cfg.multi_gpu:
self.sync_gradients(self._learn_model)
total_grad_norm_before_clip = torch.nn.utils.clip_grad_norm_(self._learn_model.parameters(),
@@ -556,58 +560,78 @@ def _forward_collect(
policy_logits = policy_logits.detach().cpu().numpy().tolist()
legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num)]
- # the only difference between collect and eval is the dirichlet noise
- noises = [
- np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j]))
- ).astype(np.float32).tolist() for j in range(active_collect_env_num)
- ]
- if self._cfg.mcts_ctree:
- # cpp mcts_tree
- roots = MCTSCtree.roots(active_collect_env_num, legal_actions)
- else:
- # python mcts_tree
- roots = MCTSPtree.roots(active_collect_env_num, legal_actions)
-
- roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play)
- self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play)
- # list of list, shape: ``{list: batch_size} -> {list: action_space_size}``
- roots_visit_count_distributions = roots.get_distributions()
- roots_values = roots.get_values() # shape: {list: batch_size}
+ if not self._cfg.collect_with_pure_policy:
+ # the only difference between collect and eval is the dirichlet noise
+ noises = [
+ np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j]))
+ ).astype(np.float32).tolist() for j in range(active_collect_env_num)
+ ]
+ if self._cfg.mcts_ctree:
+ # cpp mcts_tree
+ roots = MCTSCtree.roots(active_collect_env_num, legal_actions)
+ else:
+ # python mcts_tree
+ roots = MCTSPtree.roots(active_collect_env_num, legal_actions)
+
+ roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play)
+ self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play)
+
+ # list of list, shape: ``{list: batch_size} -> {list: action_space_size}``
+ roots_visit_count_distributions = roots.get_distributions()
+ roots_values = roots.get_values() # shape: {list: batch_size}
+
+ data_id = [i for i in range(active_collect_env_num)]
+ output = {i: None for i in data_id}
+
+ if ready_env_id is None:
+ ready_env_id = np.arange(active_collect_env_num)
+
+ for i, env_id in enumerate(ready_env_id):
+ distributions, value = roots_visit_count_distributions[i], roots_values[i]
+ if self._cfg.eps.eps_greedy_exploration_in_collect:
+ # eps greedy collect
+ action_index_in_legal_action_set, visit_count_distribution_entropy = select_action(
+ distributions, temperature=self._collect_mcts_temperature, deterministic=True
+ )
+ action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set]
+ if np.random.rand() < self.collect_epsilon:
+ action = np.random.choice(legal_actions[i])
+ else:
+ # normal collect
+ # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents
+ # the index within the legal action set, rather than the index in the entire action set.
+ action_index_in_legal_action_set, visit_count_distribution_entropy = select_action(
+ distributions, temperature=self._collect_mcts_temperature, deterministic=False
+ )
+ # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set.
+ action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set]
+ output[env_id] = {
+ 'action': action,
+ 'visit_count_distributions': distributions,
+ 'visit_count_distribution_entropy': visit_count_distribution_entropy,
+ 'searched_value': value,
+ 'predicted_value': pred_values[i],
+ 'predicted_policy_logits': policy_logits[i],
+ }
+ else:
+ data_id = [i for i in range(active_collect_env_num)]
+ output = {i: None for i in data_id}
- data_id = [i for i in range(active_collect_env_num)]
- output = {i: None for i in data_id}
+ if ready_env_id is None:
+ ready_env_id = np.arange(active_collect_env_num)
- if ready_env_id is None:
- ready_env_id = np.arange(active_collect_env_num)
-
- for i, env_id in enumerate(ready_env_id):
- distributions, value = roots_visit_count_distributions[i], roots_values[i]
- if self._cfg.eps.eps_greedy_exploration_in_collect:
- # eps greedy collect
- action_index_in_legal_action_set, visit_count_distribution_entropy = select_action(
- distributions, temperature=self._collect_mcts_temperature, deterministic=True
- )
+ for i, env_id in enumerate(ready_env_id):
+ policy_values = torch.softmax(torch.tensor([policy_logits[i][a] for a in legal_actions[i]]), dim=0).tolist()
+ policy_values = policy_values / np.sum(policy_values)
+ action_index_in_legal_action_set = np.random.choice(len(legal_actions[i]), p=policy_values)
action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set]
- if np.random.rand() < self.collect_epsilon:
- action = np.random.choice(legal_actions[i])
- else:
- # normal collect
- # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents
- # the index within the legal action set, rather than the index in the entire action set.
- action_index_in_legal_action_set, visit_count_distribution_entropy = select_action(
- distributions, temperature=self._collect_mcts_temperature, deterministic=False
- )
- # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set.
- action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set]
- output[env_id] = {
- 'action': action,
- 'visit_count_distributions': distributions,
- 'visit_count_distribution_entropy': visit_count_distribution_entropy,
- 'searched_value': value,
- 'predicted_value': pred_values[i],
- 'predicted_policy_logits': policy_logits[i],
- }
+ output[env_id] = {
+ 'action': action,
+ 'searched_value': pred_values[i],
+ 'predicted_value': pred_values[i],
+ 'predicted_policy_logits': policy_logits[i],
+ }
return output
diff --git a/lzero/policy/sampled_alphazero.py b/lzero/policy/sampled_alphazero.py
index baa1d18b0..6c228602a 100644
--- a/lzero/policy/sampled_alphazero.py
+++ b/lzero/policy/sampled_alphazero.py
@@ -53,10 +53,10 @@ class SampledAlphaZeroPolicy(Policy):
# collect data -> update policy-> collect data -> ...
# For different env, we have different episode_length,
# we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor.
- # If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.model_update_ratio automatically.
+ # If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.replay_ratio automatically.
update_per_collect=None,
# (float) The ratio of the collected data used for training. Only effective when ``update_per_collect`` is not None.
- model_update_ratio=0.1,
+ replay_ratio=0.25,
# (int) Minibatch size for one gradient descent.
batch_size=256,
# (str) Optimizer for training policy network. ['SGD', 'Adam', 'AdamW']
diff --git a/lzero/policy/sampled_efficientzero.py b/lzero/policy/sampled_efficientzero.py
index 7003f6808..2c813bea4 100644
--- a/lzero/policy/sampled_efficientzero.py
+++ b/lzero/policy/sampled_efficientzero.py
@@ -118,10 +118,10 @@ class SampledEfficientZeroPolicy(MuZeroPolicy):
# collect data -> update policy-> collect data -> ...
# For different env, we have different episode_length,
# we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor.
- # If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.model_update_ratio automatically.
+ # If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.replay_ratio automatically.
update_per_collect=None,
# (float) The ratio of the collected data used for training. Only effective when ``update_per_collect`` is not None.
- model_update_ratio=0.1,
+ replay_ratio=0.25,
# (int) Minibatch size for one gradient descent.
batch_size=256,
# (str) Optimizer for training policy network. ['SGD', 'Adam', 'AdamW']
@@ -180,7 +180,7 @@ class SampledEfficientZeroPolicy(MuZeroPolicy):
# (bool) Whether to use the true chance in MCTS in some environments with stochastic dynamics, such as 2048.
use_ture_chance_label_in_chance_encoder=False,
# (bool) Whether to add noise to roots during reanalyze process.
- reanalyze_noise=False,
+ reanalyze_noise=True,
# ****** Priority ******
# (bool) Whether to use priority when sampling training data from the buffer.
diff --git a/lzero/policy/stochastic_muzero.py b/lzero/policy/stochastic_muzero.py
index 57d52ad1d..069eb96eb 100644
--- a/lzero/policy/stochastic_muzero.py
+++ b/lzero/policy/stochastic_muzero.py
@@ -109,7 +109,7 @@ class StochasticMuZeroPolicy(MuZeroPolicy):
# we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor
update_per_collect=100,
# (float) The ratio of the collected data used for training. Only effective when ``update_per_collect`` is not None.
- model_update_ratio=0.1,
+ replay_ratio=0.25,
# (int) Minibatch size for one gradient descent.
batch_size=256,
# (str) Optimizer for training policy network. ['SGD', 'Adam']
@@ -163,7 +163,7 @@ class StochasticMuZeroPolicy(MuZeroPolicy):
# (bool) Whether to use the true chance in MCTS. If False, use the predicted chance.
use_ture_chance_label_in_chance_encoder=False,
# (bool) Whether to add noise to roots during reanalyze process.
- reanalyze_noise=False,
+ reanalyze_noise=True,
# ****** Priority ******
# (bool) Whether to use priority when sampling training data from the buffer.
diff --git a/lzero/policy/tests/config/atari_muzero_config_for_test.py b/lzero/policy/tests/config/atari_muzero_config_for_test.py
index 9ed0fba8a..c0bc7ff53 100644
--- a/lzero/policy/tests/config/atari_muzero_config_for_test.py
+++ b/lzero/policy/tests/config/atari_muzero_config_for_test.py
@@ -31,7 +31,7 @@
atari_muzero_config = dict(
exp_name=
- f'data_mz_ctree/{env_id[:-14]}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0',
+ f'data_muzero/{env_id[:-14]}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0',
env=dict(
stop_value=int(1e6),
env_id=env_id,
diff --git a/lzero/policy/tests/config/cartpole_muzero_config_for_test.py b/lzero/policy/tests/config/cartpole_muzero_config_for_test.py
index deaf1fddd..22fa1518a 100644
--- a/lzero/policy/tests/config/cartpole_muzero_config_for_test.py
+++ b/lzero/policy/tests/config/cartpole_muzero_config_for_test.py
@@ -16,7 +16,7 @@
# ==============================================================
cartpole_muzero_config = dict(
- exp_name=f'data_mz_ctree/cartpole_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0',
+ exp_name=f'data_muzero/cartpole_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0',
env=dict(
env_id='CartPole-v0',
continuous=False,
diff --git a/lzero/worker/__init__.py b/lzero/worker/__init__.py
index 3000ee23f..b74e1e745 100644
--- a/lzero/worker/__init__.py
+++ b/lzero/worker/__init__.py
@@ -1,4 +1,4 @@
from .alphazero_collector import AlphaZeroCollector
from .alphazero_evaluator import AlphaZeroEvaluator
from .muzero_collector import MuZeroCollector
-from .muzero_evaluator import MuZeroEvaluator
\ No newline at end of file
+from .muzero_evaluator import MuZeroEvaluator
diff --git a/lzero/worker/muzero_collector.py b/lzero/worker/muzero_collector.py
index 8bb5f51ba..8ebcbdb5d 100644
--- a/lzero/worker/muzero_collector.py
+++ b/lzero/worker/muzero_collector.py
@@ -80,6 +80,7 @@ def __init__(
self._tb_logger = None
self.policy_config = policy_config
+ self.collect_with_pure_policy = self.policy_config.collect_with_pure_policy
self.reset(policy, env)
@@ -210,7 +211,7 @@ def _compute_priorities(self, i: int, pred_values_lst: List[float], search_value
if self.policy_config.use_priority:
# Calculate priorities. The priorities are the L1 losses between the predicted
# values and the search values. We use 'none' as the reduction parameter, which
- # means the loss is calculated for each element individually, instead of being summed or averaged.
+ # means the loss is calculated for each element individually, instead of being summed or averaged.
# A small constant (1e-6) is added to the results to avoid zero priorities. This
# is done because zero priorities could potentially cause issues in some scenarios.
pred_values = torch.from_numpy(np.array(pred_values_lst[i])).to(self.policy_config.device).float().view(-1)
@@ -304,7 +305,8 @@ def pad_and_save_last_trajectory(self, i: int, last_game_segments: List[GameSegm
def collect(self,
n_episode: Optional[int] = None,
train_iter: int = 0,
- policy_kwargs: Optional[dict] = None) -> List[Any]:
+ policy_kwargs: Optional[dict] = None,
+ collect_with_pure_policy: bool = False) -> List[Any]:
"""
Overview:
Collect `n_episode` episodes of data with policy_kwargs, trained for `train_iter` iterations.
@@ -312,6 +314,7 @@ def collect(self,
- n_episode (:obj:`Optional[int]`): Number of episodes to collect.
- train_iter (:obj:`int`): Number of training iterations completed so far.
- policy_kwargs (:obj:`Optional[dict]`): Additional keyword arguments for the policy.
+ - collect_with_pure_policy (:obj:`bool`): Whether to collect data using pure policy without MCTS.
Returns:
- return_data (:obj:`List[Any]`): Collected data in the form of a list.
"""
@@ -389,6 +392,8 @@ def collect(self,
ready_env_id = set()
remain_episode = n_episode
+ if collect_with_pure_policy:
+ temp_visit_list = [0.0 for i in range(self._env.action_space.n)]
while True:
with self._timer:
@@ -422,7 +427,6 @@ def collect(self,
policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon)
actions_no_env_id = {k: v['action'] for k, v in policy_output.items()}
- distributions_dict_no_env_id = {k: v['visit_count_distributions'] for k, v in policy_output.items()}
if self.policy_config.sampled_algo:
root_sampled_actions_dict_no_env_id = {
k: v['root_sampled_actions']
@@ -430,39 +434,53 @@ def collect(self,
}
value_dict_no_env_id = {k: v['searched_value'] for k, v in policy_output.items()}
pred_value_dict_no_env_id = {k: v['predicted_value'] for k, v in policy_output.items()}
- visit_entropy_dict_no_env_id = {
- k: v['visit_count_distribution_entropy']
- for k, v in policy_output.items()
- }
-
- if self.policy_config.gumbel_algo:
- improved_policy_dict_no_env_id = {k: v['improved_policy_probs'] for k, v in policy_output.items()}
- completed_value_no_env_id = {
- k: v['roots_completed_value']
+
+ if not collect_with_pure_policy:
+ distributions_dict_no_env_id = {k: v['visit_count_distributions'] for k, v in policy_output.items()}
+ if self.policy_config.sampled_algo:
+ root_sampled_actions_dict_no_env_id = {
+ k: v['root_sampled_actions']
+ for k, v in policy_output.items()
+ }
+ visit_entropy_dict_no_env_id = {
+ k: v['visit_count_distribution_entropy']
for k, v in policy_output.items()
}
+ if self.policy_config.gumbel_algo:
+ improved_policy_dict_no_env_id = {k: v['improved_policy_probs'] for k, v in
+ policy_output.items()}
+ completed_value_no_env_id = {
+ k: v['roots_completed_value']
+ for k, v in policy_output.items()
+ }
+
# TODO(pu): subprocess
actions = {}
- distributions_dict = {}
- if self.policy_config.sampled_algo:
- root_sampled_actions_dict = {}
value_dict = {}
pred_value_dict = {}
- visit_entropy_dict = {}
- if self.policy_config.gumbel_algo:
- improved_policy_dict = {}
- completed_value_dict = {}
+
+ if not collect_with_pure_policy:
+ distributions_dict = {}
+ if self.policy_config.sampled_algo:
+ root_sampled_actions_dict = {}
+ visit_entropy_dict = {}
+ if self.policy_config.gumbel_algo:
+ improved_policy_dict = {}
+ completed_value_dict = {}
+
for index, env_id in enumerate(ready_env_id):
actions[env_id] = actions_no_env_id.pop(index)
- distributions_dict[env_id] = distributions_dict_no_env_id.pop(index)
- if self.policy_config.sampled_algo:
- root_sampled_actions_dict[env_id] = root_sampled_actions_dict_no_env_id.pop(index)
value_dict[env_id] = value_dict_no_env_id.pop(index)
pred_value_dict[env_id] = pred_value_dict_no_env_id.pop(index)
- visit_entropy_dict[env_id] = visit_entropy_dict_no_env_id.pop(index)
- if self.policy_config.gumbel_algo:
- improved_policy_dict[env_id] = improved_policy_dict_no_env_id.pop(index)
- completed_value_dict[env_id] = completed_value_no_env_id.pop(index)
+
+ if not collect_with_pure_policy:
+ distributions_dict[env_id] = distributions_dict_no_env_id.pop(index)
+ if self.policy_config.sampled_algo:
+ root_sampled_actions_dict[env_id] = root_sampled_actions_dict_no_env_id.pop(index)
+ visit_entropy_dict[env_id] = visit_entropy_dict_no_env_id.pop(index)
+ if self.policy_config.gumbel_algo:
+ improved_policy_dict[env_id] = improved_policy_dict_no_env_id.pop(index)
+ completed_value_dict[env_id] = completed_value_no_env_id.pop(index)
# ==============================================================
# Interact with env.
@@ -483,15 +501,19 @@ def collect(self,
continue
obs, reward, done, info = timestep.obs, timestep.reward, timestep.done, timestep.info
- if self.policy_config.sampled_algo:
- game_segments[env_id].store_search_stats(
- distributions_dict[env_id], value_dict[env_id], root_sampled_actions_dict[env_id]
- )
- elif self.policy_config.gumbel_algo:
- game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id],
- improved_policy=improved_policy_dict[env_id])
+ if collect_with_pure_policy:
+ game_segments[env_id].store_search_stats(temp_visit_list, 0)
else:
- game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id])
+ if self.policy_config.sampled_algo:
+ game_segments[env_id].store_search_stats(
+ distributions_dict[env_id], value_dict[env_id], root_sampled_actions_dict[env_id]
+ )
+ elif self.policy_config.gumbel_algo:
+ game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id],
+ improved_policy=improved_policy_dict[env_id])
+ else:
+ game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id])
+
# append a transition tuple, including a_t, o_{t+1}, r_{t}, action_mask_{t}, to_play_{t}
# in ``game_segments[env_id].init``, we have appended o_{t} in ``self.obs_segment``
if self.policy_config.use_ture_chance_label_in_chance_encoder:
@@ -517,9 +539,10 @@ def collect(self,
else:
dones[env_id] = done
- visit_entropies_lst[env_id] += visit_entropy_dict[env_id]
- if self.policy_config.gumbel_algo:
- completed_value_lst[env_id] += np.mean(np.array(completed_value_dict[env_id]))
+ if not collect_with_pure_policy:
+ visit_entropies_lst[env_id] += visit_entropy_dict[env_id]
+ if self.policy_config.gumbel_algo:
+ completed_value_lst[env_id] += np.mean(np.array(completed_value_dict[env_id]))
eps_steps_lst[env_id] += 1
total_transitions += 1
@@ -527,7 +550,7 @@ def collect(self,
if self.policy_config.use_priority:
pred_values_lst[env_id].append(pred_value_dict[env_id])
search_values_lst[env_id].append(value_dict[env_id])
- if self.policy_config.gumbel_algo:
+ if self.policy_config.gumbel_algo and not collect_with_pure_policy:
improved_policy_lst[env_id].append(improved_policy_dict[env_id])
# append the newest obs
@@ -550,7 +573,7 @@ def collect(self,
priorities = self._compute_priorities(env_id, pred_values_lst, search_values_lst)
pred_values_lst[env_id] = []
search_values_lst[env_id] = []
- if self.policy_config.gumbel_algo:
+ if self.policy_config.gumbel_algo and not collect_with_pure_policy:
improved_policy_lst[env_id] = []
# the current game_segments become last_game_segment
@@ -575,10 +598,12 @@ def collect(self,
'reward': reward,
'time': self._env_info[env_id]['time'],
'step': self._env_info[env_id]['step'],
- 'visit_entropy': visit_entropies_lst[env_id] / eps_steps_lst[env_id],
}
- if self.policy_config.gumbel_algo:
- info['completed_value'] = completed_value_lst[env_id] / eps_steps_lst[env_id]
+ if not collect_with_pure_policy:
+ info['visit_entropy'] = visit_entropies_lst[env_id] / eps_steps_lst[env_id]
+ if self.policy_config.gumbel_algo:
+ info['completed_value'] = completed_value_lst[env_id] / eps_steps_lst[env_id]
+
collected_episode += 1
self._episode_info.append(info)
@@ -650,7 +675,8 @@ def collect(self,
# log
self_play_moves_max = max(self_play_moves_max, eps_steps_lst[env_id])
- self_play_visit_entropy.append(visit_entropies_lst[env_id] / eps_steps_lst[env_id])
+ if not collect_with_pure_policy:
+ self_play_visit_entropy.append(visit_entropies_lst[env_id] / eps_steps_lst[env_id])
self_play_moves += eps_steps_lst[env_id]
self_play_episodes += 1
@@ -707,7 +733,10 @@ def _output_log(self, train_iter: int) -> None:
envstep_count = sum([d['step'] for d in self._episode_info])
duration = sum([d['time'] for d in self._episode_info])
episode_reward = [d['reward'] for d in self._episode_info]
- visit_entropy = [d['visit_entropy'] for d in self._episode_info]
+ if not self.collect_with_pure_policy:
+ visit_entropy = [d['visit_entropy'] for d in self._episode_info]
+ else:
+ visit_entropy = [0.0]
if self.policy_config.gumbel_algo:
completed_value = [d['completed_value'] for d in self._episode_info]
self._total_duration += duration
@@ -726,7 +755,6 @@ def _output_log(self, train_iter: int) -> None:
'total_episode_count': self._total_episode_count,
'total_duration': self._total_duration,
'visit_entropy': np.mean(visit_entropy),
- # 'each_reward': episode_reward,
}
if self.policy_config.gumbel_algo:
info['completed_value'] = np.mean(completed_value)
@@ -738,4 +766,4 @@ def _output_log(self, train_iter: int) -> None:
self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter)
if k in ['total_envstep_count']:
continue
- self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count)
+ self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count)
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index 4b3ce597f..ad1020b27 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,6 +1,5 @@
DI-engine>=0.4.7
gymnasium[atari]
-moviepy
numpy>=1.22.4
pympler
bsuite
diff --git a/zoo/atari/config/atari_gumbel_muzero_config.py b/zoo/atari/config/atari_gumbel_muzero_config.py
index 918a93236..4e3591727 100644
--- a/zoo/atari/config/atari_gumbel_muzero_config.py
+++ b/zoo/atari/config/atari_gumbel_muzero_config.py
@@ -31,7 +31,7 @@
atari_gumbel_muzero_config = dict(
exp_name=
- f'data_mz_ctree/{env_id[:-14]}_gumbel_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0',
+ f'data_muzero/{env_id[:-14]}_gumbel_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0',
env=dict(
stop_value=int(1e6),
env_id=env_id,
diff --git a/zoo/atari/config/atari_muzero_config.py b/zoo/atari/config/atari_muzero_config.py
index 1e5d49ce2..6d4d26c1c 100644
--- a/zoo/atari/config/atari_muzero_config.py
+++ b/zoo/atari/config/atari_muzero_config.py
@@ -31,7 +31,7 @@
# ==============================================================
atari_muzero_config = dict(
- exp_name=f'data_mz_ctree/{env_id[:-14]}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0',
+ exp_name=f'data_muzero/{env_id[:-14]}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0',
env=dict(
stop_value=int(1e6),
env_id=env_id,
@@ -52,7 +52,7 @@
norm_type='BN',
),
cuda=True,
- reanalyze_noise=False,
+ reanalyze_noise=True,
env_type='not_board_games',
game_segment_length=400,
random_collect_episode_num=0,
diff --git a/zoo/atari/config/atari_muzero_multigpu_ddp_config.py b/zoo/atari/config/atari_muzero_multigpu_ddp_config.py
index 4c80b7f90..b447f934d 100644
--- a/zoo/atari/config/atari_muzero_multigpu_ddp_config.py
+++ b/zoo/atari/config/atari_muzero_multigpu_ddp_config.py
@@ -32,7 +32,7 @@
# ==============================================================
atari_muzero_config = dict(
- exp_name=f'data_mz_ctree/{env_id[:-14]}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_ddp_{gpu_num}gpu_seed0',
+ exp_name=f'data_muzero/{env_id[:-14]}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_ddp_{gpu_num}gpu_seed0',
env=dict(
stop_value=int(1e6),
env_id=env_id,
diff --git a/zoo/atari/config/atari_rezero_ez_config.py b/zoo/atari/config/atari_rezero_ez_config.py
new file mode 100644
index 000000000..a03ef8d10
--- /dev/null
+++ b/zoo/atari/config/atari_rezero_ez_config.py
@@ -0,0 +1,95 @@
+from easydict import EasyDict
+# options={'PongNoFrameskip-v4', 'QbertNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SpaceInvadersNoFrameskip-v4', 'BreakoutNoFrameskip-v4', ...}
+env_id = 'PongNoFrameskip-v4'
+
+if env_id == 'PongNoFrameskip-v4':
+ action_space_size = 6
+elif env_id == 'QbertNoFrameskip-v4':
+ action_space_size = 6
+elif env_id == 'MsPacmanNoFrameskip-v4':
+ action_space_size = 9
+elif env_id == 'SpaceInvadersNoFrameskip-v4':
+ action_space_size = 6
+elif env_id == 'BreakoutNoFrameskip-v4':
+ action_space_size = 4
+
+# ==============================================================
+# begin of the most frequently changed config specified by the user
+# ==============================================================
+collector_env_num = 8
+n_episode = 8
+evaluator_env_num = 3
+num_simulations = 50
+update_per_collect = None
+batch_size = 256
+max_env_step = int(5e5)
+use_priority = False
+# ============= The key different params for ReZero =============
+reuse_search = True
+collect_with_pure_policy = True
+buffer_reanalyze_freq = 1
+# ==============================================================
+# end of the most frequently changed config specified by the user
+# ==============================================================
+
+atari_efficientzero_config = dict(
+ exp_name=f'data_rezero_ez/{env_id[:-14]}_rezero_efficientzero_ns{num_simulations}_upc{update_per_collect}_brf{buffer_reanalyze_freq}_seed0',
+ env=dict(
+ env_id=env_id,
+ obs_shape=(4, 96, 96),
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ manager=dict(shared_memory=False, ),
+ ),
+ policy=dict(
+ model=dict(
+ observation_shape=(4, 96, 96),
+ frame_stack_num=4,
+ action_space_size=action_space_size,
+ downsample=True,
+ discrete_action_encoding_type='one_hot',
+ norm_type='BN',
+ ),
+ cuda=True,
+ env_type='not_board_games',
+ use_augmentation=True,
+ update_per_collect=update_per_collect,
+ batch_size=batch_size,
+ optim_type='SGD',
+ lr_piecewise_constant_decay=True,
+ learning_rate=0.2,
+ num_simulations=num_simulations,
+ reanalyze_ratio=0, # NOTE: for rezero, reanalyze_ratio should be 0.
+ n_episode=n_episode,
+ eval_freq=int(2e3),
+ replay_buffer_size=int(1e6),
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ use_priority=use_priority,
+ # ============= The key different params for ReZero =============
+ reuse_search=reuse_search,
+ collect_with_pure_policy=collect_with_pure_policy,
+ buffer_reanalyze_freq=buffer_reanalyze_freq,
+ ),
+)
+atari_efficientzero_config = EasyDict(atari_efficientzero_config)
+main_config = atari_efficientzero_config
+
+atari_efficientzero_create_config = dict(
+ env=dict(
+ type='atari_lightzero',
+ import_names=['zoo.atari.envs.atari_lightzero_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='efficientzero',
+ import_names=['lzero.policy.efficientzero'],
+ ),
+)
+atari_efficientzero_create_config = EasyDict(atari_efficientzero_create_config)
+create_config = atari_efficientzero_create_config
+
+if __name__ == "__main__":
+ from lzero.entry import train_rezero
+ train_rezero([main_config, create_config], seed=0, max_env_step=max_env_step)
diff --git a/zoo/atari/config/atari_rezero_mz_config.py b/zoo/atari/config/atari_rezero_mz_config.py
new file mode 100644
index 000000000..a1d76e36c
--- /dev/null
+++ b/zoo/atari/config/atari_rezero_mz_config.py
@@ -0,0 +1,105 @@
+from easydict import EasyDict
+
+# options={'PongNoFrameskip-v4', 'QbertNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SpaceInvadersNoFrameskip-v4', 'BreakoutNoFrameskip-v4', ...}
+env_id = 'PongNoFrameskip-v4'
+
+if env_id == 'PongNoFrameskip-v4':
+ action_space_size = 6
+elif env_id == 'QbertNoFrameskip-v4':
+ action_space_size = 6
+elif env_id == 'MsPacmanNoFrameskip-v4':
+ action_space_size = 9
+elif env_id == 'SpaceInvadersNoFrameskip-v4':
+ action_space_size = 6
+elif env_id == 'BreakoutNoFrameskip-v4':
+ action_space_size = 4
+
+# ==============================================================
+# begin of the most frequently changed config specified by the user
+# ==============================================================
+collector_env_num = 8
+n_episode = 8
+evaluator_env_num = 3
+num_simulations = 50
+update_per_collect = None
+batch_size = 256
+replay_ratio = 0.25
+max_env_step = int(5e5)
+use_priority = False
+
+# ============= The key different params for ReZero =============
+reuse_search = True
+collect_with_pure_policy = True
+buffer_reanalyze_freq = 1
+# ==============================================================
+# end of the most frequently changed config specified by the user
+# ==============================================================
+
+atari_muzero_config = dict(
+ exp_name=f'data_rezero_mz/{env_id[:-14]}_rezero_muzero_ns{num_simulations}_upc{update_per_collect}_brf{buffer_reanalyze_freq}_seed0',
+ env=dict(
+ stop_value=int(1e6),
+ env_id=env_id,
+ obs_shape=(4, 96, 96),
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ manager=dict(shared_memory=False, ),
+ ),
+ policy=dict(
+ model=dict(
+ observation_shape=(4, 96, 96),
+ frame_stack_num=4,
+ action_space_size=action_space_size,
+ downsample=True,
+ self_supervised_learning_loss=True,
+ discrete_action_encoding_type='one_hot',
+ norm_type='BN',
+ ),
+ cuda=True,
+ env_type='not_board_games',
+ use_augmentation=True,
+ update_per_collect=update_per_collect,
+ replay_ratio=replay_ratio,
+ batch_size=batch_size,
+ optim_type='SGD',
+ lr_piecewise_constant_decay=True,
+ learning_rate=0.2,
+ num_simulations=num_simulations,
+ ssl_loss_weight=2,
+ n_episode=n_episode,
+ eval_freq=int(2e3),
+ replay_buffer_size=int(1e6),
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ use_priority=use_priority,
+ # ============= The key different params for ReZero =============
+ reuse_search=reuse_search,
+ collect_with_pure_policy=collect_with_pure_policy,
+ buffer_reanalyze_freq=buffer_reanalyze_freq,
+ ),
+)
+atari_muzero_config = EasyDict(atari_muzero_config)
+main_config = atari_muzero_config
+
+atari_muzero_create_config = dict(
+ env=dict(
+ type='atari_lightzero',
+ import_names=['zoo.atari.envs.atari_lightzero_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='muzero',
+ import_names=['lzero.policy.muzero'],
+ ),
+ collector=dict(
+ type='episode_muzero',
+ import_names=['lzero.worker.muzero_collector'],
+ )
+)
+atari_muzero_create_config = EasyDict(atari_muzero_create_config)
+create_config = atari_muzero_create_config
+
+if __name__ == "__main__":
+ from lzero.entry import train_rezero
+ train_rezero([main_config, create_config], seed=0, max_env_step=max_env_step)
diff --git a/zoo/atari/entry/test_memory_usage.py b/zoo/atari/entry/test_memory_usage.py
new file mode 100644
index 000000000..cfb3e9d5d
--- /dev/null
+++ b/zoo/atari/entry/test_memory_usage.py
@@ -0,0 +1,103 @@
+from lzero.mcts import ReZeroMZGameBuffer as GameBuffer
+from zoo.atari.config.atari_rezero_mz_config import main_config, create_config
+import torch
+import numpy as np
+from ding.config import compile_config
+from ding.policy import create_policy
+from tensorboardX import SummaryWriter
+import psutil
+
+
+def get_memory_usage():
+ """
+ Get the current memory usage of the process.
+
+ Returns:
+ int: Memory usage in bytes.
+ """
+ process = psutil.Process()
+ memory_info = process.memory_info()
+ return memory_info.rss
+
+
+def initialize_policy(cfg, create_cfg):
+ """
+ Initialize the policy based on the given configuration.
+
+ Args:
+ cfg (Config): Main configuration object.
+ create_cfg (Config): Creation configuration object.
+
+ Returns:
+ Policy: Initialized policy object.
+ """
+ if cfg.policy.cuda and torch.cuda.is_available():
+ cfg.policy.device = 'cuda'
+ else:
+ cfg.policy.device = 'cpu'
+
+ cfg = compile_config(cfg, seed=0, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)
+ policy = create_policy(cfg.policy, model=None, enable_field=['learn', 'collect', 'eval'])
+
+ model_path = '{template_path}/iteration_20000.pth.tar'
+ if model_path is not None:
+ policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device))
+
+ return policy, cfg
+
+
+def run_memory_test(replay_buffer, policy, writer, sample_batch_size):
+ """
+ Run memory usage test for sampling from the replay buffer.
+
+ Args:
+ replay_buffer (GameBuffer): The replay buffer to sample from.
+ policy (Policy): The policy object.
+ writer (SummaryWriter): TensorBoard summary writer.
+ sample_batch_size (int): The base batch size for sampling.
+ """
+ for i in range(2):
+ initial_memory = get_memory_usage()
+ print(f"Initial memory usage: {initial_memory} bytes")
+
+ replay_buffer.sample(sample_batch_size * (i + 1), policy)
+
+ final_memory = get_memory_usage()
+ memory_cost = final_memory - initial_memory
+
+ print(f"Memory usage after sampling: {final_memory} bytes")
+ print(f"Memory cost of sampling: {float(memory_cost) / 1e9:.2f} GB")
+
+ writer.add_scalar("Sampling Memory Usage (GB)", float(memory_cost) / 1e9, i + 1)
+
+ # Reset counters
+ replay_buffer.compute_target_re_time = 0
+ replay_buffer.origin_search_time = 0
+ replay_buffer.reuse_search_time = 0
+ replay_buffer.active_root_num = 0
+
+
+def main():
+ """
+ Main function to run the memory usage test.
+ """
+ cfg, create_cfg = main_config, create_config
+ policy, cfg = initialize_policy(cfg, create_cfg)
+
+ replay_buffer = GameBuffer(cfg.policy)
+
+ # Load and push data to the replay buffer
+ data = np.load('{template_path}/collected_data.npy', allow_pickle=True)
+ for _ in range(50):
+ replay_buffer.push_game_segments(data)
+
+ log_dir = "logs/memory_test2"
+ writer = SummaryWriter(log_dir)
+
+ run_memory_test(replay_buffer, policy, writer, sample_batch_size=25600)
+
+ writer.close()
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/zoo/board_games/connect4/config/connect4_muzero_bot_mode_config.py b/zoo/board_games/connect4/config/connect4_muzero_bot_mode_config.py
index 72491bf3b..091ba98cb 100644
--- a/zoo/board_games/connect4/config/connect4_muzero_bot_mode_config.py
+++ b/zoo/board_games/connect4/config/connect4_muzero_bot_mode_config.py
@@ -16,7 +16,7 @@
# ==============================================================
connect4_muzero_config = dict(
- exp_name=f'data_mz_ctree/connect4_play-with-bot-mode_seed0',
+ exp_name=f'data_muzero/connect4_play-with-bot-mode_seed0',
env=dict(
battle_mode='play_with_bot_mode',
bot_action_type='rule',
@@ -79,5 +79,4 @@
if __name__ == "__main__":
from lzero.entry import train_muzero
-
- train_muzero([main_config, create_config], seed=1, max_env_step=max_env_step)
\ No newline at end of file
+ train_muzero([main_config, create_config], seed=0, max_env_step=max_env_step)
\ No newline at end of file
diff --git a/zoo/board_games/connect4/config/connect4_muzero_sp_mode_config.py b/zoo/board_games/connect4/config/connect4_muzero_sp_mode_config.py
index b1c2ffd1f..10bcbcc86 100644
--- a/zoo/board_games/connect4/config/connect4_muzero_sp_mode_config.py
+++ b/zoo/board_games/connect4/config/connect4_muzero_sp_mode_config.py
@@ -16,7 +16,7 @@
# ==============================================================
connect4_muzero_config = dict(
- exp_name=f'data_mz_ctree/connect4_self-play-mode_seed0',
+ exp_name=f'data_muzero/connect4_self-play-mode_seed0',
env=dict(
battle_mode='self_play_mode',
bot_action_type='rule',
diff --git a/zoo/board_games/connect4/config/connect4_rezero_mz_bot_mode_config.py b/zoo/board_games/connect4/config/connect4_rezero_mz_bot_mode_config.py
new file mode 100644
index 000000000..f78bf043a
--- /dev/null
+++ b/zoo/board_games/connect4/config/connect4_rezero_mz_bot_mode_config.py
@@ -0,0 +1,97 @@
+from easydict import EasyDict
+# import torch
+# torch.cuda.set_device(0)
+# ==============================================================
+# begin of the most frequently changed config specified by the user
+# ==============================================================
+collector_env_num = 8
+n_episode = 8
+evaluator_env_num = 5
+num_simulations = 50
+update_per_collect = 50
+batch_size = 256
+max_env_step = int(1e6)
+
+reuse_search = True
+collect_with_pure_policy = True
+use_priority = False
+buffer_reanalyze_freq = 1
+# ==============================================================
+# end of the most frequently changed config specified by the user
+# ==============================================================
+
+connect4_muzero_config = dict(
+ exp_name=f'data_rezero_mz/connect4_muzero_bot-mode_ns{num_simulations}_upc{update_per_collect}_brf{buffer_reanalyze_freq}_seed0',
+ env=dict(
+ battle_mode='play_with_bot_mode',
+ bot_action_type='rule',
+ channel_last=False,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ manager=dict(shared_memory=False, ),
+ ),
+ policy=dict(
+ model=dict(
+ observation_shape=(3, 6, 7),
+ action_space_size=7,
+ image_channel=3,
+ num_res_blocks=1,
+ num_channels=64,
+ support_scale=300,
+ reward_support_size=601,
+ value_support_size=601,
+ ),
+ cuda=True,
+ env_type='board_games',
+ action_type='varied_action_space',
+ game_segment_length=int(6 * 7 / 2), # for battle_mode='play_with_bot_mode'
+ update_per_collect=update_per_collect,
+ batch_size=batch_size,
+ optim_type='Adam',
+ lr_piecewise_constant_decay=False,
+ learning_rate=0.003,
+ grad_clip_value=0.5,
+ num_simulations=num_simulations,
+ # NOTE๏ผIn board_games, we set large td_steps to make sure the value target is the final outcome.
+ td_steps=int(6 * 7 / 2), # for battle_mode='play_with_bot_mode'
+ # NOTE๏ผIn board_games, we set discount_factor=1.
+ discount_factor=1,
+ n_episode=n_episode,
+ eval_freq=int(2e3),
+ replay_buffer_size=int(1e5),
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ reanalyze_noise=True,
+ use_priority=use_priority,
+ # ============= The key different params for ReZero =============
+ reuse_search=reuse_search,
+ collect_with_pure_policy=collect_with_pure_policy,
+ buffer_reanalyze_freq=buffer_reanalyze_freq,
+ ),
+)
+connect4_muzero_config = EasyDict(connect4_muzero_config)
+main_config = connect4_muzero_config
+
+connect4_muzero_create_config = dict(
+ env=dict(
+ type='connect4',
+ import_names=['zoo.board_games.connect4.envs.connect4_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='muzero',
+ import_names=['lzero.policy.muzero'],
+ ),
+)
+connect4_muzero_create_config = EasyDict(connect4_muzero_create_config)
+create_config = connect4_muzero_create_config
+
+if __name__ == "__main__":
+ # Define a list of seeds for multiple runs
+ seeds = [0, 1, 2] # You can add more seed values here
+ for seed in seeds:
+ # Update exp_name to include the current seed
+ main_config.exp_name = f'data_rezero_mz/connect4_muzero_bot-mode_ns{num_simulations}_upc{update_per_collect}_brf{buffer_reanalyze_freq}_seed{seed}'
+ from lzero.entry import train_rezero
+ train_rezero([main_config, create_config], seed=seed, max_env_step=max_env_step)
\ No newline at end of file
diff --git a/zoo/board_games/connect4/envs/test_bots.py b/zoo/board_games/connect4/envs/test_bots.py
index 9189b761e..9b1fbf063 100644
--- a/zoo/board_games/connect4/envs/test_bots.py
+++ b/zoo/board_games/connect4/envs/test_bots.py
@@ -1,6 +1,7 @@
import time
import numpy as np
+import psutil
import pytest
from easydict import EasyDict
@@ -8,6 +9,11 @@
from zoo.board_games.mcts_bot import MCTSBot
+def get_memory_usage():
+ process = psutil.Process()
+ memory_info = process.memory_info()
+ return memory_info.rss
+
@pytest.mark.unittest
class TestConnect4Bot():
"""
@@ -31,7 +37,7 @@ def setup(self) -> None:
prob_expert_agent=0,
bot_action_type='rule',
screen_scaling=9,
- render_mode='image_savefile_mode',
+ render_mode= None,
prob_random_action_in_bot=0,
)
@@ -50,6 +56,8 @@ def test_mcts_bot_vs_rule_bot(self, num_simulations: int = 200) -> None:
# Repeat the game for 10 rounds.
for i in range(10):
print('-' * 10 + str(i) + '-' * 10)
+ memory_usage = get_memory_usage()
+ print(f"Initial memory usage: {memory_usage} bytes")
# Initialize the game, where there are two players: player 1 and player 2.
env = Connect4Env(EasyDict(self.cfg))
# Reset the environment, set the board to a clean board and the start player to be player 1.
@@ -61,6 +69,7 @@ def test_mcts_bot_vs_rule_bot(self, num_simulations: int = 200) -> None:
player = MCTSBot(env_mcts, 'a', num_simulations) # player_index = 0, player = 1
# Set player 1 to move first.
player_index = 0
+ step = 1
while not env.get_done_reward()[0]:
"""
Overview:
@@ -70,7 +79,7 @@ def test_mcts_bot_vs_rule_bot(self, num_simulations: int = 200) -> None:
if player_index == 0:
t1 = time.time()
# action = env.bot_action()
- action = player.get_actions(state, player_index=player_index)
+ action, node = player.get_actions(state, step, player_index=player_index)
t2 = time.time()
# print("The time difference is :", t2-t1)
mcts_bot_time_list.append(t2 - t1)
@@ -86,7 +95,13 @@ def test_mcts_bot_vs_rule_bot(self, num_simulations: int = 200) -> None:
player_index = 0
env.step(action)
state = env.board
- # print(np.array(state).reshape(6, 7))
+ step += 1
+ print(np.array(state).reshape(6, 7))
+ temp = memory_usage
+ memory_usage = get_memory_usage()
+ memory_cost = memory_usage - temp
+ print(f"Memory usage after search: {memory_usage} bytes")
+ print(f"Increased memory usage due to searches: {memory_cost} bytes")
# Record the winner.
winner.append(env.get_done_winner()[1])
@@ -115,7 +130,7 @@ def test_mcts_bot_vs_rule_bot(self, num_simulations: int = 200) -> None:
def test_mcts_bot_vs_mcts_bot(self, num_simulations_1: int = 50, num_simulations_2: int = 50) -> None:
"""
Overview:
- A tictactoe game between mcts_bot and rule_bot, where rule_bot take the first move.
+ A tictactoe game between two mcts_bots.
Arguments:
- num_simulations_1 (:obj:`int`): The number of the simulations of player 1 required to find the best move.
- num_simulations_2 (:obj:`int`): The number of the simulations of player 2 required to find the best move.
@@ -126,17 +141,21 @@ def test_mcts_bot_vs_mcts_bot(self, num_simulations_1: int = 50, num_simulations
winner = []
# Repeat the game for 10 rounds.
- for i in range(10):
+ for i in range(1):
print('-' * 10 + str(i) + '-' * 10)
+ memory_usage = get_memory_usage()
+ print(f"ๅๅงๅ
ๅญไฝฟ็จ้: {memory_usage} bytes")
# Initialize the game, where there are two players: player 1 and player 2.
env = Connect4Env(EasyDict(self.cfg))
# Reset the environment, set the board to a clean board and the start player to be player 1.
env.reset()
state = env.board
player1 = MCTSBot(env, 'a', num_simulations_1) # player_index = 0, player = 1
- player2 = MCTSBot(env, 'a', num_simulations_2)
+ player2 = MCTSBot(env, 'b', num_simulations_2)
# Set player 1 to move first.
player_index = 0
+ step = 1
+ node = None
while not env.get_done_reward()[0]:
"""
Overview:
@@ -146,7 +165,7 @@ def test_mcts_bot_vs_mcts_bot(self, num_simulations_1: int = 50, num_simulations
if player_index == 0:
t1 = time.time()
# action = env.bot_action()
- action = player1.get_actions(state, player_index=player_index)
+ action, node, visit = player1.get_actions(state, step, player_index)
t2 = time.time()
# print("The time difference is :", t2-t1)
mcts_bot1_time_list.append(t2 - t1)
@@ -155,14 +174,20 @@ def test_mcts_bot_vs_mcts_bot(self, num_simulations_1: int = 50, num_simulations
else:
t1 = time.time()
# action = env.bot_action()
- action = player2.get_actions(state, player_index=player_index)
+ action, node, visit = player2.get_actions(state, step, player_index, num_simulation=visit)
t2 = time.time()
# print("The time difference is :", t2-t1)
mcts_bot2_time_list.append(t2 - t1)
player_index = 0
env.step(action)
+ step += 1
state = env.board
- # print(np.array(state).reshape(6, 7))
+ print(np.array(state).reshape(6, 7))
+ temp = memory_usage
+ memory_usage = get_memory_usage()
+ memory_cost = memory_usage - temp
+ print(f"Memory usage after search: {memory_usage} bytes")
+ print(f"Increased memory usage due to searches: {memory_cost} bytes")
# Record the winner.
winner.append(env.get_done_winner()[1])
@@ -175,11 +200,11 @@ def test_mcts_bot_vs_mcts_bot(self, num_simulations_1: int = 50, num_simulations
mcts_bot2_var = np.var(mcts_bot2_time_list)
# Print the information of the games.
- print('num_simulations={}\n'.format(200))
+ print('num_simulations={}\n'.format(num_simulations_1))
print('mcts_bot1_time_list={}\n'.format(mcts_bot1_time_list))
print('mcts_bot1_mu={}, mcts_bot1_var={}\n'.format(mcts_bot1_mu, mcts_bot1_var))
- print('num_simulations={}\n'.format(1000))
+ print('num_simulations={}\n'.format(num_simulations_2))
print('mcts_bot2_time_list={}\n'.format(mcts_bot2_time_list))
print('mcts_bot2_mu={}, mcts_bot2_var={}\n'.format(mcts_bot2_mu, mcts_bot2_var))
@@ -204,6 +229,8 @@ def test_rule_bot_vs_rule_bot(self) -> None:
# Repeat the game for 10 rounds.
for i in range(10):
print('-' * 10 + str(i) + '-' * 10)
+ memory_usage = get_memory_usage()
+ print(f"ๅๅงๅ
ๅญไฝฟ็จ้: {memory_usage} bytes")
# Initialize the game, where there are two players: player 1 and player 2.
env = Connect4Env(EasyDict(self.cfg))
# Reset the environment, set the board to a clean board and the start player to be player 1.
@@ -234,7 +261,12 @@ def test_rule_bot_vs_rule_bot(self) -> None:
player_index = 0
env.step(action)
state = env.board
- # print(np.array(state).reshape(6, 7))
+ print(np.array(state).reshape(6, 7))
+ temp = memory_usage
+ memory_usage = get_memory_usage()
+ memory_cost = memory_usage - temp
+ print(f"Memory usage after search: {memory_usage} bytes")
+ print(f"Increased memory usage due to searches: {memory_cost} bytes")
# Record the winner.
winner.append(env.get_done_winner()[1])
@@ -258,3 +290,9 @@ def test_rule_bot_vs_rule_bot(self) -> None:
winner, winner.count(-1), winner.count(1), winner.count(2)
)
)
+
+
+if __name__ == "__main__":
+ test = TestConnect4Bot()
+ test.setup()
+ test.test_mcts_bot_vs_mcts_bot(2000,200)
\ No newline at end of file
diff --git a/zoo/board_games/gomoku/config/gomoku_gumbel_muzero_bot_mode_config.py b/zoo/board_games/gomoku/config/gomoku_gumbel_muzero_bot_mode_config.py
index 49d44e432..e9fb19301 100644
--- a/zoo/board_games/gomoku/config/gomoku_gumbel_muzero_bot_mode_config.py
+++ b/zoo/board_games/gomoku/config/gomoku_gumbel_muzero_bot_mode_config.py
@@ -21,7 +21,7 @@
gomoku_gumbel_muzero_config = dict(
exp_name=
- f'data_mz_ctree/gomoku_b{board_size}_rand{prob_random_action_in_bot}_gumbel_muzero_bot-mode_type-{bot_action_type}_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0',
+ f'data_muzero/gomoku_b{board_size}_rand{prob_random_action_in_bot}_gumbel_muzero_bot-mode_type-{bot_action_type}_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0',
env=dict(
board_size=board_size,
battle_mode='play_with_bot_mode',
diff --git a/zoo/board_games/gomoku/config/gomoku_muzero_bot_mode_config.py b/zoo/board_games/gomoku/config/gomoku_muzero_bot_mode_config.py
index a30a31437..36bf5ae5f 100644
--- a/zoo/board_games/gomoku/config/gomoku_muzero_bot_mode_config.py
+++ b/zoo/board_games/gomoku/config/gomoku_muzero_bot_mode_config.py
@@ -20,8 +20,7 @@
# ==============================================================
gomoku_muzero_config = dict(
- exp_name=
- f'data_mz_ctree/gomoku_b{board_size}_rand{prob_random_action_in_bot}_muzero_bot-mode_type-{bot_action_type}_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0',
+ exp_name=f'data_muzero/gomoku_b{board_size}_rand{prob_random_action_in_bot}_muzero_bot-mode_type-{bot_action_type}_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0',
env=dict(
board_size=board_size,
battle_mode='play_with_bot_mode',
@@ -85,5 +84,10 @@
create_config = gomoku_muzero_create_config
if __name__ == "__main__":
- from lzero.entry import train_muzero
- train_muzero([main_config, create_config], seed=0, max_env_step=max_env_step)
+ # Define a list of seeds for multiple runs
+ seeds = [0] # You can add more seed values here
+ for seed in seeds:
+ # Update exp_name to include the current seed
+ main_config.exp_name = f'data_muzero/gomoku_b{board_size}_rand{prob_random_action_in_bot}_muzero_bot-mode_type-{bot_action_type}_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed{seed}'
+ from lzero.entry import train_muzero
+ train_muzero([main_config, create_config], seed=seed, max_env_step=max_env_step)
diff --git a/zoo/board_games/gomoku/config/gomoku_muzero_sp_mode_config.py b/zoo/board_games/gomoku/config/gomoku_muzero_sp_mode_config.py
index 1fbaedb8b..7b12445f2 100644
--- a/zoo/board_games/gomoku/config/gomoku_muzero_sp_mode_config.py
+++ b/zoo/board_games/gomoku/config/gomoku_muzero_sp_mode_config.py
@@ -21,12 +21,12 @@
gomoku_muzero_config = dict(
exp_name=
- f'data_mz_ctree/gomoku_muzero_sp-mode_rand{prob_random_action_in_bot}_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0',
+ f'data_muzero/gomoku_muzero_sp-mode_rand{prob_random_action_in_bot}_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0',
env=dict(
battle_mode='self_play_mode',
bot_action_type=bot_action_type,
prob_random_action_in_bot=prob_random_action_in_bot,
- channel_last=True,
+ channel_last=False,
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
n_evaluator_episode=evaluator_env_num,
diff --git a/zoo/board_games/gomoku/config/gomoku_rezero-mz_bot_mode_config.py b/zoo/board_games/gomoku/config/gomoku_rezero-mz_bot_mode_config.py
new file mode 100644
index 000000000..031c7ef01
--- /dev/null
+++ b/zoo/board_games/gomoku/config/gomoku_rezero-mz_bot_mode_config.py
@@ -0,0 +1,100 @@
+from easydict import EasyDict
+# ==============================================================
+# begin of the most frequently changed config specified by the user
+# ==============================================================
+collector_env_num = 32
+n_episode = 32
+evaluator_env_num = 5
+num_simulations = 50
+update_per_collect = 50
+batch_size = 256
+max_env_step = int(1e6)
+board_size = 6 # default_size is 15
+bot_action_type = 'v0' # options={'v0', 'v1'}
+prob_random_action_in_bot = 0.5
+
+reuse_search = True
+collect_with_pure_policy = True
+use_priority = False
+buffer_reanalyze_freq = 1
+# ==============================================================
+# end of the most frequently changed config specified by the user
+# ==============================================================
+
+gomoku_muzero_config = dict(
+ exp_name=f'data_rezero_mz/gomoku_b{board_size}_rand{prob_random_action_in_bot}_muzero_bot-mode_type-{bot_action_type}_ns{num_simulations}_upc{update_per_collect}_brf{buffer_reanalyze_freq}_seed0',
+ env=dict(
+ board_size=board_size,
+ battle_mode='play_with_bot_mode',
+ bot_action_type=bot_action_type,
+ prob_random_action_in_bot=prob_random_action_in_bot,
+ channel_last=False,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ manager=dict(shared_memory=False, ),
+ ),
+ policy=dict(
+ model=dict(
+ observation_shape=(3, board_size, board_size),
+ action_space_size=int(board_size * board_size),
+ image_channel=3,
+ num_res_blocks=1,
+ num_channels=32,
+ support_scale=10,
+ reward_support_size=21,
+ value_support_size=21,
+ ),
+ cuda=True,
+ env_type='board_games',
+ action_type='varied_action_space',
+ game_segment_length=int(board_size * board_size / 2), # for battle_mode='play_with_bot_mode'
+ update_per_collect=update_per_collect,
+ batch_size=batch_size,
+ optim_type='Adam',
+ lr_piecewise_constant_decay=False,
+ learning_rate=0.003,
+ grad_clip_value=0.5,
+ num_simulations=num_simulations,
+ # NOTE๏ผIn board_games, we set large td_steps to make sure the value target is the final outcome.
+ td_steps=int(board_size * board_size / 2), # for battle_mode='play_with_bot_mode'
+ # NOTE๏ผIn board_games, we set discount_factor=1.
+ discount_factor=1,
+ n_episode=n_episode,
+ eval_freq=int(2e3),
+ replay_buffer_size=int(1e5),
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ reanalyze_noise=True,
+ use_priority=use_priority,
+ # ============= The key different params for ReZero =============
+ reuse_search=reuse_search,
+ collect_with_pure_policy=collect_with_pure_policy,
+ buffer_reanalyze_freq=buffer_reanalyze_freq,
+ ),
+)
+gomoku_muzero_config = EasyDict(gomoku_muzero_config)
+main_config = gomoku_muzero_config
+
+gomoku_muzero_create_config = dict(
+ env=dict(
+ type='gomoku',
+ import_names=['zoo.board_games.gomoku.envs.gomoku_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='muzero',
+ import_names=['lzero.policy.muzero'],
+ ),
+)
+gomoku_muzero_create_config = EasyDict(gomoku_muzero_create_config)
+create_config = gomoku_muzero_create_config
+
+if __name__ == "__main__":
+ # Define a list of seeds for multiple runs
+ seeds = [0] # You can add more seed values here
+ for seed in seeds:
+ # Update exp_name to include the current seed
+ main_config.exp_name = f'data_rezero-mz/gomoku_b{board_size}_rand{prob_random_action_in_bot}_rezero-mz_bot-mode_type-{bot_action_type}_ns{num_simulations}_upc{update_per_collect}_brf{buffer_reanalyze_freq}_seed{seed}'
+ from lzero.entry import train_rezero
+ train_rezero([main_config, create_config], seed=seed, max_env_step=max_env_step)
\ No newline at end of file
diff --git a/zoo/board_games/gomoku/entry/gomoku_gumbel_muzero_eval.py b/zoo/board_games/gomoku/entry/gomoku_gumbel_muzero_eval.py
index e34d1af02..5568148a2 100644
--- a/zoo/board_games/gomoku/entry/gomoku_gumbel_muzero_eval.py
+++ b/zoo/board_games/gomoku/entry/gomoku_gumbel_muzero_eval.py
@@ -8,7 +8,7 @@
point to the ckpt file of the pretrained model, and an absolute path is recommended.
In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``.
"""
- model_path = './data_mz_ctree/gomoku_gumbel_muzero_visit50_value1_purevaluenetwork_deletepriority_deletetargetsoftmax_seed0_230517_131142/ckpt/iteration_113950.pth.tar'
+ model_path = './data_muzero/gomoku_gumbel_muzero_visit50_value1_purevaluenetwork_deletepriority_deletetargetsoftmax_seed0_230517_131142/ckpt/iteration_113950.pth.tar'
seeds = [0,1,2,3,4]
num_episodes_each_seed = 10
# If True, you can play with the agent.
diff --git a/zoo/board_games/mcts_bot.py b/zoo/board_games/mcts_bot.py
index 32609de54..706c0f468 100644
--- a/zoo/board_games/mcts_bot.py
+++ b/zoo/board_games/mcts_bot.py
@@ -10,8 +10,11 @@
"""
import time
+import copy
from abc import ABC, abstractmethod
from collections import defaultdict
+from graphviz import Digraph
+import os
import numpy as np
import copy
@@ -170,7 +173,8 @@ def value(self):
- If the parent's current player is player 2, then Q-value = 5 - 10 = -5.
This way, a higher Q-value for a node indicates a higher win rate for the parent's current player.
"""
-
+ if self.parent == None:
+ return 0
# Determine the number of wins and losses based on the current player at the parent node.
wins, loses = (self._results[1], self._results[-1]) if self.parent.env.current_player == 1 else (
self._results[-1], self._results[1])
@@ -344,6 +348,11 @@ def _tree_policy(self):
else:
current_node = current_node.best_child()
return current_node
+
+ def print_tree(self, node, indent="*"):
+ print(indent + str(np.array(node.env.board).reshape(6,7)) + str(node.visit_count))
+ for child in node.children:
+ self.print_tree(child, indent + "*")
class MCTSBot:
@@ -365,7 +374,7 @@ def __init__(self, env, bot_name, num_simulation=50):
self.num_simulation = num_simulation
self.simulator_env = env
- def get_actions(self, state, player_index, best_action_type="UCB"):
+ def get_actions(self, state, step, player_index, root=None, num_simulation=None, best_action_type="UCB"):
"""
Overview:
This function gets the actions that the MCTS Bot will take.
@@ -380,8 +389,95 @@ def get_actions(self, state, player_index, best_action_type="UCB"):
"""
# Every time before make a decision, reset the environment to the current environment of the game.
self.simulator_env.reset(start_player_index=player_index, init_state=state)
- root = TwoPlayersMCTSNode(self.simulator_env)
+ if root == None:
+ root = TwoPlayersMCTSNode(self.simulator_env)
# Do the MCTS to find the best action to take.
mcts = MCTS(root)
- mcts.best_action(self.num_simulation, best_action_type=best_action_type)
- return root.best_action
+ if num_simulation == None:
+ child_node = mcts.best_action(self.num_simulation, best_action_type=best_action_type)
+ else:
+ child_node = mcts.best_action(num_simulation, best_action_type=best_action_type)
+ print(root.visit_count)
+ if step%2 == 1:
+ self.plot_simulation_graph(child_node, step)
+ else:
+ self.plot_simulation_graph(root, step)
+ # if step == 3:
+ # self.plot_simulation_graph(root, step)
+
+ return root.best_action, child_node, int(child_node.visit_count)
+
+ def obtain_tree_topology(self, root, to_play=-1):
+ node_stack = []
+ edge_topology_list = []
+ node_topology_list = []
+ node_id_list = []
+ node_stack.append(root)
+ while len(node_stack) > 0:
+ node = node_stack[-1]
+ node_stack.pop()
+ node_dict = {}
+ node_dict['node_id'] = np.array(node.env.board).reshape(6,7)
+ node_dict['visit_count'] = node.visit_count
+ # node_dict['policy_prior'] = node.prior
+ node_dict['value'] = node.value
+ node_topology_list.append(node_dict)
+
+ # node_id_list.append(node.simulation_index)
+ for child in node.children:
+ # child.parent_simulation_index = node.simulation_index
+ edge_dict = {}
+ edge_dict['parent_id'] = np.array(node.env.board).reshape(6,7)
+ edge_dict['child_id'] = np.array(child.env.board).reshape(6,7)
+ edge_topology_list.append(edge_dict)
+ node_stack.append(child)
+ return edge_topology_list, node_id_list, node_topology_list
+
+ def obtain_child_topology(self, root, to_play=-1):
+ edge_topology_list = []
+ node_topology_list = []
+ node_id_list = []
+ node = root
+ node_dict = {}
+ node_dict['node_id'] = np.array(node.env.board).reshape(6,7)
+ node_dict['visit_count'] = node.visit_count
+ # node_dict['policy_prior'] = node.prior
+ node_dict['value'] = node.value
+ node_topology_list.append(node_dict)
+
+ for child in node.children:
+ # child.parent_simulation_index = node.simulation_index
+ edge_dict = {}
+ edge_dict['parent_id'] = np.array(node.env.board).reshape(6,7)
+ edge_dict['child_id'] = np.array(child.env.board).reshape(6,7)
+ edge_topology_list.append(edge_dict)
+ node_dict = {}
+ node_dict['node_id'] = np.array(child.env.board).reshape(6,7)
+ node_dict['visit_count'] = child.visit_count
+ # node_dict['policy_prior'] = node.prior
+ node_dict['value'] = child.value
+ node_topology_list.append(node_dict)
+ return edge_topology_list, node_id_list, node_topology_list
+
+ def plot_simulation_graph(self, env_root, current_step, type="child", graph_directory=None):
+ if type == "child":
+ edge_topology_list, node_id_list, node_topology_list = self.obtain_child_topology(env_root)
+ elif type == "tree":
+ edge_topology_list, node_id_list, node_topology_list = self.obtain_tree_topology(env_root)
+ dot = Digraph(comment='this is direction')
+ for node_topology in node_topology_list:
+ node_name = str(node_topology['node_id'])
+ label = f"{node_topology['node_id']}, \n visit_count: {node_topology['visit_count']}, \n value: {round(node_topology['value'], 4)}"
+ dot.node(node_name, label=label)
+ for edge_topology in edge_topology_list:
+ parent_id = str(edge_topology['parent_id'])
+ child_id = str(edge_topology['child_id'])
+ label = parent_id + '-' + child_id
+ dot.edge(parent_id, child_id, label=None)
+ if graph_directory is None:
+ graph_directory = './data_visualize/'
+ if not os.path.exists(graph_directory):
+ os.makedirs(graph_directory)
+ graph_path = graph_directory + 'same_num_' + str(current_step) + 'step.gv'
+ dot.format = 'png'
+ dot.render(graph_path, view=False)
diff --git a/zoo/board_games/test_speed_win-rate_between_bots.py b/zoo/board_games/test_speed_win-rate_between_bots.py
index c31504420..7d9374e6d 100644
--- a/zoo/board_games/test_speed_win-rate_between_bots.py
+++ b/zoo/board_games/test_speed_win-rate_between_bots.py
@@ -9,11 +9,17 @@
import numpy as np
from easydict import EasyDict
+import psutil
from zoo.board_games.gomoku.envs.gomoku_env import GomokuEnv
from zoo.board_games.mcts_bot import MCTSBot
from zoo.board_games.tictactoe.envs.tictactoe_env import TicTacToeEnv
+def get_memory_usage():
+ process = psutil.Process()
+ memory_info = process.memory_info()
+ return memory_info.rss
+
cfg_tictactoe = dict(
battle_mode='self_play_mode',
agent_vs_human=False,
@@ -361,6 +367,8 @@ def test_tictactoe_mcts_bot_vs_alphabeta_bot(num_simulations=50):
# Repeat the game for 10 rounds.
for i in range(10):
print('-' * 10 + str(i) + '-' * 10)
+ memory_usage = get_memory_usage()
+ print(f'Memory usage at the beginning of the game: {memory_usage} bytes')
# Initialize the game, where there are two players: player 1 and player 2.
env = TicTacToeEnv(EasyDict(cfg_tictactoe))
# Reset the environment, set the board to a clean board and the start player to be player 1.
@@ -369,6 +377,7 @@ def test_tictactoe_mcts_bot_vs_alphabeta_bot(num_simulations=50):
player = MCTSBot(env, 'a', num_simulations) # player_index = 0, player = 1
# Set player 1 to move first.
player_index = 0
+ step = 1
while not env.get_done_reward()[0]:
"""
Overview:
@@ -378,7 +387,7 @@ def test_tictactoe_mcts_bot_vs_alphabeta_bot(num_simulations=50):
if player_index == 0:
t1 = time.time()
# action = env.mcts_bot()
- action = player.get_actions(state, player_index=player_index, best_action_type = "most_visit")
+ action = player.get_actions(state, step, player_index, best_action_type = "most_visit")[0]
t2 = time.time()
# print("The time difference is :", t2-t1)
# mcts_bot_time_list.append(t2 - t1)
@@ -396,10 +405,17 @@ def test_tictactoe_mcts_bot_vs_alphabeta_bot(num_simulations=50):
alphabeta_pruning_time_list.append(t2 - t1)
player_index = 0
env.step(action)
+ step += 1
state = env.board
# Print the result of the game.
if env.get_done_reward()[0]:
print(state)
+
+ temp = memory_usage
+ memory_usage = get_memory_usage()
+ memory_cost = memory_usage - temp
+ print(f'Memory usage after searching: {memory_usage} bytes')
+ print(f'Memory increase after searching: {memory_cost} bytes')
# Record the winner.
winner.append(env.get_done_winner()[1])
diff --git a/zoo/board_games/tictactoe/config/tictactoe_gumbel_muzero_bot_mode_config.py b/zoo/board_games/tictactoe/config/tictactoe_gumbel_muzero_bot_mode_config.py
index 0ef0bf9f7..6e4ee623a 100644
--- a/zoo/board_games/tictactoe/config/tictactoe_gumbel_muzero_bot_mode_config.py
+++ b/zoo/board_games/tictactoe/config/tictactoe_gumbel_muzero_bot_mode_config.py
@@ -17,7 +17,7 @@
tictactoe_gumbel_muzero_config = dict(
exp_name=
- f'data_mz_ctree/tictactoe_gumbel_muzero_bot-mode_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0',
+ f'data_muzero/tictactoe_gumbel_muzero_bot-mode_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0',
env=dict(
battle_mode='play_with_bot_mode',
collector_env_num=collector_env_num,
diff --git a/zoo/board_games/tictactoe/config/tictactoe_muzero_bot_mode_config.py b/zoo/board_games/tictactoe/config/tictactoe_muzero_bot_mode_config.py
index da4bf2c0e..82912877e 100644
--- a/zoo/board_games/tictactoe/config/tictactoe_muzero_bot_mode_config.py
+++ b/zoo/board_games/tictactoe/config/tictactoe_muzero_bot_mode_config.py
@@ -17,7 +17,7 @@
tictactoe_muzero_config = dict(
exp_name=
- f'data_mz_ctree/tictactoe_muzero_bot-mode_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0',
+ f'data_muzero/tictactoe_muzero_bot-mode_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0',
env=dict(
battle_mode='play_with_bot_mode',
collector_env_num=collector_env_num,
diff --git a/zoo/board_games/tictactoe/config/tictactoe_muzero_sp_mode_config.py b/zoo/board_games/tictactoe/config/tictactoe_muzero_sp_mode_config.py
index 546fc4c0e..981cf0786 100644
--- a/zoo/board_games/tictactoe/config/tictactoe_muzero_sp_mode_config.py
+++ b/zoo/board_games/tictactoe/config/tictactoe_muzero_sp_mode_config.py
@@ -17,7 +17,7 @@
tictactoe_muzero_config = dict(
exp_name=
- f'data_mz_ctree/tictactoe_muzero_sp-mode_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0',
+ f'data_muzero/tictactoe_muzero_sp-mode_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0',
env=dict(
battle_mode='self_play_mode',
collector_env_num=collector_env_num,
diff --git a/zoo/box2d/bipedalwalker/config/bipedalwalker_cont_disc_efficientzero_config.py b/zoo/box2d/bipedalwalker/config/bipedalwalker_cont_disc_efficientzero_config.py
index 929d5594f..ebc11c8b4 100644
--- a/zoo/box2d/bipedalwalker/config/bipedalwalker_cont_disc_efficientzero_config.py
+++ b/zoo/box2d/bipedalwalker/config/bipedalwalker_cont_disc_efficientzero_config.py
@@ -10,7 +10,7 @@
# each_dim_disc_size = 4 # thus the total discrete action number is 4**4=256
# num_simulations = 50
# update_per_collect = None
-# model_update_ratio = 0.25
+# replay_ratio = 0.25
# batch_size = 256
# max_env_step = int(5e6)
# reanalyze_ratio = 0.
@@ -20,7 +20,7 @@
bipedalwalker_cont_disc_efficientzero_config = dict(
exp_name=
- f'data_sez_ctree/bipedalwalker_cont_disc_efficientzero_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_seed0',
+ f'data_sez_ctree/bipedalwalker_cont_disc_efficientzero_ns{num_simulations}_upc{update_per_collect}-mur{replay_ratio}_rr{reanalyze_ratio}_seed0',
env=dict(
stop_value=int(1e6),
env_id='BipedalWalker-v3',
@@ -61,7 +61,7 @@
reanalyze_ratio=reanalyze_ratio,
n_episode=n_episode,
eval_freq=int(2e3),
- model_update_ratio=model_update_ratio,
+ replay_ratio=replay_ratio,
replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions.
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
diff --git a/zoo/box2d/bipedalwalker/config/bipedalwalker_cont_disc_sampled_efficientzero_config.py b/zoo/box2d/bipedalwalker/config/bipedalwalker_cont_disc_sampled_efficientzero_config.py
index f20558b3a..ecbfe1249 100644
--- a/zoo/box2d/bipedalwalker/config/bipedalwalker_cont_disc_sampled_efficientzero_config.py
+++ b/zoo/box2d/bipedalwalker/config/bipedalwalker_cont_disc_sampled_efficientzero_config.py
@@ -11,7 +11,7 @@
K = 20 # num_of_sampled_actions
num_simulations = 50
update_per_collect = None
-model_update_ratio = 0.25
+replay_ratio = 0.25
batch_size = 256
max_env_step = int(5e6)
reanalyze_ratio = 0.
@@ -21,7 +21,7 @@
bipedalwalker_cont_disc_sampled_efficientzero_config = dict(
exp_name=
- f'data_sez_ctree/bipedalwalker_cont_disc_sampled_efficientzero_k{K}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_seed0',
+ f'data_sez_ctree/bipedalwalker_cont_disc_sampled_efficientzero_k{K}_ns{num_simulations}_upc{update_per_collect}-mur{replay_ratio}_rr{reanalyze_ratio}_seed0',
env=dict(
stop_value=int(1e6),
env_id='BipedalWalker-v3',
@@ -63,7 +63,7 @@
reanalyze_ratio=reanalyze_ratio,
n_episode=n_episode,
eval_freq=int(2e3),
- model_update_ratio=model_update_ratio,
+ replay_ratio=replay_ratio,
replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions.
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
diff --git a/zoo/box2d/bipedalwalker/config/bipedalwalker_cont_sampled_efficientzero_config.py b/zoo/box2d/bipedalwalker/config/bipedalwalker_cont_sampled_efficientzero_config.py
index a61856c44..d691e9d1c 100644
--- a/zoo/box2d/bipedalwalker/config/bipedalwalker_cont_sampled_efficientzero_config.py
+++ b/zoo/box2d/bipedalwalker/config/bipedalwalker_cont_sampled_efficientzero_config.py
@@ -10,7 +10,7 @@
K = 20 # num_of_sampled_actions
num_simulations = 50
update_per_collect = None
-model_update_ratio = 0.25
+replay_ratio = 0.25
batch_size = 256
max_env_step = int(5e6)
reanalyze_ratio = 0.
@@ -20,7 +20,7 @@
bipedalwalker_cont_sampled_efficientzero_config = dict(
exp_name=
- f'data_sez_ctree/bipedalwalker_cont_sampled_efficientzero_k{K}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_seed0',
+ f'data_sez_ctree/bipedalwalker_cont_sampled_efficientzero_k{K}_ns{num_simulations}_upc{update_per_collect}-mur{replay_ratio}_rr{reanalyze_ratio}_seed0',
env=dict(
env_id='BipedalWalker-v3',
env_type='normal',
@@ -61,7 +61,7 @@
reanalyze_ratio=reanalyze_ratio,
n_episode=n_episode,
eval_freq=int(2e3),
- model_update_ratio=model_update_ratio,
+ replay_ratio=replay_ratio,
replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions.
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
diff --git a/zoo/box2d/lunarlander/config/lunarlander_disc_efficientzero_config.py b/zoo/box2d/lunarlander/config/lunarlander_disc_efficientzero_config.py
index 38e268c3e..ab274eb9e 100644
--- a/zoo/box2d/lunarlander/config/lunarlander_disc_efficientzero_config.py
+++ b/zoo/box2d/lunarlander/config/lunarlander_disc_efficientzero_config.py
@@ -1,5 +1,6 @@
from easydict import EasyDict
-
+import torch
+torch.cuda.set_device(0)
# ==============================================================
# begin of the most frequently changed config specified by the user
# ==============================================================
@@ -7,17 +8,26 @@
n_episode = 8
evaluator_env_num = 3
num_simulations = 50
-update_per_collect = 200
+# update_per_collect = 200
+update_per_collect = None
+replay_ratio = 0.25
+
batch_size = 256
-max_env_step = int(5e6)
-reanalyze_ratio = 0.
+max_env_step = int(1e6)
+# reanalyze_ratio = 0.
+# reanalyze_ratio = 1
+reanalyze_ratio = 0.99
+
+
+seed = 0
+
# ==============================================================
# end of the most frequently changed config specified by the user
# ==============================================================
lunarlander_disc_efficientzero_config = dict(
- exp_name=
- f'data_ez_ctree/lunarlander_disc_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0',
+ # exp_name=f'data_ez_ctree/lunarlander_disc_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed{seed}',
+ exp_name=f'data_ez_ctree_0129/lunarlander/ez_rr{reanalyze_ratio}_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed{seed}',
env=dict(
env_id='LunarLander-v2',
continuous=False,
@@ -42,6 +52,7 @@
env_type='not_board_games',
game_segment_length=200,
update_per_collect=update_per_collect,
+ replay_ratio=replay_ratio,
batch_size=batch_size,
optim_type='Adam',
lr_piecewise_constant_decay=False,
diff --git a/zoo/box2d/lunarlander/config/lunarlander_disc_gumbel_muzero_config.py b/zoo/box2d/lunarlander/config/lunarlander_disc_gumbel_muzero_config.py
index 3f4241685..0d8e1585d 100644
--- a/zoo/box2d/lunarlander/config/lunarlander_disc_gumbel_muzero_config.py
+++ b/zoo/box2d/lunarlander/config/lunarlander_disc_gumbel_muzero_config.py
@@ -16,7 +16,7 @@
# ==============================================================
lunarlander_gumbel_muzero_config = dict(
- exp_name=f'data_mz_ctree/lunarlander_gumbel_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0',
+ exp_name=f'data_muzero/lunarlander_gumbel_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0',
env=dict(
env_id='LunarLander-v2',
continuous=False,
diff --git a/zoo/box2d/lunarlander/config/lunarlander_disc_muzero_config.py b/zoo/box2d/lunarlander/config/lunarlander_disc_muzero_config.py
index 19b7bcd86..9c9c2b277 100644
--- a/zoo/box2d/lunarlander/config/lunarlander_disc_muzero_config.py
+++ b/zoo/box2d/lunarlander/config/lunarlander_disc_muzero_config.py
@@ -16,7 +16,7 @@
# ==============================================================
lunarlander_muzero_config = dict(
- exp_name=f'data_mz_ctree/lunarlander_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0',
+ exp_name=f'data_muzero/lunarlander_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0',
env=dict(
env_id='LunarLander-v2',
continuous=False,
diff --git a/zoo/box2d/lunarlander/config/lunarlander_disc_muzero_multi-seed_config.py b/zoo/box2d/lunarlander/config/lunarlander_disc_muzero_multi-seed_config.py
new file mode 100644
index 000000000..f13011418
--- /dev/null
+++ b/zoo/box2d/lunarlander/config/lunarlander_disc_muzero_multi-seed_config.py
@@ -0,0 +1,91 @@
+from easydict import EasyDict
+import torch
+torch.cuda.set_device(5)
+# ==============================================================
+# begin of the most frequently changed config specified by the user
+# ==============================================================
+collector_env_num = 8
+n_episode = 8
+evaluator_env_num = 3
+num_simulations = 50
+update_per_collect = 200
+batch_size = 256
+max_env_step = int(5e6)
+reanalyze_ratio = 0.
+# ==============================================================
+# end of the most frequently changed config specified by the user
+# ==============================================================
+
+lunarlander_muzero_config = dict(
+ exp_name=f'data_muzero/lunarlander_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0',
+ env=dict(
+ env_name='LunarLander-v2',
+ continuous=False,
+ manually_discretization=False,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ manager=dict(shared_memory=False, ),
+ ),
+ policy=dict(
+ model=dict(
+ observation_shape=8,
+ action_space_size=4,
+ model_type='mlp',
+ lstm_hidden_size=256,
+ latent_state_dim=256,
+ self_supervised_learning_loss=True, # NOTE: default is False.
+ discrete_action_encoding_type='one_hot',
+ res_connection_in_dynamics=True,
+ norm_type='BN',
+ ),
+ cuda=True,
+ env_type='not_board_games',
+ game_segment_length=200,
+ update_per_collect=update_per_collect,
+ batch_size=batch_size,
+ optim_type='Adam',
+ lr_piecewise_constant_decay=False,
+ learning_rate=0.003,
+ ssl_loss_weight=2, # NOTE: default is 0.
+ grad_clip_value=0.5,
+ num_simulations=num_simulations,
+ reanalyze_ratio=reanalyze_ratio,
+ n_episode=n_episode,
+ eval_freq=int(1e3),
+ replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions.
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ ),
+)
+lunarlander_muzero_config = EasyDict(lunarlander_muzero_config)
+main_config = lunarlander_muzero_config
+
+lunarlander_muzero_create_config = dict(
+ env=dict(
+ type='lunarlander',
+ import_names=['zoo.box2d.lunarlander.envs.lunarlander_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='muzero',
+ import_names=['lzero.policy.muzero'],
+ ),
+ collector=dict(
+ type='episode_muzero',
+ get_train_sample=True,
+ import_names=['lzero.worker.muzero_collector'],
+ )
+)
+lunarlander_muzero_create_config = EasyDict(lunarlander_muzero_create_config)
+create_config = lunarlander_muzero_create_config
+
+if __name__ == "__main__":
+ # Define a list of seeds for multiple runs
+ seeds = [0, 1] # You can add more seed values here
+
+ for seed in seeds:
+ # Update exp_name to include the current seed
+ main_config.exp_name = f'data_muzero_0128/lunarlander_muzero_ns{main_config.policy.num_simulations}_upc{main_config.policy.update_per_collect}_rr{main_config.policy.reanalyze_ratio}_seed{seed}'
+ from lzero.entry import train_muzero
+ train_muzero([main_config, create_config], seed=seed, max_env_step=max_env_step)
\ No newline at end of file
diff --git a/zoo/box2d/lunarlander/config/lunarlander_disc_rezero-ez_config.py b/zoo/box2d/lunarlander/config/lunarlander_disc_rezero-ez_config.py
new file mode 100644
index 000000000..1e180390d
--- /dev/null
+++ b/zoo/box2d/lunarlander/config/lunarlander_disc_rezero-ez_config.py
@@ -0,0 +1,111 @@
+from easydict import EasyDict
+# ==============================================================
+# begin of the most frequently changed config specified by the user
+# ==============================================================
+collector_env_num = 8
+n_episode = 8
+evaluator_env_num = 3
+num_simulations = 50
+update_per_collect = None
+replay_ratio = 0.25
+batch_size = 256
+max_env_step = int(1e6)
+use_priority = False
+
+reuse_search = True
+collect_with_pure_policy = True
+buffer_reanalyze_freq = 1
+# ==============================================================
+# end of the most frequently changed config specified by the user
+# ==============================================================
+
+lunarlander_muzero_config = dict(
+ exp_name=f'data_rezero-ez/lunarlander_rezero-ez_ns{num_simulations}_upc{update_per_collect}_brf{buffer_reanalyze_freq}_seed0',
+ env=dict(
+ env_name='LunarLander-v2',
+ continuous=False,
+ manually_discretization=False,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ manager=dict(shared_memory=False, ),
+ ),
+ policy=dict(
+ # NOTE: not save ckpt
+ learn=dict(
+ learner=dict(
+ train_iterations=1000000000,
+ dataloader=dict(
+ num_workers=0,
+ ),
+ log_policy=True,
+ hook=dict(
+ load_ckpt_before_run='',
+ log_show_after_iter=1000,
+ save_ckpt_after_iter=1000000,
+ save_ckpt_after_run=True,
+ ),
+ cfg_type='BaseLearnerDict',
+ ),
+ ),
+ model=dict(
+ observation_shape=8,
+ action_space_size=4,
+ model_type='mlp',
+ lstm_hidden_size=256,
+ latent_state_dim=256,
+ self_supervised_learning_loss=True,
+ discrete_action_encoding_type='one_hot',
+ res_connection_in_dynamics=True,
+ norm_type='BN',
+ ),
+ cuda=True,
+ env_type='not_board_games',
+ game_segment_length=200,
+ update_per_collect=update_per_collect,
+ replay_ratio=replay_ratio,
+ batch_size=batch_size,
+ optim_type='Adam',
+ lr_piecewise_constant_decay=False,
+ learning_rate=0.003,
+ ssl_loss_weight=2,
+ grad_clip_value=0.5,
+ num_simulations=num_simulations,
+ n_episode=n_episode,
+ eval_freq=int(1e3),
+ replay_buffer_size=int(1e6),
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ reanalyze_noise=True,
+ use_priority=use_priority,
+ # ============= The key different params for ReZero =============
+ reuse_search=reuse_search,
+ collect_with_pure_policy=collect_with_pure_policy,
+ buffer_reanalyze_freq=buffer_reanalyze_freq,
+ ),
+)
+lunarlander_muzero_config = EasyDict(lunarlander_muzero_config)
+main_config = lunarlander_muzero_config
+
+lunarlander_muzero_create_config = dict(
+ env=dict(
+ type='lunarlander',
+ import_names=['zoo.box2d.lunarlander.envs.lunarlander_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='efficientzero',
+ import_names=['lzero.policy.efficientzero'],
+ ),
+)
+lunarlander_muzero_create_config = EasyDict(lunarlander_muzero_create_config)
+create_config = lunarlander_muzero_create_config
+
+if __name__ == "__main__":
+ # Define a list of seeds for multiple runs
+ seeds = [0] # You can add more seed values here
+ for seed in seeds:
+ # Update exp_name to include the current seed
+ main_config.exp_name = f'data_rezero_ez/lunarlander_rezero-ez_ns{num_simulations}_upc{update_per_collect}_brf{buffer_reanalyze_freq}_seed{seed}'
+ from lzero.entry import train_rezero
+ train_rezero([main_config, create_config], seed=seed, max_env_step=max_env_step)
diff --git a/zoo/box2d/lunarlander/config/lunarlander_disc_rezero-mz_config.py b/zoo/box2d/lunarlander/config/lunarlander_disc_rezero-mz_config.py
new file mode 100644
index 000000000..1e99bc146
--- /dev/null
+++ b/zoo/box2d/lunarlander/config/lunarlander_disc_rezero-mz_config.py
@@ -0,0 +1,94 @@
+from easydict import EasyDict
+# ==============================================================
+# begin of the most frequently changed config specified by the user
+# ==============================================================
+collector_env_num = 8
+n_episode = 8
+evaluator_env_num = 3
+num_simulations = 50
+update_per_collect = None
+replay_ratio = 0.25
+batch_size = 256
+max_env_step = int(1e6)
+use_priority = False
+
+reuse_search = True
+collect_with_pure_policy = True
+buffer_reanalyze_freq = 1
+# ==============================================================
+# end of the most frequently changed config specified by the user
+# ==============================================================
+
+lunarlander_muzero_config = dict(
+ exp_name=f'data_rezero-mz/lunarlander_rezero-mz_ns{num_simulations}_upc{update_per_collect}_brf{buffer_reanalyze_freq}_seed0',
+ env=dict(
+ env_name='LunarLander-v2',
+ continuous=False,
+ manually_discretization=False,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ manager=dict(shared_memory=False, ),
+ ),
+ policy=dict(
+ model=dict(
+ observation_shape=8,
+ action_space_size=4,
+ model_type='mlp',
+ lstm_hidden_size=256,
+ latent_state_dim=256,
+ self_supervised_learning_loss=True,
+ discrete_action_encoding_type='one_hot',
+ res_connection_in_dynamics=True,
+ norm_type='BN',
+ ),
+ cuda=True,
+ env_type='not_board_games',
+ game_segment_length=200,
+ update_per_collect=update_per_collect,
+ batch_size=batch_size,
+ optim_type='Adam',
+ lr_piecewise_constant_decay=False,
+ learning_rate=0.003,
+ ssl_loss_weight=2,
+ grad_clip_value=0.5,
+ num_simulations=num_simulations,
+ reanalyze_ratio=0, # NOTE: for rezero, reanalyze_ratio should be 0.
+ n_episode=n_episode,
+ eval_freq=int(1e3),
+ replay_buffer_size=int(1e6),
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ reanalyze_noise=True,
+ use_priority=use_priority,
+ # ============= The key different params for ReZero =============
+ reuse_search=reuse_search,
+ collect_with_pure_policy=collect_with_pure_policy,
+ buffer_reanalyze_freq=buffer_reanalyze_freq,
+ ),
+)
+lunarlander_muzero_config = EasyDict(lunarlander_muzero_config)
+main_config = lunarlander_muzero_config
+
+lunarlander_muzero_create_config = dict(
+ env=dict(
+ type='lunarlander',
+ import_names=['zoo.box2d.lunarlander.envs.lunarlander_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='muzero',
+ import_names=['lzero.policy.muzero'],
+ ),
+)
+lunarlander_muzero_create_config = EasyDict(lunarlander_muzero_create_config)
+create_config = lunarlander_muzero_create_config
+
+if __name__ == "__main__":
+ # Define a list of seeds for multiple runs
+ seeds = [0] # You can add more seed values here
+ for seed in seeds:
+ # Update exp_name to include the current seed
+ main_config.exp_name = f'data_rezero_mz/lunarlander_rezero-mz_ns{num_simulations}_upc{update_per_collect}_brf{buffer_reanalyze_freq}_seed{seed}'
+ from lzero.entry import train_rezero
+ train_rezero([main_config, create_config], seed=seed, max_env_step=max_env_step)
\ No newline at end of file
diff --git a/zoo/bsuite/config/bsuite_muzero_config.py b/zoo/bsuite/config/bsuite_muzero_config.py
index 5c40a204b..9e5682df7 100644
--- a/zoo/bsuite/config/bsuite_muzero_config.py
+++ b/zoo/bsuite/config/bsuite_muzero_config.py
@@ -38,7 +38,7 @@
# ==============================================================
bsuite_muzero_config = dict(
- exp_name=f'data_mz_ctree/bsuite_{env_id}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed{seed}',
+ exp_name=f'data_muzero/bsuite_{env_id}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed{seed}',
env=dict(
env_id=env_id,
stop_value=int(1e6),
diff --git a/zoo/classic_control/cartpole/config/cartpole_gumbel_muzero_config.py b/zoo/classic_control/cartpole/config/cartpole_gumbel_muzero_config.py
index 9f292575a..1db5629ee 100644
--- a/zoo/classic_control/cartpole/config/cartpole_gumbel_muzero_config.py
+++ b/zoo/classic_control/cartpole/config/cartpole_gumbel_muzero_config.py
@@ -16,7 +16,7 @@
# ==============================================================
cartpole_gumbel_muzero_config = dict(
- exp_name=f'data_mz_ctree/cartpole_gumbel_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0',
+ exp_name=f'data_muzero/cartpole_gumbel_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0',
env=dict(
env_id='CartPole-v0',
continuous=False,
diff --git a/zoo/classic_control/cartpole/config/cartpole_muzero_config.py b/zoo/classic_control/cartpole/config/cartpole_muzero_config.py
index 56ee32acf..1918ca408 100644
--- a/zoo/classic_control/cartpole/config/cartpole_muzero_config.py
+++ b/zoo/classic_control/cartpole/config/cartpole_muzero_config.py
@@ -16,7 +16,7 @@
# ==============================================================
cartpole_muzero_config = dict(
- exp_name=f'data_mz_ctree/cartpole_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0',
+ exp_name=f'data_muzero/cartpole_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0',
env=dict(
env_id='CartPole-v0',
continuous=False,
diff --git a/zoo/classic_control/cartpole/config/cartpole_rezero_mz_config.py b/zoo/classic_control/cartpole/config/cartpole_rezero_mz_config.py
new file mode 100644
index 000000000..7405547f3
--- /dev/null
+++ b/zoo/classic_control/cartpole/config/cartpole_rezero_mz_config.py
@@ -0,0 +1,87 @@
+from easydict import EasyDict
+
+# ==============================================================
+# begin of the most frequently changed config specified by the user
+# ==============================================================
+collector_env_num = 8
+n_episode = 8
+evaluator_env_num = 3
+num_simulations = 25
+update_per_collect = 100
+batch_size = 256
+max_env_step = int(1e5)
+use_priority = False
+
+reuse_search = True
+collect_with_pure_policy = False
+buffer_reanalyze_freq = 1
+# ==============================================================
+# end of the most frequently changed config specified by the user
+# ==============================================================
+
+cartpole_muzero_config = dict(
+ exp_name=f'data_rezero-mz/cartpole_rezero-mz_ns{num_simulations}_upc{update_per_collect}_brf{buffer_reanalyze_freq}_seed0',
+ env=dict(
+ env_name='CartPole-v0',
+ continuous=False,
+ manually_discretization=False,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ manager=dict(shared_memory=False, ),
+ ),
+ policy=dict(
+ model=dict(
+ observation_shape=4,
+ action_space_size=2,
+ model_type='mlp',
+ lstm_hidden_size=128,
+ latent_state_dim=128,
+ self_supervised_learning_loss=True, # NOTE: default is False.
+ discrete_action_encoding_type='one_hot',
+ norm_type='BN',
+ ),
+ cuda=True,
+ env_type='not_board_games',
+ game_segment_length=50,
+ update_per_collect=update_per_collect,
+ batch_size=batch_size,
+ optim_type='Adam',
+ lr_piecewise_constant_decay=False,
+ learning_rate=0.003,
+ ssl_loss_weight=2,
+ num_simulations=num_simulations,
+ n_episode=n_episode,
+ eval_freq=int(2e2),
+ replay_buffer_size=int(1e6),
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ reanalyze_noise=True,
+ use_priority=use_priority,
+ # ============= The key different params for ReZero =============
+ reuse_search=reuse_search,
+ collect_with_pure_policy=collect_with_pure_policy,
+ buffer_reanalyze_freq=buffer_reanalyze_freq,
+ ),
+)
+
+cartpole_muzero_config = EasyDict(cartpole_muzero_config)
+main_config = cartpole_muzero_config
+
+cartpole_muzero_create_config = dict(
+ env=dict(
+ type='cartpole_lightzero',
+ import_names=['zoo.classic_control.cartpole.envs.cartpole_lightzero_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='muzero',
+ import_names=['lzero.policy.muzero'],
+ ),
+)
+cartpole_muzero_create_config = EasyDict(cartpole_muzero_create_config)
+create_config = cartpole_muzero_create_config
+
+if __name__ == "__main__":
+ from lzero.entry import train_rezero
+ train_rezero([main_config, create_config], seed=0, max_env_step=max_env_step)
\ No newline at end of file
diff --git a/zoo/classic_control/pendulum/config/pendulum_cont_disc_gumbel_muzero_config.py b/zoo/classic_control/pendulum/config/pendulum_cont_disc_gumbel_muzero_config.py
index f2c0749b6..bed3a192f 100644
--- a/zoo/classic_control/pendulum/config/pendulum_cont_disc_gumbel_muzero_config.py
+++ b/zoo/classic_control/pendulum/config/pendulum_cont_disc_gumbel_muzero_config.py
@@ -20,7 +20,7 @@
# ==============================================================
pendulum_disc_gumbel_muzero_config = dict(
- exp_name=f'data_mz_ctree/pendulum_disc_gumbel_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0',
+ exp_name=f'data_muzero/pendulum_disc_gumbel_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0',
env=dict(
env_id='Pendulum-v1',
continuous=False,
diff --git a/zoo/game_2048/config/muzero_2048_config.py b/zoo/game_2048/config/muzero_2048_config.py
index 45f1c3e05..5b7204499 100644
--- a/zoo/game_2048/config/muzero_2048_config.py
+++ b/zoo/game_2048/config/muzero_2048_config.py
@@ -21,7 +21,7 @@
# ==============================================================
atari_muzero_config = dict(
- exp_name=f'data_mz_ctree/game_2048_npct-{num_of_possible_chance_tile}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_bs{batch_size}_sslw2_seed0',
+ exp_name=f'data_muzero/game_2048_npct-{num_of_possible_chance_tile}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_bs{batch_size}_sslw2_seed0',
env=dict(
stop_value=int(1e6),
env_id=env_id,
diff --git a/zoo/memory/config/memory_muzero_config.py b/zoo/memory/config/memory_muzero_config.py
index 6b204d4f7..0477111f0 100644
--- a/zoo/memory/config/memory_muzero_config.py
+++ b/zoo/memory/config/memory_muzero_config.py
@@ -34,7 +34,7 @@
# ==============================================================
memory_muzero_config = dict(
- exp_name=f'data_mz_ctree/{env_id}_memlen-{memory_length}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_'
+ exp_name=f'data_muzero/{env_id}_memlen-{memory_length}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_'
f'collect-eps-{eps_greedy_exploration_in_collect}_temp-final-steps-{threshold_training_steps_for_final_temperature}'
f'_pelw{policy_entropy_loss_weight}_seed{seed}',
env=dict(
diff --git a/zoo/minigrid/config/minigrid_muzero_config.py b/zoo/minigrid/config/minigrid_muzero_config.py
index 304d0860c..b66b311ee 100644
--- a/zoo/minigrid/config/minigrid_muzero_config.py
+++ b/zoo/minigrid/config/minigrid_muzero_config.py
@@ -25,7 +25,7 @@
# ==============================================================
minigrid_muzero_config = dict(
- exp_name=f'data_mz_ctree/{env_id}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_'
+ exp_name=f'data_muzero/{env_id}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_'
f'collect-eps-{eps_greedy_exploration_in_collect}_temp-final-steps-{threshold_training_steps_for_final_temperature}_pelw{policy_entropy_loss_weight}_seed{seed}',
env=dict(
stop_value=int(1e6),