diff --git a/README.md b/README.md index 1b7558d5d..6372e2483 100644 --- a/README.md +++ b/README.md @@ -122,23 +122,25 @@ LightZero is a library with a [PyTorch](https://pytorch.org/) implementation of The environments and algorithms currently supported by LightZero are shown in the table below: -| Env./Algo. | AlphaZero | MuZero | EfficientZero | Sampled EfficientZero | Gumbel MuZero | Stochastic MuZero | UniZero | -|---------------| -------- | ------ |-------------| ------------------ | ---------- |----------------|---------------| -| TicTacToe | โœ” | โœ” | ๐Ÿ”’ | ๐Ÿ”’ | โœ” | ๐Ÿ”’ |โœ”| -| Gomoku | โœ” | โœ” | ๐Ÿ”’ | ๐Ÿ”’ | โœ” | ๐Ÿ”’ |โœ”| -| Connect4 | โœ” | โœ” | ๐Ÿ”’ | ๐Ÿ”’ | ๐Ÿ”’ | ๐Ÿ”’ |โœ”| -| 2048 | --- | โœ” | ๐Ÿ”’ | ๐Ÿ”’ | ๐Ÿ”’ | โœ” |โœ”| -| Chess | ๐Ÿ”’ | ๐Ÿ”’ | ๐Ÿ”’ | ๐Ÿ”’ | ๐Ÿ”’ | ๐Ÿ”’ |๐Ÿ”’| -| Go | ๐Ÿ”’ | ๐Ÿ”’ | ๐Ÿ”’ | ๐Ÿ”’ | ๐Ÿ”’ | ๐Ÿ”’ |๐Ÿ”’| -| CartPole | --- | โœ” | โœ” | โœ” | โœ” | โœ” |โœ”| -| Pendulum | --- | โœ” | โœ” | โœ” | โœ” | โœ” |๐Ÿ”’| -| LunarLander | --- | โœ” | โœ” | โœ” | โœ” | โœ” |โœ”| -| BipedalWalker | --- | โœ” | โœ” | โœ” | โœ” | ๐Ÿ”’ |๐Ÿ”’| -| Atari | --- | โœ” | โœ” | โœ” | โœ” | โœ” |โœ”| -| MuJoCo | --- | โœ” | โœ” | โœ” | ๐Ÿ”’ | ๐Ÿ”’ |๐Ÿ”’| -| MiniGrid | --- | โœ” | โœ” | โœ” | ๐Ÿ”’ | ๐Ÿ”’ |โœ”| -| Bsuite | --- | โœ” | โœ” | โœ” | ๐Ÿ”’ | ๐Ÿ”’ |โœ”| -| Memory | --- | โœ” | โœ” | โœ” | ๐Ÿ”’ | ๐Ÿ”’ |โœ”| + +| Env./Algo. | AlphaZero | MuZero | EfficientZero | Sampled EfficientZero | Gumbel MuZero | Stochastic MuZero | UniZero |ReZero | +|---------------| -------- | ------ |-------------| ------------------ | ---------- |----------------|---------------|----------------| +| TicTacToe | โœ” | โœ” | ๐Ÿ”’ | ๐Ÿ”’ | โœ” | ๐Ÿ”’ |โœ”|๐Ÿ”’ | +| Gomoku | โœ” | โœ” | ๐Ÿ”’ | ๐Ÿ”’ | โœ” | ๐Ÿ”’ |โœ”|โœ” | +| Connect4 | โœ” | โœ” | ๐Ÿ”’ | ๐Ÿ”’ | ๐Ÿ”’ | ๐Ÿ”’ |โœ”|โœ” | +| 2048 | --- | โœ” | ๐Ÿ”’ | ๐Ÿ”’ | ๐Ÿ”’ | โœ” |โœ”|๐Ÿ”’ | +| Chess | ๐Ÿ”’ | ๐Ÿ”’ | ๐Ÿ”’ | ๐Ÿ”’ | ๐Ÿ”’ | ๐Ÿ”’ |๐Ÿ”’|๐Ÿ”’ | +| Go | ๐Ÿ”’ | ๐Ÿ”’ | ๐Ÿ”’ | ๐Ÿ”’ | ๐Ÿ”’ | ๐Ÿ”’ |๐Ÿ”’|๐Ÿ”’ | +| CartPole | --- | โœ” | โœ” | โœ” | โœ” | โœ” |โœ”|โœ” | +| Pendulum | --- | โœ” | โœ” | โœ” | โœ” | โœ” |๐Ÿ”’|๐Ÿ”’ | +| LunarLander | --- | โœ” | โœ” | โœ” | โœ” | โœ” |โœ”|๐Ÿ”’ | +| BipedalWalker | --- | โœ” | โœ” | โœ” | โœ” | ๐Ÿ”’ |๐Ÿ”’|๐Ÿ”’ | +| Atari | --- | โœ” | โœ” | โœ” | โœ” | โœ” |โœ”|โœ” | +| MuJoCo | --- | โœ” | โœ” | โœ” | ๐Ÿ”’ | ๐Ÿ”’ |๐Ÿ”’|๐Ÿ”’ | +| MiniGrid | --- | โœ” | โœ” | โœ” | ๐Ÿ”’ | ๐Ÿ”’ |โœ”|๐Ÿ”’ | +| Bsuite | --- | โœ” | โœ” | โœ” | ๐Ÿ”’ | ๐Ÿ”’ |โœ”|๐Ÿ”’ | +| Memory | --- | โœ” | โœ” | โœ” | ๐Ÿ”’ | ๐Ÿ”’ |โœ”|๐Ÿ”’ | + (1): "โœ”" means that the corresponding item is finished and well-tested. diff --git a/README.zh.md b/README.zh.md index f0770c7a0..6dc96a479 100644 --- a/README.zh.md +++ b/README.zh.md @@ -110,23 +110,23 @@ LightZero ๆ˜ฏๅŸบไบŽ [PyTorch](https://pytorch.org/) ๅฎž็Žฐ็š„ MCTS ็ฎ—ๆณ•ๅบ“๏ผŒ LightZero ็›ฎๅ‰ๆ”ฏๆŒ็š„็ŽฏๅขƒๅŠ็ฎ—ๆณ•ๅฆ‚ไธ‹่กจๆ‰€็คบ๏ผš -| Env./Algo. | AlphaZero | MuZero | EfficientZero | Sampled EfficientZero | Gumbel MuZero | Stochastic MuZero | UniZero | -|---------------| -------- | ------ |-------------| ------------------ | ---------- |----------------|---------------| -| TicTacToe | โœ” | โœ” | ๐Ÿ”’ | ๐Ÿ”’ | โœ” | ๐Ÿ”’ |โœ”| -| Gomoku | โœ” | โœ” | ๐Ÿ”’ | ๐Ÿ”’ | โœ” | ๐Ÿ”’ |โœ”| -| Connect4 | โœ” | โœ” | ๐Ÿ”’ | ๐Ÿ”’ | ๐Ÿ”’ | ๐Ÿ”’ |โœ”| -| 2048 | --- | โœ” | ๐Ÿ”’ | ๐Ÿ”’ | ๐Ÿ”’ | โœ” |โœ”| -| Chess | ๐Ÿ”’ | ๐Ÿ”’ | ๐Ÿ”’ | ๐Ÿ”’ | ๐Ÿ”’ | ๐Ÿ”’ |๐Ÿ”’| -| Go | ๐Ÿ”’ | ๐Ÿ”’ | ๐Ÿ”’ | ๐Ÿ”’ | ๐Ÿ”’ | ๐Ÿ”’ |๐Ÿ”’| -| CartPole | --- | โœ” | โœ” | โœ” | โœ” | โœ” |โœ”| -| Pendulum | --- | โœ” | โœ” | โœ” | โœ” | โœ” |๐Ÿ”’| -| LunarLander | --- | โœ” | โœ” | โœ” | โœ” | โœ” |โœ”| -| BipedalWalker | --- | โœ” | โœ” | โœ” | โœ” | ๐Ÿ”’ |๐Ÿ”’| -| Atari | --- | โœ” | โœ” | โœ” | โœ” | โœ” |โœ”| -| MuJoCo | --- | โœ” | โœ” | โœ” | ๐Ÿ”’ | ๐Ÿ”’ |๐Ÿ”’| -| MiniGrid | --- | โœ” | โœ” | โœ” | ๐Ÿ”’ | ๐Ÿ”’ |โœ”| -| Bsuite | --- | โœ” | โœ” | โœ” | ๐Ÿ”’ | ๐Ÿ”’ |โœ”| -| Memory | --- | โœ” | โœ” | โœ” | ๐Ÿ”’ | ๐Ÿ”’ |โœ”| +| Env./Algo. | AlphaZero | MuZero | EfficientZero | Sampled EfficientZero | Gumbel MuZero | Stochastic MuZero | UniZero |ReZero | +|---------------| -------- | ------ |-------------| ------------------ | ---------- |----------------|---------------|----------------| +| TicTacToe | โœ” | โœ” | ๐Ÿ”’ | ๐Ÿ”’ | โœ” | ๐Ÿ”’ |โœ”|๐Ÿ”’ | +| Gomoku | โœ” | โœ” | ๐Ÿ”’ | ๐Ÿ”’ | โœ” | ๐Ÿ”’ |โœ”|โœ” | +| Connect4 | โœ” | โœ” | ๐Ÿ”’ | ๐Ÿ”’ | ๐Ÿ”’ | ๐Ÿ”’ |โœ”|โœ” | +| 2048 | --- | โœ” | ๐Ÿ”’ | ๐Ÿ”’ | ๐Ÿ”’ | โœ” |โœ”|๐Ÿ”’ | +| Chess | ๐Ÿ”’ | ๐Ÿ”’ | ๐Ÿ”’ | ๐Ÿ”’ | ๐Ÿ”’ | ๐Ÿ”’ |๐Ÿ”’|๐Ÿ”’ | +| Go | ๐Ÿ”’ | ๐Ÿ”’ | ๐Ÿ”’ | ๐Ÿ”’ | ๐Ÿ”’ | ๐Ÿ”’ |๐Ÿ”’|๐Ÿ”’ | +| CartPole | --- | โœ” | โœ” | โœ” | โœ” | โœ” |โœ”|โœ” | +| Pendulum | --- | โœ” | โœ” | โœ” | โœ” | โœ” |๐Ÿ”’|๐Ÿ”’ | +| LunarLander | --- | โœ” | โœ” | โœ” | โœ” | โœ” |โœ”|๐Ÿ”’ | +| BipedalWalker | --- | โœ” | โœ” | โœ” | โœ” | ๐Ÿ”’ |๐Ÿ”’|๐Ÿ”’ | +| Atari | --- | โœ” | โœ” | โœ” | โœ” | โœ” |โœ”|โœ” | +| MuJoCo | --- | โœ” | โœ” | โœ” | ๐Ÿ”’ | ๐Ÿ”’ |๐Ÿ”’|๐Ÿ”’ | +| MiniGrid | --- | โœ” | โœ” | โœ” | ๐Ÿ”’ | ๐Ÿ”’ |โœ”|๐Ÿ”’ | +| Bsuite | --- | โœ” | โœ” | โœ” | ๐Ÿ”’ | ๐Ÿ”’ |โœ”|๐Ÿ”’ | +| Memory | --- | โœ” | โœ” | โœ” | ๐Ÿ”’ | ๐Ÿ”’ |โœ”|๐Ÿ”’ | (1): "โœ”" ่กจ็คบๅฏนๅบ”็š„้กน็›ฎๅทฒ็ปๅฎŒๆˆๅนถ็ป่ฟ‡่‰ฏๅฅฝ็š„ๆต‹่ฏ•ใ€‚ diff --git a/lzero/agent/alphazero.py b/lzero/agent/alphazero.py index d31e17bb9..3eb265bde 100644 --- a/lzero/agent/alphazero.py +++ b/lzero/agent/alphazero.py @@ -198,9 +198,9 @@ def train( new_data = sum(new_data, []) if self.cfg.policy.update_per_collect is None: - # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio. + # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio. collected_transitions_num = len(new_data) - update_per_collect = int(collected_transitions_num * self.cfg.policy.model_update_ratio) + update_per_collect = int(collected_transitions_num * self.cfg.policy.replay_ratio) replay_buffer.push(new_data, cur_collector_envstep=collector.envstep) # Learn policy from collected data diff --git a/lzero/agent/efficientzero.py b/lzero/agent/efficientzero.py index 421cea881..b2362d9cc 100644 --- a/lzero/agent/efficientzero.py +++ b/lzero/agent/efficientzero.py @@ -228,9 +228,9 @@ def train( # Collect data by default config n_sample/n_episode. new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) if self.cfg.policy.update_per_collect is None: - # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio. + # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio. collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]]) - update_per_collect = int(collected_transitions_num * self.cfg.policy.model_update_ratio) + update_per_collect = int(collected_transitions_num * self.cfg.policy.replay_ratio) # save returned new_data collected by the collector replay_buffer.push_game_segments(new_data) # remove the oldest data if the replay buffer is full. diff --git a/lzero/agent/gumbel_muzero.py b/lzero/agent/gumbel_muzero.py index 0df583ab1..0ad0ff0fb 100644 --- a/lzero/agent/gumbel_muzero.py +++ b/lzero/agent/gumbel_muzero.py @@ -228,9 +228,9 @@ def train( # Collect data by default config n_sample/n_episode. new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) if self.cfg.policy.update_per_collect is None: - # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio. + # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio. collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]]) - update_per_collect = int(collected_transitions_num * self.cfg.policy.model_update_ratio) + update_per_collect = int(collected_transitions_num * self.cfg.policy.replay_ratio) # save returned new_data collected by the collector replay_buffer.push_game_segments(new_data) # remove the oldest data if the replay buffer is full. diff --git a/lzero/agent/muzero.py b/lzero/agent/muzero.py index 55dda5d00..7a77996b5 100644 --- a/lzero/agent/muzero.py +++ b/lzero/agent/muzero.py @@ -228,9 +228,9 @@ def train( # Collect data by default config n_sample/n_episode. new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) if self.cfg.policy.update_per_collect is None: - # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio. + # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio. collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]]) - update_per_collect = int(collected_transitions_num * self.cfg.policy.model_update_ratio) + update_per_collect = int(collected_transitions_num * self.cfg.policy.replay_ratio) # save returned new_data collected by the collector replay_buffer.push_game_segments(new_data) # remove the oldest data if the replay buffer is full. diff --git a/lzero/agent/sampled_alphazero.py b/lzero/agent/sampled_alphazero.py index dc76c16e5..2f762c352 100644 --- a/lzero/agent/sampled_alphazero.py +++ b/lzero/agent/sampled_alphazero.py @@ -198,9 +198,9 @@ def train( new_data = sum(new_data, []) if self.cfg.policy.update_per_collect is None: - # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio. + # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio. collected_transitions_num = len(new_data) - update_per_collect = int(collected_transitions_num * self.cfg.policy.model_update_ratio) + update_per_collect = int(collected_transitions_num * self.cfg.policy.replay_ratio) replay_buffer.push(new_data, cur_collector_envstep=collector.envstep) # Learn policy from collected data diff --git a/lzero/agent/sampled_efficientzero.py b/lzero/agent/sampled_efficientzero.py index 079bdd11d..19542120e 100644 --- a/lzero/agent/sampled_efficientzero.py +++ b/lzero/agent/sampled_efficientzero.py @@ -228,9 +228,9 @@ def train( # Collect data by default config n_sample/n_episode. new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) if self.cfg.policy.update_per_collect is None: - # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio. + # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio. collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]]) - update_per_collect = int(collected_transitions_num * self.cfg.policy.model_update_ratio) + update_per_collect = int(collected_transitions_num * self.cfg.policy.replay_ratio) # save returned new_data collected by the collector replay_buffer.push_game_segments(new_data) # remove the oldest data if the replay buffer is full. diff --git a/lzero/entry/__init__.py b/lzero/entry/__init__.py index 9b50ed6ec..37928cd10 100644 --- a/lzero/entry/__init__.py +++ b/lzero/entry/__init__.py @@ -4,4 +4,5 @@ from .train_muzero_with_reward_model import train_muzero_with_reward_model from .eval_muzero import eval_muzero from .eval_muzero_with_gym_env import eval_muzero_with_gym_env -from .train_muzero_with_gym_env import train_muzero_with_gym_env \ No newline at end of file +from .train_muzero_with_gym_env import train_muzero_with_gym_env +from .train_rezero import train_rezero diff --git a/lzero/entry/train_alphazero.py b/lzero/entry/train_alphazero.py index 3b455adb1..b77d69a75 100644 --- a/lzero/entry/train_alphazero.py +++ b/lzero/entry/train_alphazero.py @@ -119,9 +119,9 @@ def train_alphazero( new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) new_data = sum(new_data, []) if cfg.policy.update_per_collect is None: - # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio. + # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio. collected_transitions_num = len(new_data) - update_per_collect = int(collected_transitions_num * cfg.policy.model_update_ratio) + update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio) replay_buffer.push(new_data, cur_collector_envstep=collector.envstep) # Learn policy from collected data diff --git a/lzero/entry/train_muzero.py b/lzero/entry/train_muzero.py index 55ff5cb70..f8f70f80d 100644 --- a/lzero/entry/train_muzero.py +++ b/lzero/entry/train_muzero.py @@ -8,12 +8,12 @@ from ding.envs import create_env_manager from ding.envs import get_vec_env_setting from ding.policy import create_policy -from ding.utils import set_pkg_seed, get_rank from ding.rl_utils import get_epsilon_greedy_fn +from ding.utils import set_pkg_seed, get_rank from ding.worker import BaseLearner from tensorboardX import SummaryWriter -from lzero.entry.utils import log_buffer_memory_usage +from lzero.entry.utils import log_buffer_memory_usage, log_buffer_run_time from lzero.policy import visit_count_temperature from lzero.policy.random_policy import LightZeroRandomPolicy from lzero.worker import MuZeroCollector as Collector @@ -69,7 +69,6 @@ def train_muzero( cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) # Create main components: env, policy env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) - collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) @@ -138,6 +137,7 @@ def train_muzero( while True: log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger) + log_buffer_run_time(learner.train_iter, replay_buffer, tb_logger) collect_kwargs = {} # set temperature for visit count distributions according to the train_iter, # please refer to Appendix D in MuZero paper for details. @@ -172,9 +172,9 @@ def train_muzero( # Collect data by default config n_sample/n_episode. new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) if cfg.policy.update_per_collect is None: - # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio. + # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio. collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]]) - update_per_collect = int(collected_transitions_num * cfg.policy.model_update_ratio) + update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio) # save returned new_data collected by the collector replay_buffer.push_game_segments(new_data) # remove the oldest data if the replay buffer is full. diff --git a/lzero/entry/train_muzero_with_gym_env.py b/lzero/entry/train_muzero_with_gym_env.py index 3aa3906b9..f2ec552d8 100644 --- a/lzero/entry/train_muzero_with_gym_env.py +++ b/lzero/entry/train_muzero_with_gym_env.py @@ -136,9 +136,9 @@ def train_muzero_with_gym_env( # Collect data by default config n_sample/n_episode. new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) if cfg.policy.update_per_collect is None: - # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio. + # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio. collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]]) - update_per_collect = int(collected_transitions_num * cfg.policy.model_update_ratio) + update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio) # save returned new_data collected by the collector replay_buffer.push_game_segments(new_data) # remove the oldest data if the replay buffer is full. diff --git a/lzero/entry/train_muzero_with_reward_model.py b/lzero/entry/train_muzero_with_reward_model.py index 2ae409601..20c9d05b5 100644 --- a/lzero/entry/train_muzero_with_reward_model.py +++ b/lzero/entry/train_muzero_with_reward_model.py @@ -171,9 +171,9 @@ def train_muzero_with_reward_model( reward_model.clear_old_data() if cfg.policy.update_per_collect is None: - # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio. + # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio. collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]]) - update_per_collect = int(collected_transitions_num * cfg.policy.model_update_ratio) + update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio) # save returned new_data collected by the collector replay_buffer.push_game_segments(new_data) # remove the oldest data if the replay buffer is full. diff --git a/lzero/entry/train_rezero.py b/lzero/entry/train_rezero.py new file mode 100644 index 000000000..565b5b444 --- /dev/null +++ b/lzero/entry/train_rezero.py @@ -0,0 +1,232 @@ +import logging +import os +from functools import partial +from typing import Optional, Tuple + +import torch +from ding.config import compile_config +from ding.envs import create_env_manager, get_vec_env_setting +from ding.policy import create_policy +from ding.rl_utils import get_epsilon_greedy_fn +from ding.utils import set_pkg_seed, get_rank +from ding.worker import BaseLearner +from tensorboardX import SummaryWriter + +from lzero.entry.utils import log_buffer_memory_usage, log_buffer_run_time +from lzero.policy import visit_count_temperature +from lzero.policy.random_policy import LightZeroRandomPolicy +from lzero.worker import MuZeroCollector as Collector +from lzero.worker import MuZeroEvaluator as Evaluator +from .utils import random_collect + + +def train_rezero( + input_cfg: Tuple[dict, dict], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), +) -> 'Policy': + """ + Train entry for ReZero algorithms (ReZero-MuZero, ReZero-EfficientZero). + + Args: + - input_cfg (:obj:`Tuple[dict, dict]`): Configuration dictionaries (user_config, create_cfg). + - seed (:obj:`int`): Random seed for reproducibility. + - model (:obj:`Optional[torch.nn.Module]`): Pre-initialized model instance. + - model_path (:obj:`Optional[str]`): Path to pretrained model checkpoint. + - max_train_iter (:obj:`Optional[int]`): Maximum number of training iterations. + - max_env_step (:obj:`Optional[int]`): Maximum number of environment steps. + + Returns: + - Policy: Trained policy object. + """ + cfg, create_cfg = input_cfg + assert create_cfg.policy.type in ['efficientzero', 'muzero'], \ + "train_rezero entry only supports 'efficientzero' and 'muzero' algorithms" + + # Import appropriate GameBuffer based on policy type + if create_cfg.policy.type == 'muzero': + from lzero.mcts import ReZeroMZGameBuffer as GameBuffer + elif create_cfg.policy.type == 'efficientzero': + from lzero.mcts import ReZeroEZGameBuffer as GameBuffer + + # Set device (CUDA if available and enabled, otherwise CPU) + cfg.policy.device = 'cuda' if cfg.policy.cuda and torch.cuda.is_available() else 'cpu' + + # Compile and finalize configuration + cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + + # Create environment, policy, and core components + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + + # Set seeds for reproducibility + collector_env.seed(cfg.seed) + evaluator_env.seed(cfg.seed, dynamic_seed=False) + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + # Adjust checkpoint saving frequency for offline evaluation + if cfg.policy.eval_offline: + cfg.policy.learn.learner.hook.save_ckpt_after_iter = cfg.policy.eval_freq + + # Create and initialize policy + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + if model_path: + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) + + # Initialize worker components + tb_logger = SummaryWriter(os.path.join(f'./{cfg.exp_name}/log/', 'serial')) if get_rank() == 0 else None + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + replay_buffer = GameBuffer(cfg.policy) + collector = Collector( + env=collector_env, + policy=policy.collect_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=cfg.policy + ) + evaluator = Evaluator( + eval_freq=cfg.policy.eval_freq, + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=cfg.policy + ) + + # Main training loop + learner.call_hook('before_run') + update_per_collect = cfg.policy.update_per_collect + + # Perform initial random data collection if specified + if cfg.policy.random_collect_episode_num > 0: + random_collect(cfg.policy, policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer) + + # Initialize offline evaluation tracking if enabled + if cfg.policy.eval_offline: + eval_train_iter_list, eval_train_envstep_list = [], [] + + # Evaluate initial random agent + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + + buffer_reanalyze_count = 0 + train_epoch = 0 + while True: + # Log buffer metrics + log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger) + log_buffer_run_time(learner.train_iter, replay_buffer, tb_logger) + + # Prepare collection parameters + collect_kwargs = { + 'temperature': visit_count_temperature( + cfg.policy.manual_temperature_decay, + cfg.policy.fixed_temperature_value, + cfg.policy.threshold_training_steps_for_final_temperature, + trained_steps=learner.train_iter + ), + 'epsilon': get_epsilon_greedy_fn( + cfg.policy.eps.start, cfg.policy.eps.end, + cfg.policy.eps.decay, cfg.policy.eps.type + )(collector.envstep) if cfg.policy.eps.eps_greedy_exploration_in_collect else 0.0 + } + + # Periodic evaluation + if evaluator.should_eval(learner.train_iter): + if cfg.policy.eval_offline: + eval_train_iter_list.append(learner.train_iter) + eval_train_envstep_list.append(collector.envstep) + else: + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + if stop: + break + + # Collect new data + new_data = collector.collect( + train_iter=learner.train_iter, + policy_kwargs=collect_kwargs, + collect_with_pure_policy=cfg.policy.collect_with_pure_policy + ) + + # Update collection frequency if not specified + if update_per_collect is None: + collected_transitions = sum(len(segment) for segment in new_data[0]) + update_per_collect = int(collected_transitions * cfg.policy.replay_ratio) + + # Update replay buffer + replay_buffer.push_game_segments(new_data) + replay_buffer.remove_oldest_data_to_fit() + + # Periodically reanalyze buffer + if cfg.policy.buffer_reanalyze_freq >= 1: + # Reanalyze buffer times in one train_epoch + reanalyze_interval = update_per_collect // cfg.policy.buffer_reanalyze_freq + else: + # Reanalyze buffer each <1/buffer_reanalyze_freq> train_epoch + if train_epoch % (1//cfg.policy.buffer_reanalyze_freq) == 0 and replay_buffer.get_num_of_transitions() > 2000: + # When reanalyzing the buffer, the samples in the entire buffer are processed in mini-batches with a batch size of 2000. + # This is an empirically selected value for optimal efficiency. + replay_buffer.reanalyze_buffer(2000, policy) + buffer_reanalyze_count += 1 + logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}') + + # Training loop + for i in range(update_per_collect): + if cfg.policy.buffer_reanalyze_freq >= 1: + # Reanalyze buffer times in one train_epoch + if i % reanalyze_interval == 0 and replay_buffer.get_num_of_transitions() > 2000: + # When reanalyzing the buffer, the samples in the entire buffer are processed in mini-batches with a batch size of 2000. + # This is an empirically selected value for optimal efficiency. + replay_buffer.reanalyze_buffer(2000, policy) + buffer_reanalyze_count += 1 + logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}') + + # Sample and train on mini-batch + if replay_buffer.get_num_of_transitions() > cfg.policy.batch_size: + train_data = replay_buffer.sample(cfg.policy.batch_size, policy) + log_vars = learner.train(train_data, collector.envstep) + + if cfg.policy.use_priority: + replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig']) + else: + logging.warning('Insufficient data in replay buffer for sampling. Continuing collection...') + break + + train_epoch += 1 + # Check termination conditions + if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: + if cfg.policy.eval_offline: + perform_offline_evaluation(cfg, learner, policy, evaluator, eval_train_iter_list, + eval_train_envstep_list) + break + + learner.call_hook('after_run') + return policy + + +def perform_offline_evaluation(cfg, learner, policy, evaluator, eval_train_iter_list, eval_train_envstep_list): + """ + Perform offline evaluation of the trained model. + + Args: + cfg (dict): Configuration dictionary. + learner (BaseLearner): Learner object. + policy (Policy): Policy object. + evaluator (Evaluator): Evaluator object. + eval_train_iter_list (list): List of training iterations for evaluation. + eval_train_envstep_list (list): List of environment steps for evaluation. + """ + logging.info('Starting offline evaluation...') + ckpt_dirname = f'./{learner.exp_name}/ckpt' + + for train_iter, collector_envstep in zip(eval_train_iter_list, eval_train_envstep_list): + ckpt_path = os.path.join(ckpt_dirname, f'iteration_{train_iter}.pth.tar') + policy.learn_mode.load_state_dict(torch.load(ckpt_path, map_location=cfg.policy.device)) + stop, reward = evaluator.eval(learner.save_checkpoint, train_iter, collector_envstep) + logging.info(f'Offline eval at iter: {train_iter}, steps: {collector_envstep}, reward: {reward}') + + logging.info('Offline evaluation completed') \ No newline at end of file diff --git a/lzero/entry/utils.py b/lzero/entry/utils.py index dfcbbf559..04d80f264 100644 --- a/lzero/entry/utils.py +++ b/lzero/entry/utils.py @@ -1,9 +1,9 @@ import os +from typing import Optional, Callable import psutil from pympler.asizeof import asizeof from tensorboardX import SummaryWriter -from typing import Optional, Callable def random_collect( @@ -26,7 +26,8 @@ def random_collect( collect_kwargs = {'temperature': 1, 'epsilon': 0.0} # Collect data by default config n_sample/n_episode. - new_data = collector.collect(n_episode=policy_cfg.random_collect_episode_num, train_iter=0, policy_kwargs=collect_kwargs) + new_data = collector.collect(n_episode=policy_cfg.random_collect_episode_num, train_iter=0, + policy_kwargs=collect_kwargs) if postprocess_data_fn is not None: new_data = postprocess_data_fn(new_data) @@ -75,3 +76,41 @@ def log_buffer_memory_usage(train_iter: int, buffer: "GameBuffer", writer: Summa # Record the memory usage of the process to TensorBoard. writer.add_scalar('Buffer/memory_usage/process', process_memory_usage_mb, train_iter) + + +def log_buffer_run_time(train_iter: int, buffer: "GameBuffer", writer: SummaryWriter) -> None: + """ + Overview: + Log the average runtime metrics of the buffer to TensorBoard. + Arguments: + - train_iter (:obj:`int`): The current training iteration. + - buffer (:obj:`GameBuffer`): The game buffer containing runtime metrics. + - writer (:obj:`SummaryWriter`): The TensorBoard writer for logging metrics. + + .. note:: + "writer is None" indicates that the function is being called in a slave process in the DDP setup. + """ + if writer is not None: + sample_times = buffer.sample_times + + if sample_times == 0: + return + + # Calculate and log average reanalyze time. + average_reanalyze_time = buffer.compute_target_re_time / sample_times + writer.add_scalar('Buffer/average_reanalyze_time', average_reanalyze_time, train_iter) + + # Calculate and log average origin search time. + average_origin_search_time = buffer.origin_search_time / sample_times + writer.add_scalar('Buffer/average_origin_search_time', average_origin_search_time, train_iter) + + # Calculate and log average reuse search time. + average_reuse_search_time = buffer.reuse_search_time / sample_times + writer.add_scalar('Buffer/average_reuse_search_time', average_reuse_search_time, train_iter) + + # Calculate and log average active root number. + average_active_root_num = buffer.active_root_num / sample_times + writer.add_scalar('Buffer/average_active_root_num', average_active_root_num, train_iter) + + # Reset the time records in the buffer. + buffer.reset_runtime_metrics() diff --git a/lzero/mcts/buffer/__init__.py b/lzero/mcts/buffer/__init__.py index 31680a75e..9e3c910a0 100644 --- a/lzero/mcts/buffer/__init__.py +++ b/lzero/mcts/buffer/__init__.py @@ -3,3 +3,5 @@ from .game_buffer_sampled_efficientzero import SampledEfficientZeroGameBuffer from .game_buffer_gumbel_muzero import GumbelMuZeroGameBuffer from .game_buffer_stochastic_muzero import StochasticMuZeroGameBuffer +from .game_buffer_rezero_mz import ReZeroMZGameBuffer +from .game_buffer_rezero_ez import ReZeroEZGameBuffer diff --git a/lzero/mcts/buffer/game_buffer.py b/lzero/mcts/buffer/game_buffer.py index ec594c98e..81bf7b88c 100644 --- a/lzero/mcts/buffer/game_buffer.py +++ b/lzero/mcts/buffer/game_buffer.py @@ -30,7 +30,7 @@ def default_config(cls: type) -> EasyDict: # (int) The size/capacity of the replay buffer in terms of transitions. replay_buffer_size=int(1e6), # (float) The ratio of experiences required for the reanalyzing part in a minibatch. - reanalyze_ratio=0.3, + reanalyze_ratio=0, # (bool) Whether to consider outdated experiences for reanalyzing. If True, we first sort the data in the minibatch by the time it was produced # and only reanalyze the oldest ``reanalyze_ratio`` fraction. reanalyze_outdated=True, @@ -124,6 +124,8 @@ def _sample_orig_data(self, batch_size: int) -> Tuple: # sample according to transition index # TODO(pu): replace=True + # print(f"num transitions is {num_of_transitions}") + # print(f"length of probs is {len(probs)}") batch_index_list = np.random.choice(num_of_transitions, batch_size, p=probs, replace=False) if self._cfg.reanalyze_outdated is True: @@ -148,9 +150,50 @@ def _sample_orig_data(self, batch_size: int) -> Tuple: orig_data = (game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time) return orig_data + + def _sample_orig_reanalyze_data(self, batch_size: int) -> Tuple: + """ + Overview: + sample orig_data that contains: + game_segment_list: a list of game segments + pos_in_game_segment_list: transition index in game (relative index) + batch_index_list: the index of start transition of sampled minibatch in replay buffer + weights_list: the weight concerning the priority + make_time: the time the batch is made (for correctly updating replay buffer when data is deleted) + Arguments: + - batch_size (:obj:`int`): batch size + - beta: float the parameter in PER for calculating the priority + """ + segment_length = (self.get_num_of_transitions()//2000) + assert self._beta > 0 + num_of_transitions = self.get_num_of_transitions() + sample_points = num_of_transitions // segment_length + + batch_index_list = np.random.choice(2000, batch_size, replace=False) + + if self._cfg.reanalyze_outdated is True: + # NOTE: used in reanalyze part + batch_index_list.sort() + + # TODO(xcy): use weighted sample + game_segment_list = [] + pos_in_game_segment_list = [] + + for idx in batch_index_list: + game_segment_idx, pos_in_game_segment = self.game_segment_game_pos_look_up[idx*segment_length] + game_segment_idx -= self.base_idx + game_segment = self.game_segment_buffer[game_segment_idx] + + game_segment_list.append(game_segment) + pos_in_game_segment_list.append(pos_in_game_segment) + + make_time = [time.time() for _ in range(len(batch_index_list))] + + orig_data = (game_segment_list, pos_in_game_segment_list, batch_index_list, [], make_time) + return orig_data def _preprocess_to_play_and_action_mask( - self, game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list + self, game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list, unroll_steps = None ): """ Overview: @@ -158,15 +201,17 @@ def _preprocess_to_play_and_action_mask( - to_play: {list: game_segment_batch_size * (num_unroll_steps+1)} - action_mask: {list: game_segment_batch_size * (num_unroll_steps+1)} """ + unroll_steps = unroll_steps if unroll_steps is not None else self._cfg.num_unroll_steps + to_play = [] for bs in range(game_segment_batch_size): to_play_tmp = list( to_play_segment[bs][pos_in_game_segment_list[bs]:pos_in_game_segment_list[bs] + - self._cfg.num_unroll_steps + 1] + unroll_steps + 1] ) - if len(to_play_tmp) < self._cfg.num_unroll_steps + 1: + if len(to_play_tmp) < unroll_steps + 1: # NOTE: the effective to play index is {1,2}, for null padding data, we set to_play=-1 - to_play_tmp += [-1 for _ in range(self._cfg.num_unroll_steps + 1 - len(to_play_tmp))] + to_play_tmp += [-1 for _ in range(unroll_steps + 1 - len(to_play_tmp))] to_play.append(to_play_tmp) to_play = sum(to_play, []) @@ -178,12 +223,12 @@ def _preprocess_to_play_and_action_mask( for bs in range(game_segment_batch_size): action_mask_tmp = list( action_mask_segment[bs][pos_in_game_segment_list[bs]:pos_in_game_segment_list[bs] + - self._cfg.num_unroll_steps + 1] + unroll_steps + 1] ) - if len(action_mask_tmp) < self._cfg.num_unroll_steps + 1: + if len(action_mask_tmp) < unroll_steps + 1: action_mask_tmp += [ list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) - for _ in range(self._cfg.num_unroll_steps + 1 - len(action_mask_tmp)) + for _ in range(unroll_steps + 1 - len(action_mask_tmp)) ] action_mask.append(action_mask_tmp) action_mask = to_list(action_mask) @@ -334,6 +379,7 @@ def _push_game_segment(self, data: Any, meta: Optional[dict] = None) -> None: valid_len = len(data) else: valid_len = len(data) - meta['unroll_plus_td_steps'] + # print(f'valid_len is {valid_len}') if meta['priorities'] is None: max_prio = self.game_pos_priorities.max() if self.game_segment_buffer else 1 @@ -354,6 +400,8 @@ def _push_game_segment(self, data: Any, meta: Optional[dict] = None) -> None: self.game_segment_game_pos_look_up += [ (self.base_idx + len(self.game_segment_buffer) - 1, step_pos) for step_pos in range(len(data)) ] + # print(f'potioritys is {self.game_pos_priorities}') + # print(f'num of transitions is {len(self.game_segment_game_pos_look_up)}') def remove_oldest_data_to_fit(self) -> None: """ diff --git a/lzero/mcts/buffer/game_buffer_muzero.py b/lzero/mcts/buffer/game_buffer_muzero.py index 873fdbdf8..420eb5922 100644 --- a/lzero/mcts/buffer/game_buffer_muzero.py +++ b/lzero/mcts/buffer/game_buffer_muzero.py @@ -2,7 +2,8 @@ import numpy as np import torch -from ding.utils import BUFFER_REGISTRY +from ding.utils import BUFFER_REGISTRY, EasyTimer +# from line_profiler import line_profiler from lzero.mcts.tree_search.mcts_ctree import MuZeroMCTSCtree as MCTSCtree from lzero.mcts.tree_search.mcts_ptree import MuZeroMCTSPtree as MCTSPtree @@ -49,6 +50,28 @@ def __init__(self, cfg: dict): self.game_pos_priorities = [] self.game_segment_game_pos_look_up = [] + self._compute_target_timer = EasyTimer() + self._reuse_search_timer = EasyTimer() + self._origin_search_timer = EasyTimer() + self.buffer_reanalyze = False + + self.compute_target_re_time = 0 + self.reuse_search_time = 0 + self.origin_search_time = 0 + self.sample_times = 0 + self.active_root_num = 0 + + def reset_runtime_metrics(self): + """ + Overview: + Reset the runtime metrics of the buffer. + """ + self.compute_target_re_time = 0 + self.reuse_search_time = 0 + self.origin_search_time = 0 + self.sample_times = 0 + self.active_root_num = 0 + def sample( self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"] ) -> List[Any]: @@ -72,8 +95,10 @@ def sample( batch_rewards, batch_target_values = self._compute_target_reward_value( reward_value_context, policy._target_model ) - # target policy - batch_target_policies_re = self._compute_target_policy_reanalyzed(policy_re_context, policy._target_model) + with self._compute_target_timer: + batch_target_policies_re = self._compute_target_policy_reanalyzed(policy_re_context, policy._target_model) + self.compute_target_re_time += self._compute_target_timer.value + batch_target_policies_non_re = self._compute_target_policy_non_reanalyzed( policy_non_re_context, self._cfg.model.action_space_size ) @@ -90,6 +115,8 @@ def sample( # a batch contains the current_batch and the target_batch train_data = [current_batch, target_batch] + if not self.buffer_reanalyze: + self.sample_times += 1 return train_data def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: @@ -488,6 +515,7 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A batch_target_values = np.asarray(batch_target_values, dtype=object) return batch_rewards, batch_target_values + # @profile def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: Any) -> np.ndarray: """ Overview: @@ -549,7 +577,6 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: _, reward_pool, policy_logits_pool, latent_state_roots = concat_output(network_output, data_type='muzero') reward_pool = reward_pool.squeeze().tolist() policy_logits_pool = policy_logits_pool.tolist() - # noises are not necessary for reanalyze noises = [ np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self._cfg.model.action_space_size ).astype(np.float32).tolist() for _ in range(transition_batch_size) @@ -562,7 +589,9 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: else: roots.prepare_no_noise(reward_pool, policy_logits_pool, to_play) # do MCTS for a new policy with the recent target model - MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play) + with self._origin_search_timer: + MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play) + self.origin_search_time += self._origin_search_timer.value else: # python mcts_tree roots = MCTSPtree.roots(transition_batch_size, legal_actions) diff --git a/lzero/mcts/buffer/game_buffer_rezero_ez.py b/lzero/mcts/buffer/game_buffer_rezero_ez.py new file mode 100644 index 000000000..fdfae46df --- /dev/null +++ b/lzero/mcts/buffer/game_buffer_rezero_ez.py @@ -0,0 +1,329 @@ +from typing import Any, List, Union, TYPE_CHECKING + +import numpy as np +import torch +from ding.utils import BUFFER_REGISTRY, EasyTimer + +from lzero.mcts.tree_search.mcts_ctree import EfficientZeroMCTSCtree as MCTSCtree +from lzero.mcts.utils import prepare_observation +from lzero.policy import to_detach_cpu_numpy, concat_output, inverse_scalar_transform +from .game_buffer_efficientzero import EfficientZeroGameBuffer +from .game_buffer_rezero_mz import ReZeroMZGameBuffer, compute_all_filters + +# from line_profiler import line_profiler + +if TYPE_CHECKING: + from lzero.policy import MuZeroPolicy, EfficientZeroPolicy, SampledEfficientZeroPolicy + + +@BUFFER_REGISTRY.register('game_buffer_rezero_ez') +class ReZeroEZGameBuffer(EfficientZeroGameBuffer, ReZeroMZGameBuffer): + """ + Overview: + The specific game buffer for ReZero-EfficientZero policy. + """ + + def __init__(self, cfg: dict): + """ + Overview: + Initialize the ReZeroEZGameBuffer with the given configuration. If a user passes in a cfg with a key that matches an existing key + in the default configuration, the user-provided value will override the default configuration. Otherwise, + the default configuration will be used. + """ + super().__init__(cfg) + + # Update the default configuration with the provided configuration. + default_config = self.default_config() + default_config.update(cfg) + self._cfg = default_config + + # Ensure the configuration values are valid. + assert self._cfg.env_type in ['not_board_games', 'board_games'] + assert self._cfg.action_type in ['fixed_action_space', 'varied_action_space'] + + # Initialize various parameters from the configuration. + self.replay_buffer_size = self._cfg.replay_buffer_size + self.batch_size = self._cfg.batch_size + self._alpha = self._cfg.priority_prob_alpha + self._beta = self._cfg.priority_prob_beta + + self.keep_ratio = 1 + self.model_update_interval = 10 + self.num_of_collected_episodes = 0 + self.base_idx = 0 + self.clear_time = 0 + + self.game_segment_buffer = [] + self.game_pos_priorities = [] + self.game_segment_game_pos_look_up = [] + + # Timers for performance monitoring + self._compute_target_timer = EasyTimer() + self._reuse_search_timer = EasyTimer() + self._origin_search_timer = EasyTimer() + self.buffer_reanalyze = True + + # Performance metrics + self.compute_target_re_time = 0 + self.reuse_search_time = 0 + self.origin_search_time = 0 + self.sample_times = 0 + self.active_root_num = 0 + self.average_infer = 0 + + def sample( + self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"] + ) -> List[Any]: + """ + Overview: + Sample data from the GameBuffer and prepare the current and target batch for training. + Arguments: + - batch_size (int): Batch size. + - policy (Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"]): Policy. + Returns: + - train_data (List): List of train data, including current_batch and target_batch. + """ + policy._target_model.to(self._cfg.device) + policy._target_model.eval() + + # Obtain the current_batch and prepare target context + reward_value_context, policy_re_context, policy_non_re_context, current_batch = self._make_batch( + batch_size, self._cfg.reanalyze_ratio + ) + + # Compute target values and policies + batch_value_prefixs, batch_target_values = self._compute_target_reward_value( + reward_value_context, policy._target_model + ) + + with self._compute_target_timer: + batch_target_policies_re = self._compute_target_policy_reanalyzed(policy_re_context, policy._target_model) + self.compute_target_re_time += self._compute_target_timer.value + + batch_target_policies_non_re = self._compute_target_policy_non_reanalyzed( + policy_non_re_context, self._cfg.model.action_space_size + ) + + # Fuse reanalyzed and non-reanalyzed target policies + if 0 < self._cfg.reanalyze_ratio < 1: + batch_target_policies = np.concatenate([batch_target_policies_re, batch_target_policies_non_re]) + elif self._cfg.reanalyze_ratio == 1: + batch_target_policies = batch_target_policies_re + elif self._cfg.reanalyze_ratio == 0: + batch_target_policies = batch_target_policies_non_re + + target_batch = [batch_value_prefixs, batch_target_values, batch_target_policies] + + # A batch contains the current_batch and the target_batch + train_data = [current_batch, target_batch] + if not self.buffer_reanalyze: + self.sample_times += 1 + return train_data + + def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: Any, length=None) -> np.ndarray: + """ + Overview: + Prepare policy targets from the reanalyzed context of policies. + Arguments: + - policy_re_context (List): List of policy context to be reanalyzed. + - model (Any): The model used for inference. + - length (int, optional): The length of unroll steps. + Returns: + - batch_target_policies_re (np.ndarray): The reanalyzed policy targets. + """ + if policy_re_context is None: + return [] + + batch_target_policies_re = [] + + unroll_steps = length - 1 if length is not None else self._cfg.num_unroll_steps + + policy_obs_list, true_action, policy_mask, pos_in_game_segment_list, batch_index_list, child_visits, root_values, game_segment_lens, action_mask_segment, to_play_segment = policy_re_context + + transition_batch_size = len(policy_obs_list) + game_segment_batch_size = len(pos_in_game_segment_list) + + to_play, action_mask = self._preprocess_to_play_and_action_mask( + game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list, length + ) + + if self._cfg.model.continuous_action_space: + action_mask = [ + list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) + ] + legal_actions = [ + [-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size) + ] + else: + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)] + + with torch.no_grad(): + policy_obs_list = prepare_observation(policy_obs_list, self._cfg.model.model_type) + slices = int(np.ceil(transition_batch_size / self._cfg.mini_infer_size)) + network_output = [] + + for i in range(slices): + beg_index = self._cfg.mini_infer_size * i + end_index = self._cfg.mini_infer_size * (i + 1) + m_obs = torch.from_numpy(policy_obs_list[beg_index:end_index]).to(self._cfg.device) + m_output = model.initial_inference(m_obs) + + if not model.training: + [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( + [ + m_output.latent_state, + inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), + m_output.policy_logits + ] + ) + m_output.reward_hidden_state = ( + m_output.reward_hidden_state[0].detach().cpu().numpy(), + m_output.reward_hidden_state[1].detach().cpu().numpy() + ) + + network_output.append(m_output) + + _, value_prefix_pool, policy_logits_pool, latent_state_roots, reward_hidden_state_roots = concat_output( + network_output, data_type='efficientzero' + ) + value_prefix_pool = value_prefix_pool.squeeze().tolist() + policy_logits_pool = policy_logits_pool.tolist() + noises = [ + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self._cfg.model.action_space_size).astype( + np.float32).tolist() + for _ in range(transition_batch_size) + ] + + if self._cfg.mcts_ctree: + legal_actions_by_iter = compute_all_filters(legal_actions, unroll_steps) + noises_by_iter = compute_all_filters(noises, unroll_steps) + value_prefix_pool_by_iter = compute_all_filters(value_prefix_pool, unroll_steps) + policy_logits_pool_by_iter = compute_all_filters(policy_logits_pool, unroll_steps) + to_play_by_iter = compute_all_filters(to_play, unroll_steps) + latent_state_roots_by_iter = compute_all_filters(latent_state_roots, unroll_steps) + + batch1_core_by_iter = compute_all_filters(reward_hidden_state_roots[0][0], unroll_steps) + batch2_core_by_iter = compute_all_filters(reward_hidden_state_roots[1][0], unroll_steps) + true_action_by_iter = compute_all_filters(true_action, unroll_steps) + + temp_values = [] + temp_distributions = [] + mcts_ctree = MCTSCtree(self._cfg) + temp_search_time = 0 + temp_length = 0 + temp_infer = 0 + + if self._cfg.reuse_search: + for iter in range(unroll_steps + 1): + iter_batch_size = transition_batch_size / (unroll_steps + 1) + roots = MCTSCtree.roots(iter_batch_size, legal_actions_by_iter[iter]) + if self._cfg.reanalyze_noise: + roots.prepare(self._cfg.root_noise_weight, noises_by_iter[iter], + value_prefix_pool_by_iter[iter], policy_logits_pool_by_iter[iter], + to_play_by_iter[iter]) + else: + roots.prepare_no_noise(value_prefix_pool_by_iter[iter], policy_logits_pool_by_iter[iter], + to_play_by_iter[iter]) + + if iter == 0: + with self._origin_search_timer: + mcts_ctree.search(roots, model, latent_state_roots_by_iter[iter], + [[batch1_core_by_iter[iter]], [batch2_core_by_iter[iter]]], + to_play_by_iter[iter]) + self.origin_search_time += self._origin_search_timer.value + else: + with self._reuse_search_timer: + # ===================== Core implementation of ReZero: search_with_reuse ===================== + length, average_infer = mcts_ctree.search_with_reuse(roots, model, + latent_state_roots_by_iter[iter], + [[batch1_core_by_iter[iter]], + [batch2_core_by_iter[iter]]], + to_play_by_iter[iter], + true_action_list= + true_action_by_iter[iter], + reuse_value_list=iter_values) + temp_search_time += self._reuse_search_timer.value + temp_length += length + temp_infer += average_infer + + iter_values = roots.get_values() + iter_distributions = roots.get_distributions() + temp_values.append(iter_values) + temp_distributions.append(iter_distributions) + + else: + for iter in range(unroll_steps + 1): + iter_batch_size = transition_batch_size / (unroll_steps + 1) + roots = MCTSCtree.roots(iter_batch_size, legal_actions_by_iter[iter]) + if self._cfg.reanalyze_noise: + roots.prepare(self._cfg.root_noise_weight, noises_by_iter[iter], + value_prefix_pool_by_iter[iter], policy_logits_pool_by_iter[iter], + to_play_by_iter[iter]) + else: + roots.prepare_no_noise(value_prefix_pool_by_iter[iter], policy_logits_pool_by_iter[iter], + to_play_by_iter[iter]) + + with self._origin_search_timer: + mcts_ctree.search(roots, model, latent_state_roots_by_iter[iter], + [[batch1_core_by_iter[iter]], [batch2_core_by_iter[iter]]], + to_play_by_iter[iter]) + self.origin_search_time += self._origin_search_timer.value + + iter_values = roots.get_values() + iter_distributions = roots.get_distributions() + temp_values.append(iter_values) + temp_distributions.append(iter_distributions) + + self.origin_search_time = self.origin_search_time / (unroll_steps + 1) + + if unroll_steps == 0: + self.reuse_search_time = 0 + self.active_root_num = 0 + else: + self.reuse_search_time += (temp_search_time / unroll_steps) + self.active_root_num += (temp_length / unroll_steps) + self.average_infer += (temp_infer / unroll_steps) + + roots_legal_actions_list = legal_actions + temp_values.reverse() + temp_distributions.reverse() + roots_values = [] + roots_distributions = [] + [roots_values.extend(column) for column in zip(*temp_values)] + [roots_distributions.extend(column) for column in zip(*temp_distributions)] + + policy_index = 0 + for state_index, child_visit, root_value in zip(pos_in_game_segment_list, child_visits, root_values): + target_policies = [] + + for current_index in range(state_index, state_index + unroll_steps + 1): + distributions = roots_distributions[policy_index] + searched_value = roots_values[policy_index] + + if policy_mask[policy_index] == 0: + target_policies.append([0 for _ in range(self._cfg.model.action_space_size)]) + else: + if distributions is None: + target_policies.append(list( + np.ones(self._cfg.model.action_space_size) / self._cfg.model.action_space_size)) + else: + sim_num = sum(distributions) + child_visit[current_index] = [visit_count / sim_num for visit_count in distributions] + root_value[current_index] = searched_value + if self._cfg.action_type == 'fixed_action_space': + sum_visits = sum(distributions) + policy = [visit_count / sum_visits for visit_count in distributions] + target_policies.append(policy) + else: + policy_tmp = [0 for _ in range(self._cfg.model.action_space_size)] + sum_visits = sum(distributions) + policy = [visit_count / sum_visits for visit_count in distributions] + for index, legal_action in enumerate(roots_legal_actions_list[policy_index]): + policy_tmp[legal_action] = policy[index] + target_policies.append(policy_tmp) + + policy_index += 1 + + batch_target_policies_re.append(target_policies) + + return np.array(batch_target_policies_re) diff --git a/lzero/mcts/buffer/game_buffer_rezero_mz.py b/lzero/mcts/buffer/game_buffer_rezero_mz.py new file mode 100644 index 000000000..9e864ac5e --- /dev/null +++ b/lzero/mcts/buffer/game_buffer_rezero_mz.py @@ -0,0 +1,401 @@ +from typing import Any, List, Tuple, Union, TYPE_CHECKING + +import numpy as np +import torch +from ding.torch_utils.data_helper import to_list +from ding.utils import BUFFER_REGISTRY, EasyTimer + +from lzero.mcts.tree_search.mcts_ctree import MuZeroMCTSCtree as MCTSCtree +from lzero.mcts.tree_search.mcts_ptree import MuZeroMCTSPtree as MCTSPtree +from lzero.mcts.utils import prepare_observation +from lzero.policy import to_detach_cpu_numpy, concat_output, inverse_scalar_transform +from .game_buffer_muzero import MuZeroGameBuffer + +# from line_profiler import line_profiler +if TYPE_CHECKING: + from lzero.policy import MuZeroPolicy, EfficientZeroPolicy, SampledEfficientZeroPolicy + + +def compute_all_filters(data, num_unroll_steps): + data_by_iter = [] + for iter in range(num_unroll_steps + 1): + iter_data = [x for i, x in enumerate(data) + if (i + 1) % (num_unroll_steps + 1) == + ((num_unroll_steps + 1 - iter) % (num_unroll_steps + 1))] + data_by_iter.append(iter_data) + return data_by_iter + + +@BUFFER_REGISTRY.register('game_buffer_rezero_mz') +class ReZeroMZGameBuffer(MuZeroGameBuffer): + """ + Overview: + The specific game buffer for ReZero-MuZero policy. + """ + + def __init__(self, cfg: dict): + """ + Overview: + Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key + in the default configuration, the user-provided value will override the default configuration. Otherwise, + the default configuration will be used. + """ + super().__init__(cfg) + default_config = self.default_config() + default_config.update(cfg) + self._cfg = default_config + + # Ensure valid configuration values + assert self._cfg.env_type in ['not_board_games', 'board_games'] + assert self._cfg.action_type in ['fixed_action_space', 'varied_action_space'] + + # Initialize buffer parameters + self.replay_buffer_size = self._cfg.replay_buffer_size + self.batch_size = self._cfg.batch_size + self._alpha = self._cfg.priority_prob_alpha + self._beta = self._cfg.priority_prob_beta + + self.keep_ratio = 1 + self.model_update_interval = 10 + self.num_of_collected_episodes = 0 + self.base_idx = 0 + self.clear_time = 0 + + self.game_segment_buffer = [] + self.game_pos_priorities = [] + self.game_segment_game_pos_look_up = [] + + self._compute_target_timer = EasyTimer() + self._reuse_search_timer = EasyTimer() + self._origin_search_timer = EasyTimer() + self.buffer_reanalyze = True + self.compute_target_re_time = 0 + self.reuse_search_time = 0 + self.origin_search_time = 0 + self.sample_times = 0 + self.active_root_num = 0 + self.average_infer = 0 + + def reanalyze_buffer( + self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"] + ) -> List[Any]: + """ + Overview: + Sample data from ``GameBuffer`` and prepare the current and target batch for training. + Arguments: + - batch_size (:obj:`int`): Batch size. + - policy (:obj:`Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"]`): Policy. + Returns: + - train_data (:obj:`List`): List of train data, including current_batch and target_batch. + """ + assert self._cfg.mcts_ctree is True, "ReZero-MuZero only supports cpp mcts_ctree now!" + policy._target_model.to(self._cfg.device) + policy._target_model.eval() + + # Obtain the current batch and prepare target context + policy_re_context = self._make_reanalyze_batch(batch_size) + + with self._compute_target_timer: + segment_length = self.get_num_of_transitions() // 2000 + batch_target_policies_re = self._compute_target_policy_reanalyzed( + policy_re_context, policy._target_model, segment_length + ) + + self.compute_target_re_time += self._compute_target_timer.value + + if self.buffer_reanalyze: + self.sample_times += 1 + + def _make_reanalyze_batch(self, batch_size: int) -> Tuple[Any]: + """ + Overview: + First sample orig_data through ``_sample_orig_data()``, then prepare the context of a batch: + reward_value_context: The context of reanalyzed value targets. + policy_re_context: The context of reanalyzed policy targets. + policy_non_re_context: The context of non-reanalyzed policy targets. + current_batch: The inputs of batch. + Arguments: + - batch_size (:obj:`int`): The batch size of orig_data from replay buffer. + Returns: + - context (:obj:`Tuple`): reward_value_context, policy_re_context, policy_non_re_context, current_batch + """ + # Obtain the batch context from replay buffer + orig_data = self._sample_orig_reanalyze_data(batch_size) + game_segment_list, pos_in_game_segment_list, batch_index_list, _, make_time_list = orig_data + segment_length = self.get_num_of_transitions() // 2000 + policy_re_context = self._prepare_policy_reanalyzed_context( + [], game_segment_list, pos_in_game_segment_list, segment_length + ) + return policy_re_context + + def _prepare_policy_reanalyzed_context( + self, batch_index_list: List[str], game_segment_list: List[Any], pos_in_game_segment_list: List[str], + length=None + ) -> List[Any]: + """ + Overview: + Prepare the context of policies for calculating policy target in reanalyzing part. + Arguments: + - batch_index_list (:obj:`list`): Start transition index in the replay buffer. + - game_segment_list (:obj:`list`): List of game segments. + - pos_in_game_segment_list (:obj:`list`): Position of transition index in one game history. + - length (:obj:`int`, optional): Length of segments. + Returns: + - policy_re_context (:obj:`list`): policy_obs_list, policy_mask, pos_in_game_segment_list, indices, + child_visits, game_segment_lens, action_mask_segment, to_play_segment + """ + zero_obs = game_segment_list[0].zero_obs() + with torch.no_grad(): + policy_obs_list = [] + true_action = [] + policy_mask = [] + + unroll_steps = length - 1 if length is not None else self._cfg.num_unroll_steps + rewards, child_visits, game_segment_lens, root_values = [], [], [], [] + action_mask_segment, to_play_segment = [], [] + + for game_segment, state_index in zip(game_segment_list, pos_in_game_segment_list): + game_segment_len = len(game_segment) + game_segment_lens.append(game_segment_len) + rewards.append(game_segment.reward_segment) + action_mask_segment.append(game_segment.action_mask_segment) + to_play_segment.append(game_segment.to_play_segment) + child_visits.append(game_segment.child_visit_segment) + root_values.append(game_segment.root_value_segment) + + # Prepare the corresponding observations + game_obs = game_segment.get_unroll_obs(state_index, unroll_steps) + for current_index in range(state_index, state_index + unroll_steps + 1): + if current_index < game_segment_len: + policy_mask.append(1) + beg_index = current_index - state_index + end_index = beg_index + self._cfg.model.frame_stack_num + obs = game_obs[beg_index:end_index] + action = game_segment.action_segment[current_index] + if current_index == game_segment_len - 1: + action = -2 # use the illegal negative action to represent the gi + else: + policy_mask.append(0) + obs = zero_obs + action = -2 # use the illegal negative action to represent the padding action + policy_obs_list.append(obs) + true_action.append(action) + + policy_re_context = [ + policy_obs_list, true_action, policy_mask, pos_in_game_segment_list, batch_index_list, child_visits, + root_values, game_segment_lens, action_mask_segment, to_play_segment + ] + return policy_re_context + + # @profile + def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: Any, length=None) -> np.ndarray: + """ + Overview: + Prepare policy targets from the reanalyzed context of policies. + Arguments: + - policy_re_context (:obj:`List`): List of policy context to be reanalyzed. + - model (:obj:`Any`): The model used for inference. + - length (:obj:`int`, optional): The length of the unroll steps. + Returns: + - batch_target_policies_re (:obj:`np.ndarray`): The reanalyzed batch target policies. + """ + if policy_re_context is None: + return [] + + batch_target_policies_re = [] + + unroll_steps = length - 1 if length is not None else self._cfg.num_unroll_steps + + # Unpack the policy reanalyze context + ( + policy_obs_list, true_action, policy_mask, pos_in_game_segment_list, batch_index_list, + child_visits, root_values, game_segment_lens, action_mask_segment, to_play_segment + ) = policy_re_context + + transition_batch_size = len(policy_obs_list) + game_segment_batch_size = len(pos_in_game_segment_list) + + to_play, action_mask = self._preprocess_to_play_and_action_mask( + game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list, unroll_steps + ) + + if self._cfg.model.continuous_action_space: + action_mask = [ + list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) + ] + legal_actions = [ + [-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size) + ] + else: + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)] + + with torch.no_grad(): + policy_obs_list = prepare_observation(policy_obs_list, self._cfg.model.model_type) + slices = int(np.ceil(transition_batch_size / self._cfg.mini_infer_size)) + network_output = [] + + for i in range(slices): + beg_index = self._cfg.mini_infer_size * i + end_index = self._cfg.mini_infer_size * (i + 1) + m_obs = torch.from_numpy(policy_obs_list[beg_index:end_index]).to(self._cfg.device) + m_output = model.initial_inference(m_obs) + + if not model.training: + m_output.latent_state, m_output.value, m_output.policy_logits = to_detach_cpu_numpy( + [ + m_output.latent_state, + inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), + m_output.policy_logits + ] + ) + + network_output.append(m_output) + + _, reward_pool, policy_logits_pool, latent_state_roots = concat_output(network_output, data_type='muzero') + reward_pool = reward_pool.squeeze().tolist() + policy_logits_pool = policy_logits_pool.tolist() + noises = [ + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self._cfg.model.action_space_size).astype( + np.float32).tolist() + for _ in range(transition_batch_size) + ] + + if self._cfg.mcts_ctree: + legal_actions_by_iter = compute_all_filters(legal_actions, unroll_steps) + noises_by_iter = compute_all_filters(noises, unroll_steps) + reward_pool_by_iter = compute_all_filters(reward_pool, unroll_steps) + policy_logits_pool_by_iter = compute_all_filters(policy_logits_pool, unroll_steps) + to_play_by_iter = compute_all_filters(to_play, unroll_steps) + latent_state_roots_by_iter = compute_all_filters(latent_state_roots, unroll_steps) + true_action_by_iter = compute_all_filters(true_action, unroll_steps) + + temp_values, temp_distributions = [], [] + mcts_ctree = MCTSCtree(self._cfg) + temp_search_time, temp_length, temp_infer = 0, 0, 0 + + if self._cfg.reuse_search: + for iter in range(unroll_steps + 1): + iter_batch_size = transition_batch_size / (unroll_steps + 1) + roots = MCTSCtree.roots(iter_batch_size, legal_actions_by_iter[iter]) + + if self._cfg.reanalyze_noise: + roots.prepare( + self._cfg.root_noise_weight, + noises_by_iter[iter], + reward_pool_by_iter[iter], + policy_logits_pool_by_iter[iter], + to_play_by_iter[iter] + ) + else: + roots.prepare_no_noise( + reward_pool_by_iter[iter], + policy_logits_pool_by_iter[iter], + to_play_by_iter[iter] + ) + + if iter == 0: + with self._origin_search_timer: + mcts_ctree.search(roots, model, latent_state_roots_by_iter[iter], to_play_by_iter[iter]) + self.origin_search_time += self._origin_search_timer.value + else: + with self._reuse_search_timer: + # ===================== Core implementation of ReZero: search_with_reuse ===================== + length, average_infer = mcts_ctree.search_with_reuse( + roots, model, latent_state_roots_by_iter[iter], to_play_by_iter[iter], + true_action_list=true_action_by_iter[iter], reuse_value_list=iter_values + ) + temp_search_time += self._reuse_search_timer.value + temp_length += length + temp_infer += average_infer + + iter_values = roots.get_values() + iter_distributions = roots.get_distributions() + temp_values.append(iter_values) + temp_distributions.append(iter_distributions) + else: + for iter in range(unroll_steps + 1): + iter_batch_size = transition_batch_size / (unroll_steps + 1) + roots = MCTSCtree.roots(iter_batch_size, legal_actions_by_iter[iter]) + + if self._cfg.reanalyze_noise: + roots.prepare( + self._cfg.root_noise_weight, + noises_by_iter[iter], + reward_pool_by_iter[iter], + policy_logits_pool_by_iter[iter], + to_play_by_iter[iter] + ) + else: + roots.prepare_no_noise( + reward_pool_by_iter[iter], + policy_logits_pool_by_iter[iter], + to_play_by_iter[iter] + ) + + with self._origin_search_timer: + mcts_ctree.search(roots, model, latent_state_roots_by_iter[iter], to_play_by_iter[iter]) + self.origin_search_time += self._origin_search_timer.value + + iter_values = roots.get_values() + iter_distributions = roots.get_distributions() + temp_values.append(iter_values) + temp_distributions.append(iter_distributions) + + self.origin_search_time /= (unroll_steps + 1) + + if unroll_steps == 0: + self.reuse_search_time, self.active_root_num = 0, 0 + else: + self.reuse_search_time += (temp_search_time / unroll_steps) + self.active_root_num += (temp_length / unroll_steps) + self.average_infer += (temp_infer / unroll_steps) + + roots_legal_actions_list = legal_actions + temp_values.reverse() + temp_distributions.reverse() + roots_values, roots_distributions = [], [] + [roots_values.extend(column) for column in zip(*temp_values)] + [roots_distributions.extend(column) for column in zip(*temp_distributions)] + + policy_index = 0 + for state_index, child_visit, root_value in zip(pos_in_game_segment_list, child_visits, root_values): + target_policies = [] + + for current_index in range(state_index, state_index + unroll_steps + 1): + distributions = roots_distributions[policy_index] + searched_value = roots_values[policy_index] + + if policy_mask[policy_index] == 0: + target_policies.append([0 for _ in range(self._cfg.model.action_space_size)]) + else: + if distributions is None: + target_policies.append( + list(np.ones(self._cfg.model.action_space_size) / self._cfg.model.action_space_size) + ) + else: + # ===================== Update the data in buffer ===================== + # After the reanalysis search, new target policies and root values are obtained. + # These target policies and root values are stored in the game segment, specifically in the ``child_visit_segment`` and ``root_value_segment``. + # We replace the data at the corresponding locations with the latest search results to maintain the most up-to-date targets. + sim_num = sum(distributions) + child_visit[current_index] = [visit_count / sim_num for visit_count in distributions] + root_value[current_index] = searched_value + + if self._cfg.action_type == 'fixed_action_space': + sum_visits = sum(distributions) + policy = [visit_count / sum_visits for visit_count in distributions] + target_policies.append(policy) + else: + policy_tmp = [0 for _ in range(self._cfg.model.action_space_size)] + sum_visits = sum(distributions) + policy = [visit_count / sum_visits for visit_count in distributions] + for index, legal_action in enumerate(roots_legal_actions_list[policy_index]): + policy_tmp[legal_action] = policy[index] + target_policies.append(policy_tmp) + + policy_index += 1 + + batch_target_policies_re.append(target_policies) + + return np.array(batch_target_policies_re) + diff --git a/lzero/mcts/buffer/game_segment.py b/lzero/mcts/buffer/game_segment.py index da889d98f..c50d93e3d 100644 --- a/lzero/mcts/buffer/game_segment.py +++ b/lzero/mcts/buffer/game_segment.py @@ -210,12 +210,15 @@ def store_search_stats( store the visit count distributions and value of the root node after MCTS. """ sum_visits = sum(visit_counts) + if sum_visits == 0: + # if the sum of visit counts is 0, set it to a small value to avoid division by zero + sum_visits = 1e-6 if idx is None: self.child_visit_segment.append([visit_count / sum_visits for visit_count in visit_counts]) self.root_value_segment.append(root_value) if self.sampled_algo: self.root_sampled_actions.append(root_sampled_actions) - # store the improved policy in Gumbel Muzero: \pi'=softmax(logits + \sigma(CompletedQ)) + # store the improved policy in Gumbel MuZero: \pi'=softmax(logits + \sigma(CompletedQ)) if self.gumbel_algo: self.improved_policy_probs.append(improved_policy) else: diff --git a/lzero/mcts/ctree/ctree_efficientzero/ez_tree.pxd b/lzero/mcts/ctree/ctree_efficientzero/ez_tree.pxd index 915199349..1023150aa 100644 --- a/lzero/mcts/ctree/ctree_efficientzero/ez_tree.pxd +++ b/lzero/mcts/ctree/ctree_efficientzero/ez_tree.pxd @@ -79,9 +79,17 @@ cdef extern from "lib/cnode.h" namespace "tree": vector[float] values, vector[vector[float]] policies, CMinMaxStatsList *min_max_stats_lst, CSearchResults & results, vector[int] is_reset_list, vector[int] & to_play_batch) + void cbatch_backpropagate_with_reuse(int current_latent_state_index, float discount_factor, vector[float] value_prefixs, + vector[float] values, vector[vector[float]] policies, + CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, + vector[int] is_reset_list, vector[int] &to_play_batch, vector[int] &no_inference_lst, + vector[int] &reuse_lst, vector[float] &reuse_value_lst) void cbatch_traverse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, CMinMaxStatsList *min_max_stats_lst, CSearchResults & results, vector[int] & virtual_to_play_batch) + void cbatch_traverse_with_reuse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, + CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, + vector[int] &virtual_to_play_batch, vector[int] &true_action, vector[float] &reuse_value) cdef class MinMaxStatsList: cdef CMinMaxStatsList *cmin_max_stats_lst diff --git a/lzero/mcts/ctree/ctree_efficientzero/ez_tree.pyx b/lzero/mcts/ctree/ctree_efficientzero/ez_tree.pyx index 8149f8569..b7d28c1fd 100644 --- a/lzero/mcts/ctree/ctree_efficientzero/ez_tree.pyx +++ b/lzero/mcts/ctree/ctree_efficientzero/ez_tree.pyx @@ -91,6 +91,19 @@ def batch_backpropagate(int current_latent_state_index, float discount_factor, l cbatch_backpropagate(current_latent_state_index, discount_factor, cvalue_prefixs, cvalues, cpolicies, min_max_stats_lst.cmin_max_stats_lst, results.cresults, is_reset_list, to_play_batch) +@cython.binding +def batch_backpropagate_with_reuse(int current_latent_state_index, float discount_factor, list value_prefixs, list values, list policies, + MinMaxStatsList min_max_stats_lst, ResultsWrapper results, list is_reset_list, + list to_play_batch, list no_inference_lst, list reuse_lst, list reuse_value_lst): + cdef int i + cdef vector[float] cvalue_prefixs = value_prefixs + cdef vector[float] cvalues = values + cdef vector[vector[float]] cpolicies = policies + cdef vector[float] creuse_value_lst = reuse_value_lst + + cbatch_backpropagate_with_reuse(current_latent_state_index, discount_factor, cvalue_prefixs, cvalues, cpolicies, + min_max_stats_lst.cmin_max_stats_lst, results.cresults, is_reset_list, to_play_batch, no_inference_lst, reuse_lst, creuse_value_lst) + @cython.binding def batch_traverse(Roots roots, int pb_c_base, float pb_c_init, float discount_factor, MinMaxStatsList min_max_stats_lst, ResultsWrapper results, list virtual_to_play_batch): @@ -98,3 +111,11 @@ def batch_traverse(Roots roots, int pb_c_base, float pb_c_init, float discount_f results.cresults, virtual_to_play_batch) return results.cresults.latent_state_index_in_search_path, results.cresults.latent_state_index_in_batch, results.cresults.last_actions, results.cresults.virtual_to_play_batchs + +@cython.binding +def batch_traverse_with_reuse(Roots roots, int pb_c_base, float pb_c_init, float discount_factor, MinMaxStatsList min_max_stats_lst, + ResultsWrapper results, list virtual_to_play_batch, list true_action, list reuse_value): + cbatch_traverse_with_reuse(roots.roots, pb_c_base, pb_c_init, discount_factor, min_max_stats_lst.cmin_max_stats_lst, results.cresults, + virtual_to_play_batch, true_action, reuse_value) + + return results.cresults.latent_state_index_in_search_path, results.cresults.latent_state_index_in_batch, results.cresults.last_actions, results.cresults.virtual_to_play_batchs diff --git a/lzero/mcts/ctree/ctree_efficientzero/lib/cnode.cpp b/lzero/mcts/ctree/ctree_efficientzero/lib/cnode.cpp index 606744e0a..8a94a7ca9 100644 --- a/lzero/mcts/ctree/ctree_efficientzero/lib/cnode.cpp +++ b/lzero/mcts/ctree/ctree_efficientzero/lib/cnode.cpp @@ -600,6 +600,54 @@ namespace tree } } + void cbatch_backpropagate_with_reuse(int current_latent_state_index, float discount_factor, const std::vector &value_prefixs, const std::vector &values, const std::vector > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector is_reset_list, std::vector &to_play_batch, std::vector &no_inference_lst, std::vector &reuse_lst, std::vector &reuse_value_lst) + { + /* + Overview: + Expand the nodes along the search path and update the infos. + Arguments: + - current_latent_state_index: The index of latent state of the leaf node in the search path. + - discount_factor: the discount factor of reward. + - value_prefixs: the value prefixs of nodes along the search path. + - values: the values to propagate along the search path. + - policies: the policy logits of nodes along the search path. + - min_max_stats: a tool used to min-max normalize the q value. + - results: the search results. + - to_play_batch: the batch of which player is playing on this node. + - no_inference_lst: the list of the nodes which does not need to expand. + - reuse_lst: the list of the nodes which should use reuse-value to backpropagate. + - reuse_value_lst: the list of the reuse-value. + */ + int count_a = 0; + int count_b = 0; + int count_c = 0; + float value_propagate = 0; + for (int i = 0; i < results.num; ++i) + { + if (i == no_inference_lst[count_a]) + { + count_a = count_a + 1; + value_propagate = reuse_value_lst[i]; + } + else + { + results.nodes[i]->expand(to_play_batch[i], current_latent_state_index, count_b, value_prefixs[count_b], policies[count_b]); + if (i == reuse_lst[count_c]) + { + value_propagate = reuse_value_lst[i]; + count_c = count_c + 1; + } + else + { + value_propagate = values[count_b]; + } + count_b = count_b + 1; + } + results.nodes[i]->is_reset = is_reset_list[i]; + cbackpropagate(results.search_paths[i], min_max_stats_lst->stats_lst[i], to_play_batch[i], value_propagate, discount_factor); + } + } + int cselect_child(CNode *root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q, int players) { /* @@ -646,6 +694,65 @@ namespace tree return action; } + int cselect_root_child(CNode *root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q, int players, int true_action, float reuse_value) + { + /* + Overview: + Select the child node of the roots according to ucb scores. + Arguments: + - root: the roots to select the child node. + - min_max_stats: a tool used to min-max normalize the score. + - pb_c_base: constants c2 in muzero. + - pb_c_init: constants c1 in muzero. + - disount_factor: the discount factor of reward. + - mean_q: the mean q value of the parent node. + - players: the number of players. + - true_action: the action chosen in the trajectory. + - reuse_value: the value obtained from the search of the next state in the trajectory. + Returns: + - action: the action to select. + */ + + float max_score = FLOAT_MIN; + const float epsilon = 0.000001; + std::vector max_index_lst; + for (auto a : root->legal_actions) + { + + CNode *child = root->get_child(a); + float temp_score = 0.0; + if (a == true_action) + { + temp_score = carm_score(child, min_max_stats, mean_q, root->is_reset, reuse_value, root->visit_count - 1, root->value_prefix, pb_c_base, pb_c_init, discount_factor, players); + } + else + { + temp_score = cucb_score(child, min_max_stats, mean_q, root->is_reset, root->visit_count - 1, root->value_prefix, pb_c_base, pb_c_init, discount_factor, players); + } + + if (max_score < temp_score) + { + max_score = temp_score; + + max_index_lst.clear(); + max_index_lst.push_back(a); + } + else if (temp_score >= max_score - epsilon) + { + max_index_lst.push_back(a); + } + } + + int action = 0; + if (max_index_lst.size() > 0) + { + int rand_index = rand() % max_index_lst.size(); + action = max_index_lst[rand_index]; + } + // printf("select root child ends"); + return action; + } + float cucb_score(CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, int is_reset, float total_children_visit_counts, float parent_value_prefix, float pb_c_base, float pb_c_init, float discount_factor, int players) { /* @@ -706,6 +813,76 @@ namespace tree return prior_score + value_score; // ucb_value } + float carm_score(CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, int is_reset, float reuse_value, float total_children_visit_counts, float parent_value_prefix, float pb_c_base, float pb_c_init, float discount_factor, int players) + { + /* + Overview: + Compute the ucb score of the child. + Arguments: + - child: the child node to compute ucb score. + - min_max_stats: a tool used to min-max normalize the score. + - parent_mean_q: the mean q value of the parent node. + - is_reset: whether the value prefix needs to be reset. + - total_children_visit_counts: the total visit counts of the child nodes of the parent node. + - parent_value_prefix: the value prefix of parent node. + - pb_c_base: constants c2 in muzero. + - pb_c_init: constants c1 in muzero. + - disount_factor: the discount factor of reward. + - players: the number of players. + Returns: + - ucb_value: the ucb score of the child. + */ + float pb_c = 0.0, prior_score = 0.0, value_score = 0.0; + pb_c = log((total_children_visit_counts + pb_c_base + 1) / pb_c_base) + pb_c_init; + pb_c *= (sqrt(total_children_visit_counts) / (child->visit_count + 1)); + + prior_score = pb_c * child->prior; + if (child->visit_count == 0) + { + value_score = parent_mean_q; + } + else + { + float true_reward = child->value_prefix - parent_value_prefix; + if (is_reset == 1) + { + true_reward = child->value_prefix; + } + + if (players == 1) + { + value_score = true_reward + discount_factor * reuse_value; + } + else if (players == 2) + { + value_score = true_reward + discount_factor * (-reuse_value); + } + } + + value_score = min_max_stats.normalize(value_score); + + if (value_score < 0) + { + value_score = 0; + } + else if (value_score > 1) + { + value_score = 1; + } + + float ucb_value = 0.0; + if (child->visit_count == 0) + { + ucb_value = prior_score + value_score; + } + else + { + ucb_value = value_score; + } + // printf("carmscore ends"); + return ucb_value; + } + void cbatch_traverse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector &virtual_to_play_batch) { /* @@ -784,4 +961,113 @@ namespace tree results.virtual_to_play_batchs.push_back(virtual_to_play_batch[i]); } } + + void cbatch_traverse_with_reuse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector &virtual_to_play_batch, std::vector &true_action, std::vector &reuse_value) + { + /* + Overview: + Search node path from the roots. + Arguments: + - roots: the roots that search from. + - pb_c_base: constants c2 in muzero. + - pb_c_init: constants c1 in muzero. + - disount_factor: the discount factor of reward. + - min_max_stats: a tool used to min-max normalize the score. + - results: the search results. + - virtual_to_play_batch: the batch of which player is playing on this node. + - true_action: the action chosen in the trajectory. + - reuse_value: the value obtained from the search of the next state in the trajectory. + */ + // set seed + get_time_and_set_rand_seed(); + + int last_action = -1; + float parent_q = 0.0; + results.search_lens = std::vector(); + + int players = 0; + int largest_element = *max_element(virtual_to_play_batch.begin(), virtual_to_play_batch.end()); // 0 or 2 + if (largest_element == -1) + { + players = 1; + } + else + { + players = 2; + } + + for (int i = 0; i < results.num; ++i) + { + CNode *node = &(roots->roots[i]); + int is_root = 1; + int search_len = 0; + results.search_paths[i].push_back(node); + + while (node->expanded()) + { + float mean_q = node->compute_mean_q(is_root, parent_q, discount_factor); + parent_q = mean_q; + + int action = 0; + if (is_root) + { + action = cselect_root_child(node, min_max_stats_lst->stats_lst[i], pb_c_base, pb_c_init, discount_factor, mean_q, players, true_action[i], reuse_value[i]); + } + else + { + action = cselect_child(node, min_max_stats_lst->stats_lst[i], pb_c_base, pb_c_init, discount_factor, mean_q, players); + } + + if (players > 1) + { + assert(virtual_to_play_batch[i] == 1 || virtual_to_play_batch[i] == 2); + if (virtual_to_play_batch[i] == 1) + { + virtual_to_play_batch[i] = 2; + } + else + { + virtual_to_play_batch[i] = 1; + } + } + + node->best_action = action; + // next + node = node->get_child(action); + last_action = action; + results.search_paths[i].push_back(node); + search_len += 1; + + if(is_root && action == true_action[i]) + { + break; + } + + is_root = 0; + } + + if (node->expanded()) + { + results.latent_state_index_in_search_path.push_back(-1); + results.latent_state_index_in_batch.push_back(i); + + results.last_actions.push_back(last_action); + results.search_lens.push_back(search_len); + results.nodes.push_back(node); + results.virtual_to_play_batchs.push_back(virtual_to_play_batch[i]); + } + else + { + CNode *parent = results.search_paths[i][results.search_paths[i].size() - 2]; + + results.latent_state_index_in_search_path.push_back(parent->current_latent_state_index); + results.latent_state_index_in_batch.push_back(parent->batch_index); + + results.last_actions.push_back(last_action); + results.search_lens.push_back(search_len); + results.nodes.push_back(node); + results.virtual_to_play_batchs.push_back(virtual_to_play_batch[i]); + } + } + } } \ No newline at end of file diff --git a/lzero/mcts/ctree/ctree_efficientzero/lib/cnode.h b/lzero/mcts/ctree/ctree_efficientzero/lib/cnode.h index 52b6e6dfa..bca179151 100644 --- a/lzero/mcts/ctree/ctree_efficientzero/lib/cnode.h +++ b/lzero/mcts/ctree/ctree_efficientzero/lib/cnode.h @@ -83,9 +83,13 @@ namespace tree { void update_tree_q(CNode* root, tools::CMinMaxStats &min_max_stats, float discount_factor, int players); void cbackpropagate(std::vector &search_path, tools::CMinMaxStats &min_max_stats, int to_play, float value, float discount_factor); void cbatch_backpropagate(int current_latent_state_index, float discount_factor, const std::vector &value_prefixs, const std::vector &values, const std::vector > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector is_reset_list, std::vector &to_play_batch); + void cbatch_backpropagate_with_reuse(int current_latent_state_index, float discount_factor, const std::vector &value_prefixs, const std::vector &values, const std::vector > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector is_reset_list, std::vector &to_play_batch, std::vector &no_inference_lst, std::vector &reuse_lst, std::vector &reuse_value_lst); int cselect_child(CNode* root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q, int players); + int cselect_root_child(CNode *root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q, int players, int true_action, float reuse_value); float cucb_score(CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, int is_reset, float total_children_visit_counts, float parent_value_prefix, float pb_c_base, float pb_c_init, float discount_factor, int players); + float carm_score(CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, int is_reset, float reuse_value, float total_children_visit_counts, float parent_value_prefix, float pb_c_base, float pb_c_init, float discount_factor, int players); void cbatch_traverse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector &virtual_to_play_batch); + void cbatch_traverse_with_reuse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector &virtual_to_play_batch, std::vector &true_action, std::vector &reuse_value); } #endif \ No newline at end of file diff --git a/lzero/mcts/ctree/ctree_muzero/lib/cnode.cpp b/lzero/mcts/ctree/ctree_muzero/lib/cnode.cpp index 0fe9d8018..24eb3605c 100644 --- a/lzero/mcts/ctree/ctree_muzero/lib/cnode.cpp +++ b/lzero/mcts/ctree/ctree_muzero/lib/cnode.cpp @@ -499,6 +499,55 @@ namespace tree } } + void cbatch_backpropagate_with_reuse(int current_latent_state_index, float discount_factor, const std::vector &value_prefixs, const std::vector &values, const std::vector > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector &to_play_batch, std::vector &no_inference_lst, std::vector &reuse_lst, std::vector &reuse_value_lst) + { + /* + Overview: + Expand the nodes along the search path and update the infos. Details are similar to cbatch_backpropagate, but with reuse value. + Please refer to https://arxiv.org/abs/2404.16364 for more details. + Arguments: + - current_latent_state_index: The index of latent state of the leaf node in the search path. + - discount_factor: the discount factor of reward. + - value_prefixs: the value prefixs of nodes along the search path. + - values: the values to propagate along the search path. + - policies: the policy logits of nodes along the search path. + - min_max_stats: a tool used to min-max normalize the q value. + - results: the search results. + - to_play_batch: the batch of which player is playing on this node. + - no_inference_lst: the list of the nodes which does not need to expand. + - reuse_lst: the list of the nodes which should use reuse-value to backpropagate. + - reuse_value_lst: the list of the reuse-value. + */ + int count_a = 0; + int count_b = 0; + int count_c = 0; + float value_propagate = 0; + for (int i = 0; i < results.num; ++i) + { + if (i == no_inference_lst[count_a]) + { + count_a = count_a + 1; + value_propagate = reuse_value_lst[i]; + } + else + { + results.nodes[i]->expand(to_play_batch[i], current_latent_state_index, count_b, value_prefixs[count_b], policies[count_b]); + if (i == reuse_lst[count_c]) + { + value_propagate = reuse_value_lst[i]; + count_c = count_c + 1; + } + else + { + value_propagate = values[count_b]; + } + count_b = count_b + 1; + } + + cbackpropagate(results.search_paths[i], min_max_stats_lst->stats_lst[i], to_play_batch[i], value_propagate, discount_factor); + } + } + int cselect_child(CNode *root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q, int players) { /* @@ -546,6 +595,63 @@ namespace tree return action; } + int cselect_root_child(CNode *root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q, int players, int true_action, float reuse_value) + { + /* + Overview: + Select the child node of the roots according to ucb scores. + Arguments: + - root: the roots to select the child node. + - min_max_stats: a tool used to min-max normalize the score. + - pb_c_base: constants c2 in muzero. + - pb_c_init: constants c1 in muzero. + - discount_factor: the discount factor of reward. + - mean_q: the mean q value of the parent node. + - players: the number of players. + - true_action: the action chosen in the trajectory. + - reuse_value: the value obtained from the search of the next state in the trajectory. + Returns: + - action: the action to select. + */ + float max_score = FLOAT_MIN; + const float epsilon = 0.000001; + std::vector max_index_lst; + for (auto a : root->legal_actions) + { + + CNode *child = root->get_child(a); + float temp_score = 0.0; + if (a == true_action) + { + temp_score = carm_score(child, min_max_stats, mean_q, reuse_value, root->visit_count - 1, pb_c_base, pb_c_init, discount_factor, players); + } + else + { + temp_score = cucb_score(child, min_max_stats, mean_q, root->visit_count - 1, pb_c_base, pb_c_init, discount_factor, players); + } + + if (max_score < temp_score) + { + max_score = temp_score; + + max_index_lst.clear(); + max_index_lst.push_back(a); + } + else if (temp_score >= max_score - epsilon) + { + max_index_lst.push_back(a); + } + } + + int action = 0; + if (max_index_lst.size() > 0) + { + int rand_index = rand() % max_index_lst.size(); + action = max_index_lst[rand_index]; + } + return action; + } + float cucb_score(CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, float total_children_visit_counts, float pb_c_base, float pb_c_init, float discount_factor, int players) { /* @@ -592,6 +698,60 @@ namespace tree return ucb_value; } + + float carm_score(CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, float reuse_value, float total_children_visit_counts, float pb_c_base, float pb_c_init, float discount_factor, int players) + { + /* + Overview: + Compute the ucb score of the child. + Arguments: + - child: the child node to compute ucb score. + - min_max_stats: a tool used to min-max normalize the score. + - mean_q: the mean q value of the parent node. + - total_children_visit_counts: the total visit counts of the child nodes of the parent node. + - pb_c_base: constants c2 in muzero. + - pb_c_init: constants c1 in muzero. + - disount_factor: the discount factor of reward. + - players: the number of players. + Returns: + - ucb_value: the ucb score of the child. + */ + float pb_c = 0.0, prior_score = 0.0, value_score = 0.0; + pb_c = log((total_children_visit_counts + pb_c_base + 1) / pb_c_base) + pb_c_init; + pb_c *= (sqrt(total_children_visit_counts) / (child->visit_count + 1)); + + prior_score = pb_c * child->prior; + if (child->visit_count == 0) + { + value_score = parent_mean_q; + } + else + { + float true_reward = child->reward; + if (players == 1) + value_score = true_reward + discount_factor * reuse_value; + else if (players == 2) + value_score = true_reward + discount_factor * (-reuse_value); + } + + value_score = min_max_stats.normalize(value_score); + + if (value_score < 0) + value_score = 0; + if (value_score > 1) + value_score = 1; + float ucb_value = 0.0; + if (child->visit_count == 0) + { + ucb_value = prior_score + value_score; + } + else + { + ucb_value = value_score; + } + return ucb_value; + } + void cbatch_traverse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector &virtual_to_play_batch) { /* @@ -663,4 +823,108 @@ namespace tree } } + + void cbatch_traverse_with_reuse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector &virtual_to_play_batch, std::vector &true_action, std::vector &reuse_value) + { + /* + Overview: + Search node path from the roots. Details are similar to cbatch_traverse, but with reuse value. + Please refer to https://arxiv.org/abs/2404.16364 for more details. + Arguments: + - roots: the roots that search from. + - pb_c_base: constants c2 in muzero. + - pb_c_init: constants c1 in muzero. + - disount_factor: the discount factor of reward. + - min_max_stats: a tool used to min-max normalize the score. + - results: the search results. + - virtual_to_play_batch: the batch of which player is playing on this node. + - true_action: the action chosen in the trajectory. + - reuse_value: the value obtained from the search of the next state in the trajectory. + */ + // set seed + get_time_and_set_rand_seed(); + + int last_action = -1; + float parent_q = 0.0; + results.search_lens = std::vector(); + + int players = 0; + int largest_element = *max_element(virtual_to_play_batch.begin(), virtual_to_play_batch.end()); // 0 or 2 + if (largest_element == -1) + players = 1; + else + players = 2; + + for (int i = 0; i < results.num; ++i) + { + CNode *node = &(roots->roots[i]); + int is_root = 1; + int search_len = 0; + results.search_paths[i].push_back(node); + + while (node->expanded()) + { + float mean_q = node->compute_mean_q(is_root, parent_q, discount_factor); + parent_q = mean_q; + + int action = 0; + if (is_root) + { + action = cselect_root_child(node, min_max_stats_lst->stats_lst[i], pb_c_base, pb_c_init, discount_factor, mean_q, players, true_action[i], reuse_value[i]); + } + else + { + action = cselect_child(node, min_max_stats_lst->stats_lst[i], pb_c_base, pb_c_init, discount_factor, mean_q, players); + } + + if (players > 1) + { + assert(virtual_to_play_batch[i] == 1 || virtual_to_play_batch[i] == 2); + if (virtual_to_play_batch[i] == 1) + virtual_to_play_batch[i] = 2; + else + virtual_to_play_batch[i] = 1; + } + + node->best_action = action; + // next + node = node->get_child(action); + last_action = action; + results.search_paths[i].push_back(node); + search_len += 1; + + if(is_root && action == true_action[i]) + { + break; + } + + is_root = 0; + + } + + if (node->expanded()) + { + results.latent_state_index_in_search_path.push_back(-1); + results.latent_state_index_in_batch.push_back(i); + + results.last_actions.push_back(last_action); + results.search_lens.push_back(search_len); + results.nodes.push_back(node); + results.virtual_to_play_batchs.push_back(virtual_to_play_batch[i]); + } + else + { + CNode *parent = results.search_paths[i][results.search_paths[i].size() - 2]; + + results.latent_state_index_in_search_path.push_back(parent->current_latent_state_index); + results.latent_state_index_in_batch.push_back(parent->batch_index); + + results.last_actions.push_back(last_action); + results.search_lens.push_back(search_len); + results.nodes.push_back(node); + results.virtual_to_play_batchs.push_back(virtual_to_play_batch[i]); + } + } + } + } \ No newline at end of file diff --git a/lzero/mcts/ctree/ctree_muzero/lib/cnode.h b/lzero/mcts/ctree/ctree_muzero/lib/cnode.h index 1c5fdea6e..c8f1c9091 100644 --- a/lzero/mcts/ctree/ctree_muzero/lib/cnode.h +++ b/lzero/mcts/ctree/ctree_muzero/lib/cnode.h @@ -83,9 +83,13 @@ namespace tree { void update_tree_q(CNode* root, tools::CMinMaxStats &min_max_stats, float discount_factor, int players); void cbackpropagate(std::vector &search_path, tools::CMinMaxStats &min_max_stats, int to_play, float value, float discount_factor); void cbatch_backpropagate(int current_latent_state_index, float discount_factor, const std::vector &rewards, const std::vector &values, const std::vector > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector &to_play_batch); + void cbatch_backpropagate_with_reuse(int current_latent_state_index, float discount_factor, const std::vector &value_prefixs, const std::vector &values, const std::vector > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector &to_play_batch, std::vector &no_inference_lst, std::vector &reuse_lst, std::vector &reuse_value_lst); int cselect_child(CNode* root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q, int players); + int cselect_root_child(CNode *root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q, int players, int true_action, float reuse_value); float cucb_score(CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, float total_children_visit_counts, float pb_c_base, float pb_c_init, float discount_factor, int players); + float carm_score(CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, float reuse_value, float total_children_visit_counts, float pb_c_base, float pb_c_init, float discount_factor, int players); void cbatch_traverse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector &virtual_to_play_batch); + void cbatch_traverse_with_reuse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector &virtual_to_play_batch, std::vector &true_action, std::vector &reuse_value); } #endif \ No newline at end of file diff --git a/lzero/mcts/ctree/ctree_muzero/mz_tree.pxd b/lzero/mcts/ctree/ctree_muzero/mz_tree.pxd index b46f86912..6c65f4466 100644 --- a/lzero/mcts/ctree/ctree_muzero/mz_tree.pxd +++ b/lzero/mcts/ctree/ctree_muzero/mz_tree.pxd @@ -70,4 +70,7 @@ cdef extern from "lib/cnode.h" namespace "tree": cdef void cbackpropagate(vector[CNode*] &search_path, CMinMaxStats &min_max_stats, int to_play, float value, float discount_factor) void cbatch_backpropagate(int current_latent_state_index, float discount_factor, vector[float] value_prefixs, vector[float] values, vector[vector[float]] policies, CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, vector[int] &to_play_batch) + void cbatch_backpropagate_with_reuse(int current_latent_state_index, float discount_factor, vector[float] value_prefixs, vector[float] values, vector[vector[float]] policies, + CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, vector[int] &to_play_batch, vector[int] &no_inference_lst, vector[int] &reuse_lst, vector[float] &reuse_value_lst) void cbatch_traverse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, vector[int] &virtual_to_play_batch) + void cbatch_traverse_with_reuse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, vector[int] &virtual_to_play_batch, vector[int] &true_action, vector[float] &reuse_value) diff --git a/lzero/mcts/ctree/ctree_muzero/mz_tree.pyx b/lzero/mcts/ctree/ctree_muzero/mz_tree.pyx index 80e5c4505..506e69d57 100644 --- a/lzero/mcts/ctree/ctree_muzero/mz_tree.pyx +++ b/lzero/mcts/ctree/ctree_muzero/mz_tree.pyx @@ -81,9 +81,27 @@ def batch_backpropagate(int current_latent_state_index, float discount_factor, l cbatch_backpropagate(current_latent_state_index, discount_factor, cvalue_prefixs, cvalues, cpolicies, min_max_stats_lst.cmin_max_stats_lst, results.cresults, to_play_batch) +def batch_backpropagate_with_reuse(int current_latent_state_index, float discount_factor, list value_prefixs, list values, list policies, + MinMaxStatsList min_max_stats_lst, ResultsWrapper results, list to_play_batch, list no_inference_lst, list reuse_lst, list reuse_value_lst): + cdef int i + cdef vector[float] cvalue_prefixs = value_prefixs + cdef vector[float] cvalues = values + cdef vector[vector[float]] cpolicies = policies + cdef vector[float] creuse_value_lst = reuse_value_lst + + cbatch_backpropagate_with_reuse(current_latent_state_index, discount_factor, cvalue_prefixs, cvalues, cpolicies, + min_max_stats_lst.cmin_max_stats_lst, results.cresults, to_play_batch, no_inference_lst, reuse_lst, creuse_value_lst) + def batch_traverse(Roots roots, int pb_c_base, float pb_c_init, float discount_factor, MinMaxStatsList min_max_stats_lst, ResultsWrapper results, list virtual_to_play_batch): cbatch_traverse(roots.roots, pb_c_base, pb_c_init, discount_factor, min_max_stats_lst.cmin_max_stats_lst, results.cresults, virtual_to_play_batch) return results.cresults.latent_state_index_in_search_path, results.cresults.latent_state_index_in_batch, results.cresults.last_actions, results.cresults.virtual_to_play_batchs + +def batch_traverse_with_reuse(Roots roots, int pb_c_base, float pb_c_init, float discount_factor, MinMaxStatsList min_max_stats_lst, + ResultsWrapper results, list virtual_to_play_batch, list true_action, list reuse_value): + cbatch_traverse_with_reuse(roots.roots, pb_c_base, pb_c_init, discount_factor, min_max_stats_lst.cmin_max_stats_lst, results.cresults, + virtual_to_play_batch, true_action, reuse_value) + + return results.cresults.latent_state_index_in_search_path, results.cresults.latent_state_index_in_batch, results.cresults.last_actions, results.cresults.virtual_to_play_batchs diff --git a/lzero/mcts/ptree/ptree_mz.py b/lzero/mcts/ptree/ptree_mz.py index 909d48e4f..afb6f7667 100644 --- a/lzero/mcts/ptree/ptree_mz.py +++ b/lzero/mcts/ptree/ptree_mz.py @@ -3,7 +3,7 @@ """ import math import random -from typing import List, Dict, Any, Tuple, Union +from typing import List, Any, Tuple, Union import numpy as np import torch @@ -19,6 +19,7 @@ class Node: ``__init__``, ``expand``, ``add_exploration_noise``, ``compute_mean_q``, ``get_trajectory``, \ ``get_children_distribution``, ``get_child``, ``expanded``, ``value``. """ + def __init__(self, prior: float, legal_actions: List = None, action_space_size: int = 9) -> None: """ Overview: @@ -412,7 +413,7 @@ def compute_ucb_score( value_score = 0 if value_score > 1: value_score = 1 - + ucb_score = prior_score + value_score return ucb_score diff --git a/lzero/mcts/tests/config/tictactoe_muzero_bot_mode_config_for_test.py b/lzero/mcts/tests/config/tictactoe_muzero_bot_mode_config_for_test.py index 8734298e3..ef5d79233 100644 --- a/lzero/mcts/tests/config/tictactoe_muzero_bot_mode_config_for_test.py +++ b/lzero/mcts/tests/config/tictactoe_muzero_bot_mode_config_for_test.py @@ -16,7 +16,7 @@ # ============================================================== tictactoe_muzero_config = dict( - exp_name='data_mz_ctree/tictactoe_muzero_bot_mode_seed0', + exp_name='data_muzero/tictactoe_muzero_bot_mode_seed0', env=dict( battle_mode='play_with_bot_mode', collector_env_num=collector_env_num, diff --git a/lzero/mcts/tree_search/mcts_ctree.py b/lzero/mcts/tree_search/mcts_ctree.py index 76131d01a..e7543ab9d 100644 --- a/lzero/mcts/tree_search/mcts_ctree.py +++ b/lzero/mcts/tree_search/mcts_ctree.py @@ -6,10 +6,13 @@ from easydict import EasyDict from lzero.mcts.ctree.ctree_efficientzero import ez_tree as tree_efficientzero -from lzero.mcts.ctree.ctree_muzero import mz_tree as tree_muzero from lzero.mcts.ctree.ctree_gumbel_muzero import gmz_tree as tree_gumbel_muzero +from lzero.mcts.ctree.ctree_muzero import mz_tree as tree_muzero from lzero.policy import InverseScalarTransform, to_detach_cpu_numpy +# from line_profiler import line_profiler +# import cProfile + if TYPE_CHECKING: from lzero.mcts.ctree.ctree_efficientzero import ez_tree as ez_ctree from lzero.mcts.ctree.ctree_muzero import mz_tree as mz_ctree @@ -186,6 +189,105 @@ def search( min_max_stats_lst, results, virtual_to_play_batch ) + def search_with_reuse( + self, + roots: Any, + model: torch.nn.Module, + latent_state_roots: List[Any], + to_play_batch: Union[int, List[Any]], + true_action_list=None, + reuse_value_list=None + ) -> None: + """ + Overview: + Perform Monte Carlo Tree Search (MCTS) for the root nodes in parallel. Utilizes the cpp ctree for efficiency. + Please refer to https://arxiv.org/abs/2404.16364 for more details. + Arguments: + - roots (:obj:`Any`): A batch of expanded root nodes. + - model (:obj:`torch.nn.Module`): The neural network model. + - latent_state_roots (:obj:`list`): The hidden states of the root nodes. + - to_play_batch (:obj:`Union[int, list]`): The list or batch indicator for players in self-play mode. + - true_action_list (:obj:`list`, optional): A list of true actions for reuse. + - reuse_value_list (:obj:`list`, optional): A list of values for reuse. + """ + + with torch.no_grad(): + model.eval() + + # Initialize constants and variables + batch_size = roots.num + pb_c_base, pb_c_init, discount_factor = self._cfg.pb_c_base, self._cfg.pb_c_init, self._cfg.discount_factor + latent_state_batch_in_search_path = [latent_state_roots] + min_max_stats_lst = tree_muzero.MinMaxStatsList(batch_size) + min_max_stats_lst.set_delta(self._cfg.value_delta_max) + infer_sum = 0 + + for simulation_index in range(self._cfg.num_simulations): + latent_states = [] + temp_actions = [] + no_inference_lst = [] + reuse_lst = [] + results = tree_muzero.ResultsWrapper(num=batch_size) + + # Selection phase: traverse the tree to select a leaf node + if self._cfg.env_type == 'not_board_games': + latent_state_index_in_search_path, latent_state_index_in_batch, last_actions, virtual_to_play_batch = tree_muzero.batch_traverse_with_reuse( + roots, pb_c_base, pb_c_init, discount_factor, min_max_stats_lst, results, + to_play_batch, true_action_list, reuse_value_list + ) + else: + latent_state_index_in_search_path, latent_state_index_in_batch, last_actions, virtual_to_play_batch = tree_muzero.batch_traverse_with_reuse( + roots, pb_c_base, pb_c_init, discount_factor, min_max_stats_lst, results, + copy.deepcopy(to_play_batch), true_action_list, reuse_value_list + ) + + # Collect latent states and actions for expansion + for count, (ix, iy) in enumerate(zip(latent_state_index_in_search_path, latent_state_index_in_batch)): + if ix != -1: + latent_states.append(latent_state_batch_in_search_path[ix][iy]) + temp_actions.append(last_actions[count]) + else: + no_inference_lst.append(iy) + if ix == 0 and last_actions[count] == true_action_list[count]: + reuse_lst.append(count) + + length = len(temp_actions) + latent_states = torch.from_numpy(np.asarray(latent_states)).to(self._cfg.device) + temp_actions = torch.from_numpy(np.asarray(temp_actions)).to(self._cfg.device).long() + + # Expansion phase: expand the leaf node and evaluate the new node + if length != 0: + network_output = model.recurrent_inference(latent_states, temp_actions) + network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state) + network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits) + network_output.value = to_detach_cpu_numpy( + self.inverse_scalar_transform_handle(network_output.value)) + network_output.reward = to_detach_cpu_numpy( + self.inverse_scalar_transform_handle(network_output.reward)) + + latent_state_batch_in_search_path.append(network_output.latent_state) + reward_batch = network_output.reward.reshape(-1).tolist() + value_batch = network_output.value.reshape(-1).tolist() + policy_logits_batch = network_output.policy_logits.tolist() + else: + latent_state_batch_in_search_path.append([]) + reward_batch = [] + value_batch = [] + policy_logits_batch = [] + + # Backup phase: propagate the evaluation results back through the tree + current_latent_state_index = simulation_index + 1 + no_inference_lst.append(-1) + reuse_lst.append(-1) + tree_muzero.batch_backpropagate_with_reuse( + current_latent_state_index, discount_factor, reward_batch, value_batch, policy_logits_batch, + min_max_stats_lst, results, virtual_to_play_batch, no_inference_lst, reuse_lst, reuse_value_list + ) + infer_sum += length + + average_infer = infer_sum / self._cfg.num_simulations + return length, average_infer + class EfficientZeroMCTSCtree(object): """ @@ -288,6 +390,8 @@ def search( # the data storage of latent states: storing the latent state of all the nodes in one search. latent_state_batch_in_search_path = [latent_state_roots] # the data storage of value prefix hidden states in LSTM + # print(f"reward_hidden_state_roots[0]={reward_hidden_state_roots[0]}") + # print(f"reward_hidden_state_roots[1]={reward_hidden_state_roots[1]}") reward_hidden_state_c_batch = [reward_hidden_state_roots[0]] reward_hidden_state_h_batch = [reward_hidden_state_roots[1]] @@ -354,7 +458,8 @@ def search( network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state) network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits) network_output.value = to_detach_cpu_numpy(self.inverse_scalar_transform_handle(network_output.value)) - network_output.value_prefix = to_detach_cpu_numpy(self.inverse_scalar_transform_handle(network_output.value_prefix)) + network_output.value_prefix = to_detach_cpu_numpy( + self.inverse_scalar_transform_handle(network_output.value_prefix)) network_output.reward_hidden_state = ( network_output.reward_hidden_state[0].detach().cpu().numpy(), @@ -390,6 +495,133 @@ def search( min_max_stats_lst, results, is_reset_list, virtual_to_play_batch ) + def search_with_reuse( + self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], + reward_hidden_state_roots: List[Any], to_play_batch: Union[int, List[Any]], + true_action_list=None, reuse_value_list=None + ) -> None: + """ + Perform Monte Carlo Tree Search (MCTS) for the root nodes in parallel, utilizing model inference in parallel. + This method uses the cpp ctree for efficiency. + Please refer to https://arxiv.org/abs/2404.16364 for more details. + + Args: + roots (Any): A batch of expanded root nodes. + model (torch.nn.Module): The model to use for inference. + latent_state_roots (List[Any]): The hidden states of the root nodes. + reward_hidden_state_roots (List[Any]): The value prefix hidden states in the LSTM of the roots. + to_play_batch (Union[int, List[Any]]): The to_play_batch list used in self-play-mode board games. + true_action_list (Optional[List[Any]]): List of true actions for reuse. + reuse_value_list (Optional[List[Any]]): List of values for reuse. + + Returns: + None + """ + with torch.no_grad(): + model.eval() + + batch_size = roots.num + pb_c_base, pb_c_init, discount_factor = self._cfg.pb_c_base, self._cfg.pb_c_init, self._cfg.discount_factor + + latent_state_batch_in_search_path = [latent_state_roots] + reward_hidden_state_c_batch = [reward_hidden_state_roots[0]] + reward_hidden_state_h_batch = [reward_hidden_state_roots[1]] + + min_max_stats_lst = tree_efficientzero.MinMaxStatsList(batch_size) + min_max_stats_lst.set_delta(self._cfg.value_delta_max) + + infer_sum = 0 + + for simulation_index in range(self._cfg.num_simulations): + latent_states, hidden_states_c_reward, hidden_states_h_reward = [], [], [] + temp_actions, temp_search_lens, no_inference_lst, reuse_lst = [], [], [], [] + + results = tree_efficientzero.ResultsWrapper(num=batch_size) + + if self._cfg.env_type == 'not_board_games': + latent_state_index_in_search_path, latent_state_index_in_batch, last_actions, virtual_to_play_batch = tree_efficientzero.batch_traverse_with_reuse( + roots, pb_c_base, pb_c_init, discount_factor, min_max_stats_lst, results, + to_play_batch, true_action_list, reuse_value_list + ) + else: + latent_state_index_in_search_path, latent_state_index_in_batch, last_actions, virtual_to_play_batch = tree_efficientzero.batch_traverse_with_reuse( + roots, pb_c_base, pb_c_init, discount_factor, min_max_stats_lst, results, + copy.deepcopy(to_play_batch), true_action_list, reuse_value_list + ) + + search_lens = results.get_search_len() + + for count, (ix, iy) in enumerate(zip(latent_state_index_in_search_path, latent_state_index_in_batch)): + if ix != -1: + latent_states.append(latent_state_batch_in_search_path[ix][iy]) + hidden_states_c_reward.append(reward_hidden_state_c_batch[ix][0][iy]) + hidden_states_h_reward.append(reward_hidden_state_h_batch[ix][0][iy]) + temp_actions.append(last_actions[count]) + temp_search_lens.append(search_lens[count]) + else: + no_inference_lst.append(iy) + if ix == 0 and last_actions[count] == true_action_list[count]: + reuse_lst.append(count) + + length = len(temp_actions) + latent_states = torch.tensor(latent_states, device=self._cfg.device) + hidden_states_c_reward = torch.tensor(hidden_states_c_reward, device=self._cfg.device).unsqueeze(0) + hidden_states_h_reward = torch.tensor(hidden_states_h_reward, device=self._cfg.device).unsqueeze(0) + temp_actions = torch.tensor(temp_actions, device=self._cfg.device).long() + + if length != 0: + network_output = model.recurrent_inference( + latent_states, (hidden_states_c_reward, hidden_states_h_reward), temp_actions + ) + + network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state) + network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits) + network_output.value = to_detach_cpu_numpy( + self.inverse_scalar_transform_handle(network_output.value)) + network_output.value_prefix = to_detach_cpu_numpy( + self.inverse_scalar_transform_handle(network_output.value_prefix)) + + network_output.reward_hidden_state = ( + network_output.reward_hidden_state[0].detach().cpu().numpy(), + network_output.reward_hidden_state[1].detach().cpu().numpy() + ) + + latent_state_batch_in_search_path.append(network_output.latent_state) + value_prefix_batch = network_output.value_prefix.reshape(-1).tolist() + value_batch = network_output.value.reshape(-1).tolist() + policy_logits_batch = network_output.policy_logits.tolist() + + reward_latent_state_batch = network_output.reward_hidden_state + assert self._cfg.lstm_horizon_len > 0 + reset_idx = (np.array(temp_search_lens) % self._cfg.lstm_horizon_len == 0) + reward_latent_state_batch[0][:, reset_idx, :] = 0 + reward_latent_state_batch[1][:, reset_idx, :] = 0 + is_reset_list = reset_idx.astype(np.int32).tolist() + reward_hidden_state_c_batch.append(reward_latent_state_batch[0]) + reward_hidden_state_h_batch.append(reward_latent_state_batch[1]) + else: + latent_state_batch_in_search_path.append([]) + value_batch, policy_logits_batch, value_prefix_batch = [], [], [] + reward_hidden_state_c_batch.append([]) + reward_hidden_state_h_batch.append([]) + assert self._cfg.lstm_horizon_len > 0 + reset_idx = (np.array(search_lens) % self._cfg.lstm_horizon_len == 0) + assert len(reset_idx) == batch_size + is_reset_list = reset_idx.astype(np.int32).tolist() + + current_latent_state_index = simulation_index + 1 + no_inference_lst.append(-1) + reuse_lst.append(-1) + tree_efficientzero.batch_backpropagate_with_reuse( + current_latent_state_index, discount_factor, value_prefix_batch, value_batch, policy_logits_batch, + min_max_stats_lst, results, is_reset_list, virtual_to_play_batch, no_inference_lst, reuse_lst, + reuse_value_list + ) + infer_sum += length + + average_infer = infer_sum / self._cfg.num_simulations + return length, average_infer + class GumbelMuZeroMCTSCtree(object): """ @@ -446,7 +678,7 @@ def __init__(self, cfg: EasyDict = None) -> None: self.inverse_scalar_transform_handle = InverseScalarTransform( self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution ) - + @classmethod def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "gmz_ctree": """ @@ -462,8 +694,8 @@ def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "g return tree_gumbel_muzero.Roots(active_collect_env_num, legal_actions) def search(self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], to_play_batch: Union[int, - List[Any]] - ) -> None: + List[Any]] + ) -> None: """ Overview: Do MCTS for a batch of roots. Parallel in model inference. \ @@ -511,12 +743,14 @@ def search(self, roots: Any, model: torch.nn.Module, latent_state_roots: List[An """ if self._cfg.env_type == 'not_board_games': latent_state_index_in_search_path, latent_state_index_in_batch, last_actions, virtual_to_play_batch = tree_gumbel_muzero.batch_traverse( - roots, self._cfg.num_simulations, self._cfg.max_num_considered_actions, discount_factor, results, to_play_batch + roots, self._cfg.num_simulations, self._cfg.max_num_considered_actions, discount_factor, + results, to_play_batch ) else: # the ``to_play_batch`` is only used in board games, here we need to deepcopy it to avoid changing the original data. latent_state_index_in_search_path, latent_state_index_in_batch, last_actions, virtual_to_play_batch = tree_gumbel_muzero.batch_traverse( - roots, self._cfg.num_simulations, self._cfg.max_num_considered_actions, discount_factor, results, copy.deepcopy(to_play_batch) + roots, self._cfg.num_simulations, self._cfg.max_num_considered_actions, discount_factor, + results, copy.deepcopy(to_play_batch) ) # obtain the states for leaf nodes diff --git a/lzero/policy/alphazero.py b/lzero/policy/alphazero.py index c0b7e5d7a..6589198a8 100644 --- a/lzero/policy/alphazero.py +++ b/lzero/policy/alphazero.py @@ -50,10 +50,10 @@ class AlphaZeroPolicy(Policy): # collect data -> update policy-> collect data -> ... # For different env, we have different episode_length, # we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor. - # If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.model_update_ratio automatically. + # If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.replay_ratio automatically. update_per_collect=None, # (float) The ratio of the collected data used for training. Only effective when ``update_per_collect`` is not None. - model_update_ratio=0.1, + replay_ratio=0.25, # (int) Minibatch size for one gradient descent. batch_size=256, # (str) Optimizer for training policy network. ['SGD', 'Adam', 'AdamW'] diff --git a/lzero/policy/efficientzero.py b/lzero/policy/efficientzero.py index 3a94baf51..9f747e0db 100644 --- a/lzero/policy/efficientzero.py +++ b/lzero/policy/efficientzero.py @@ -111,10 +111,10 @@ class EfficientZeroPolicy(MuZeroPolicy): # collect data -> update policy-> collect data -> ... # For different env, we have different episode_length, # we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor - # if we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.model_update_ratio automatically. + # if we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.replay_ratio automatically. update_per_collect=None, # (float) The ratio of the collected data used for training. Only effective when ``update_per_collect`` is not None. - model_update_ratio=0.1, + replay_ratio=0.25, # (int) Minibatch size for one gradient descent. batch_size=256, # (str) Optimizer for training policy network. ['SGD', 'Adam', 'AdamW'] @@ -165,7 +165,11 @@ class EfficientZeroPolicy(MuZeroPolicy): # (bool) Whether to use the true chance in MCTS in some environments with stochastic dynamics, such as 2048. use_ture_chance_label_in_chance_encoder=False, # (bool) Whether to add noise to roots during reanalyze process. - reanalyze_noise=False, + reanalyze_noise=True, + # (bool) Whether to reuse the root value between batch searches. + reuse_search=True, + # (bool) whether to use the pure policy to collect data. If False, use the MCTS guided with policy. + collect_with_pure_policy=False, # ****** Priority ****** # (bool) Whether to use priority when sampling training data from the buffer. @@ -582,58 +586,77 @@ def _forward_collect( policy_logits = policy_logits.detach().cpu().numpy().tolist() legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num)] - # the only difference between collect and eval is the dirichlet noise. - noises = [ - np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) - ).astype(np.float32).tolist() for j in range(active_collect_env_num) - ] - if self._cfg.mcts_ctree: - # cpp mcts_tree - roots = MCTSCtree.roots(active_collect_env_num, legal_actions) - else: - # python mcts_tree - roots = MCTSPtree.roots(active_collect_env_num, legal_actions) - roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_roots, policy_logits, to_play) - self._mcts_collect.search( - roots, self._collect_model, latent_state_roots, reward_hidden_state_roots, to_play - ) - roots_visit_count_distributions = roots.get_distributions() - roots_values = roots.get_values() # shape: {list: batch_size} + if not self._cfg.collect_with_pure_policy: + # the only difference between collect and eval is the dirichlet noise. + noises = [ + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) + ).astype(np.float32).tolist() for j in range(active_collect_env_num) + ] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(active_collect_env_num, legal_actions) + else: + # python mcts_tree + roots = MCTSPtree.roots(active_collect_env_num, legal_actions) + roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_roots, policy_logits, to_play) + self._mcts_collect.search( + roots, self._collect_model, latent_state_roots, reward_hidden_state_roots, to_play + ) - data_id = [i for i in range(active_collect_env_num)] - output = {i: None for i in data_id} - if ready_env_id is None: - ready_env_id = np.arange(active_collect_env_num) + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() # shape: {list: batch_size} + + data_id = [i for i in range(active_collect_env_num)] + output = {i: None for i in data_id} + if ready_env_id is None: + ready_env_id = np.arange(active_collect_env_num) + + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + if self._cfg.eps.eps_greedy_exploration_in_collect: + # eps-greedy collect + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self._collect_mcts_temperature, deterministic=True + ) + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + if np.random.rand() < self.collect_epsilon: + action = np.random.choice(legal_actions[i]) + else: + # normal collect + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # the index within the legal action set, rather than the index in the entire action set. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self._collect_mcts_temperature, deterministic=False + ) + # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + else: + data_id = [i for i in range(active_collect_env_num)] + output = {i: None for i in data_id} - for i, env_id in enumerate(ready_env_id): - distributions, value = roots_visit_count_distributions[i], roots_values[i] - if self._cfg.eps.eps_greedy_exploration_in_collect: - # eps-greedy collect - action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( - distributions, temperature=self._collect_mcts_temperature, deterministic=True - ) - action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] - if np.random.rand() < self.collect_epsilon: - action = np.random.choice(legal_actions[i]) - else: - # normal collect - # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents - # the index within the legal action set, rather than the index in the entire action set. - action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( - distributions, temperature=self._collect_mcts_temperature, deterministic=False - ) - # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. - action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] - output[env_id] = { - 'action': action, - 'visit_count_distributions': distributions, - 'visit_count_distribution_entropy': visit_count_distribution_entropy, - 'searched_value': value, - 'predicted_value': pred_values[i], - 'predicted_policy_logits': policy_logits[i], - } + if ready_env_id is None: + ready_env_id = np.arange(active_collect_env_num) + for i, env_id in enumerate(ready_env_id): + policy_values = torch.softmax(torch.tensor([policy_logits[i][a] for a in legal_actions[i]]), dim=0).tolist() + policy_values = policy_values / np.sum(policy_values) + action_index_in_legal_action_set = np.random.choice(len(legal_actions[i]), p=policy_values) + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + output[env_id] = { + 'action': action, + 'searched_value': pred_values[i], + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } return output def _init_eval(self) -> None: diff --git a/lzero/policy/gumbel_alphazero.py b/lzero/policy/gumbel_alphazero.py index 0dcb53982..dae84e2c5 100644 --- a/lzero/policy/gumbel_alphazero.py +++ b/lzero/policy/gumbel_alphazero.py @@ -51,10 +51,10 @@ class GumbelAlphaZeroPolicy(Policy): # collect data -> update policy-> collect data -> ... # For different env, we have different episode_length, # we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor. - # If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.model_update_ratio automatically. + # If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.replay_ratio automatically. update_per_collect=None, # (float) The ratio of the collected data used for training. Only effective when ``update_per_collect`` is not None. - model_update_ratio=0.1, + replay_ratio=0.25, # (int) Minibatch size for one gradient descent. batch_size=256, # (str) Optimizer for training policy network. ['SGD', 'Adam', 'AdamW'] diff --git a/lzero/policy/gumbel_muzero.py b/lzero/policy/gumbel_muzero.py index c04fbb3c3..aeecfae67 100644 --- a/lzero/policy/gumbel_muzero.py +++ b/lzero/policy/gumbel_muzero.py @@ -111,10 +111,10 @@ class GumbelMuZeroPolicy(MuZeroPolicy): # Bigger "update_per_collect" means bigger off-policy. # collect data -> update policy-> collect data -> ... # For different env, we have different episode_length, - # If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.model_update_ratio automatically. + # If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.replay_ratio automatically. update_per_collect=None, # (float) The ratio of the collected data used for training. Only effective when ``update_per_collect`` is not None. - model_update_ratio=0.1, + replay_ratio=0.25, # (int) Minibatch size for one gradient descent. batch_size=256, # (str) Optimizer for training policy network. ['SGD' or 'Adam'] @@ -162,7 +162,7 @@ class GumbelMuZeroPolicy(MuZeroPolicy): # The larger the value, the more exploration. This value is only used when manual_temperature_decay=False. fixed_temperature_value=0.25, # (bool) Whether to add noise to roots during reanalyze process. - reanalyze_noise=False, + reanalyze_noise=True, # ****** Priority ****** # (bool) Whether to use priority when sampling training data from the buffer. diff --git a/lzero/policy/muzero.py b/lzero/policy/muzero.py index 0acc66b07..66736047a 100644 --- a/lzero/policy/muzero.py +++ b/lzero/policy/muzero.py @@ -87,7 +87,7 @@ class MuZeroPolicy(Policy): # (bool) Whether to monitor extra statistics in tensorboard. monitor_extra_statistics=True, # (int) The transition number of one ``GameSegment``. - game_segment_length=200, + game_segment_length=400, # (bool): Indicates whether to perform an offline evaluation of the checkpoint (ckpt). # If set to True, the checkpoint will be evaluated after the training process is complete. # IMPORTANT: Setting eval_offline to True requires configuring the saving of checkpoints to align with the evaluation frequency. @@ -114,10 +114,10 @@ class MuZeroPolicy(Policy): # collect data -> update policy-> collect data -> ... # For different env, we have different episode_length, # we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor. - # If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.model_update_ratio automatically. + # If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.replay_ratio automatically. update_per_collect=None, # (float) The ratio of the collected data used for training. Only effective when ``update_per_collect`` is not None. - model_update_ratio=0.1, + replay_ratio=0.25, # (int) Minibatch size for one gradient descent. batch_size=256, # (str) Optimizer for training policy network. ['SGD', 'Adam'] @@ -169,7 +169,11 @@ class MuZeroPolicy(Policy): # (bool) Whether to use the true chance in MCTS in some environments with stochastic dynamics, such as 2048. use_ture_chance_label_in_chance_encoder=False, # (bool) Whether to add noise to roots during reanalyze process. - reanalyze_noise=False, + reanalyze_noise=True, + # (bool) Whether to reuse the root value between batch searches. + reuse_search=True, + # (bool) whether to use the pure policy to collect data. If False, use the MCTS guided with policy. + collect_with_pure_policy=False, # ****** Priority ****** # (bool) Whether to use priority when sampling training data from the buffer. @@ -446,11 +450,11 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in self._cfg.policy_entropy_loss_weight * policy_entropy_loss ) weighted_total_loss = (weights * loss).mean() - gradient_scale = 1 / self._cfg.num_unroll_steps weighted_total_loss.register_hook(lambda grad: grad * gradient_scale) self._optimizer.zero_grad() weighted_total_loss.backward() + if self._cfg.multi_gpu: self.sync_gradients(self._learn_model) total_grad_norm_before_clip = torch.nn.utils.clip_grad_norm_(self._learn_model.parameters(), @@ -556,58 +560,78 @@ def _forward_collect( policy_logits = policy_logits.detach().cpu().numpy().tolist() legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num)] - # the only difference between collect and eval is the dirichlet noise - noises = [ - np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) - ).astype(np.float32).tolist() for j in range(active_collect_env_num) - ] - if self._cfg.mcts_ctree: - # cpp mcts_tree - roots = MCTSCtree.roots(active_collect_env_num, legal_actions) - else: - # python mcts_tree - roots = MCTSPtree.roots(active_collect_env_num, legal_actions) - - roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) - self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play) - # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` - roots_visit_count_distributions = roots.get_distributions() - roots_values = roots.get_values() # shape: {list: batch_size} + if not self._cfg.collect_with_pure_policy: + # the only difference between collect and eval is the dirichlet noise + noises = [ + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) + ).astype(np.float32).tolist() for j in range(active_collect_env_num) + ] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(active_collect_env_num, legal_actions) + else: + # python mcts_tree + roots = MCTSPtree.roots(active_collect_env_num, legal_actions) + + roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) + self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play) + + # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() # shape: {list: batch_size} + + data_id = [i for i in range(active_collect_env_num)] + output = {i: None for i in data_id} + + if ready_env_id is None: + ready_env_id = np.arange(active_collect_env_num) + + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + if self._cfg.eps.eps_greedy_exploration_in_collect: + # eps greedy collect + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self._collect_mcts_temperature, deterministic=True + ) + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + if np.random.rand() < self.collect_epsilon: + action = np.random.choice(legal_actions[i]) + else: + # normal collect + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # the index within the legal action set, rather than the index in the entire action set. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self._collect_mcts_temperature, deterministic=False + ) + # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + else: + data_id = [i for i in range(active_collect_env_num)] + output = {i: None for i in data_id} - data_id = [i for i in range(active_collect_env_num)] - output = {i: None for i in data_id} + if ready_env_id is None: + ready_env_id = np.arange(active_collect_env_num) - if ready_env_id is None: - ready_env_id = np.arange(active_collect_env_num) - - for i, env_id in enumerate(ready_env_id): - distributions, value = roots_visit_count_distributions[i], roots_values[i] - if self._cfg.eps.eps_greedy_exploration_in_collect: - # eps greedy collect - action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( - distributions, temperature=self._collect_mcts_temperature, deterministic=True - ) + for i, env_id in enumerate(ready_env_id): + policy_values = torch.softmax(torch.tensor([policy_logits[i][a] for a in legal_actions[i]]), dim=0).tolist() + policy_values = policy_values / np.sum(policy_values) + action_index_in_legal_action_set = np.random.choice(len(legal_actions[i]), p=policy_values) action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] - if np.random.rand() < self.collect_epsilon: - action = np.random.choice(legal_actions[i]) - else: - # normal collect - # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents - # the index within the legal action set, rather than the index in the entire action set. - action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( - distributions, temperature=self._collect_mcts_temperature, deterministic=False - ) - # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. - action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] - output[env_id] = { - 'action': action, - 'visit_count_distributions': distributions, - 'visit_count_distribution_entropy': visit_count_distribution_entropy, - 'searched_value': value, - 'predicted_value': pred_values[i], - 'predicted_policy_logits': policy_logits[i], - } + output[env_id] = { + 'action': action, + 'searched_value': pred_values[i], + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } return output diff --git a/lzero/policy/sampled_alphazero.py b/lzero/policy/sampled_alphazero.py index baa1d18b0..6c228602a 100644 --- a/lzero/policy/sampled_alphazero.py +++ b/lzero/policy/sampled_alphazero.py @@ -53,10 +53,10 @@ class SampledAlphaZeroPolicy(Policy): # collect data -> update policy-> collect data -> ... # For different env, we have different episode_length, # we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor. - # If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.model_update_ratio automatically. + # If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.replay_ratio automatically. update_per_collect=None, # (float) The ratio of the collected data used for training. Only effective when ``update_per_collect`` is not None. - model_update_ratio=0.1, + replay_ratio=0.25, # (int) Minibatch size for one gradient descent. batch_size=256, # (str) Optimizer for training policy network. ['SGD', 'Adam', 'AdamW'] diff --git a/lzero/policy/sampled_efficientzero.py b/lzero/policy/sampled_efficientzero.py index 7003f6808..2c813bea4 100644 --- a/lzero/policy/sampled_efficientzero.py +++ b/lzero/policy/sampled_efficientzero.py @@ -118,10 +118,10 @@ class SampledEfficientZeroPolicy(MuZeroPolicy): # collect data -> update policy-> collect data -> ... # For different env, we have different episode_length, # we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor. - # If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.model_update_ratio automatically. + # If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.replay_ratio automatically. update_per_collect=None, # (float) The ratio of the collected data used for training. Only effective when ``update_per_collect`` is not None. - model_update_ratio=0.1, + replay_ratio=0.25, # (int) Minibatch size for one gradient descent. batch_size=256, # (str) Optimizer for training policy network. ['SGD', 'Adam', 'AdamW'] @@ -180,7 +180,7 @@ class SampledEfficientZeroPolicy(MuZeroPolicy): # (bool) Whether to use the true chance in MCTS in some environments with stochastic dynamics, such as 2048. use_ture_chance_label_in_chance_encoder=False, # (bool) Whether to add noise to roots during reanalyze process. - reanalyze_noise=False, + reanalyze_noise=True, # ****** Priority ****** # (bool) Whether to use priority when sampling training data from the buffer. diff --git a/lzero/policy/stochastic_muzero.py b/lzero/policy/stochastic_muzero.py index 57d52ad1d..069eb96eb 100644 --- a/lzero/policy/stochastic_muzero.py +++ b/lzero/policy/stochastic_muzero.py @@ -109,7 +109,7 @@ class StochasticMuZeroPolicy(MuZeroPolicy): # we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor update_per_collect=100, # (float) The ratio of the collected data used for training. Only effective when ``update_per_collect`` is not None. - model_update_ratio=0.1, + replay_ratio=0.25, # (int) Minibatch size for one gradient descent. batch_size=256, # (str) Optimizer for training policy network. ['SGD', 'Adam'] @@ -163,7 +163,7 @@ class StochasticMuZeroPolicy(MuZeroPolicy): # (bool) Whether to use the true chance in MCTS. If False, use the predicted chance. use_ture_chance_label_in_chance_encoder=False, # (bool) Whether to add noise to roots during reanalyze process. - reanalyze_noise=False, + reanalyze_noise=True, # ****** Priority ****** # (bool) Whether to use priority when sampling training data from the buffer. diff --git a/lzero/policy/tests/config/atari_muzero_config_for_test.py b/lzero/policy/tests/config/atari_muzero_config_for_test.py index 9ed0fba8a..c0bc7ff53 100644 --- a/lzero/policy/tests/config/atari_muzero_config_for_test.py +++ b/lzero/policy/tests/config/atari_muzero_config_for_test.py @@ -31,7 +31,7 @@ atari_muzero_config = dict( exp_name= - f'data_mz_ctree/{env_id[:-14]}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', + f'data_muzero/{env_id[:-14]}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', env=dict( stop_value=int(1e6), env_id=env_id, diff --git a/lzero/policy/tests/config/cartpole_muzero_config_for_test.py b/lzero/policy/tests/config/cartpole_muzero_config_for_test.py index deaf1fddd..22fa1518a 100644 --- a/lzero/policy/tests/config/cartpole_muzero_config_for_test.py +++ b/lzero/policy/tests/config/cartpole_muzero_config_for_test.py @@ -16,7 +16,7 @@ # ============================================================== cartpole_muzero_config = dict( - exp_name=f'data_mz_ctree/cartpole_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', + exp_name=f'data_muzero/cartpole_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', env=dict( env_id='CartPole-v0', continuous=False, diff --git a/lzero/worker/__init__.py b/lzero/worker/__init__.py index 3000ee23f..b74e1e745 100644 --- a/lzero/worker/__init__.py +++ b/lzero/worker/__init__.py @@ -1,4 +1,4 @@ from .alphazero_collector import AlphaZeroCollector from .alphazero_evaluator import AlphaZeroEvaluator from .muzero_collector import MuZeroCollector -from .muzero_evaluator import MuZeroEvaluator \ No newline at end of file +from .muzero_evaluator import MuZeroEvaluator diff --git a/lzero/worker/muzero_collector.py b/lzero/worker/muzero_collector.py index 8bb5f51ba..8ebcbdb5d 100644 --- a/lzero/worker/muzero_collector.py +++ b/lzero/worker/muzero_collector.py @@ -80,6 +80,7 @@ def __init__( self._tb_logger = None self.policy_config = policy_config + self.collect_with_pure_policy = self.policy_config.collect_with_pure_policy self.reset(policy, env) @@ -210,7 +211,7 @@ def _compute_priorities(self, i: int, pred_values_lst: List[float], search_value if self.policy_config.use_priority: # Calculate priorities. The priorities are the L1 losses between the predicted # values and the search values. We use 'none' as the reduction parameter, which - # means the loss is calculated for each element individually, instead of being summed or averaged. + # means the loss is calculated for each element individually, instead of being summed or averaged. # A small constant (1e-6) is added to the results to avoid zero priorities. This # is done because zero priorities could potentially cause issues in some scenarios. pred_values = torch.from_numpy(np.array(pred_values_lst[i])).to(self.policy_config.device).float().view(-1) @@ -304,7 +305,8 @@ def pad_and_save_last_trajectory(self, i: int, last_game_segments: List[GameSegm def collect(self, n_episode: Optional[int] = None, train_iter: int = 0, - policy_kwargs: Optional[dict] = None) -> List[Any]: + policy_kwargs: Optional[dict] = None, + collect_with_pure_policy: bool = False) -> List[Any]: """ Overview: Collect `n_episode` episodes of data with policy_kwargs, trained for `train_iter` iterations. @@ -312,6 +314,7 @@ def collect(self, - n_episode (:obj:`Optional[int]`): Number of episodes to collect. - train_iter (:obj:`int`): Number of training iterations completed so far. - policy_kwargs (:obj:`Optional[dict]`): Additional keyword arguments for the policy. + - collect_with_pure_policy (:obj:`bool`): Whether to collect data using pure policy without MCTS. Returns: - return_data (:obj:`List[Any]`): Collected data in the form of a list. """ @@ -389,6 +392,8 @@ def collect(self, ready_env_id = set() remain_episode = n_episode + if collect_with_pure_policy: + temp_visit_list = [0.0 for i in range(self._env.action_space.n)] while True: with self._timer: @@ -422,7 +427,6 @@ def collect(self, policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon) actions_no_env_id = {k: v['action'] for k, v in policy_output.items()} - distributions_dict_no_env_id = {k: v['visit_count_distributions'] for k, v in policy_output.items()} if self.policy_config.sampled_algo: root_sampled_actions_dict_no_env_id = { k: v['root_sampled_actions'] @@ -430,39 +434,53 @@ def collect(self, } value_dict_no_env_id = {k: v['searched_value'] for k, v in policy_output.items()} pred_value_dict_no_env_id = {k: v['predicted_value'] for k, v in policy_output.items()} - visit_entropy_dict_no_env_id = { - k: v['visit_count_distribution_entropy'] - for k, v in policy_output.items() - } - - if self.policy_config.gumbel_algo: - improved_policy_dict_no_env_id = {k: v['improved_policy_probs'] for k, v in policy_output.items()} - completed_value_no_env_id = { - k: v['roots_completed_value'] + + if not collect_with_pure_policy: + distributions_dict_no_env_id = {k: v['visit_count_distributions'] for k, v in policy_output.items()} + if self.policy_config.sampled_algo: + root_sampled_actions_dict_no_env_id = { + k: v['root_sampled_actions'] + for k, v in policy_output.items() + } + visit_entropy_dict_no_env_id = { + k: v['visit_count_distribution_entropy'] for k, v in policy_output.items() } + if self.policy_config.gumbel_algo: + improved_policy_dict_no_env_id = {k: v['improved_policy_probs'] for k, v in + policy_output.items()} + completed_value_no_env_id = { + k: v['roots_completed_value'] + for k, v in policy_output.items() + } + # TODO(pu): subprocess actions = {} - distributions_dict = {} - if self.policy_config.sampled_algo: - root_sampled_actions_dict = {} value_dict = {} pred_value_dict = {} - visit_entropy_dict = {} - if self.policy_config.gumbel_algo: - improved_policy_dict = {} - completed_value_dict = {} + + if not collect_with_pure_policy: + distributions_dict = {} + if self.policy_config.sampled_algo: + root_sampled_actions_dict = {} + visit_entropy_dict = {} + if self.policy_config.gumbel_algo: + improved_policy_dict = {} + completed_value_dict = {} + for index, env_id in enumerate(ready_env_id): actions[env_id] = actions_no_env_id.pop(index) - distributions_dict[env_id] = distributions_dict_no_env_id.pop(index) - if self.policy_config.sampled_algo: - root_sampled_actions_dict[env_id] = root_sampled_actions_dict_no_env_id.pop(index) value_dict[env_id] = value_dict_no_env_id.pop(index) pred_value_dict[env_id] = pred_value_dict_no_env_id.pop(index) - visit_entropy_dict[env_id] = visit_entropy_dict_no_env_id.pop(index) - if self.policy_config.gumbel_algo: - improved_policy_dict[env_id] = improved_policy_dict_no_env_id.pop(index) - completed_value_dict[env_id] = completed_value_no_env_id.pop(index) + + if not collect_with_pure_policy: + distributions_dict[env_id] = distributions_dict_no_env_id.pop(index) + if self.policy_config.sampled_algo: + root_sampled_actions_dict[env_id] = root_sampled_actions_dict_no_env_id.pop(index) + visit_entropy_dict[env_id] = visit_entropy_dict_no_env_id.pop(index) + if self.policy_config.gumbel_algo: + improved_policy_dict[env_id] = improved_policy_dict_no_env_id.pop(index) + completed_value_dict[env_id] = completed_value_no_env_id.pop(index) # ============================================================== # Interact with env. @@ -483,15 +501,19 @@ def collect(self, continue obs, reward, done, info = timestep.obs, timestep.reward, timestep.done, timestep.info - if self.policy_config.sampled_algo: - game_segments[env_id].store_search_stats( - distributions_dict[env_id], value_dict[env_id], root_sampled_actions_dict[env_id] - ) - elif self.policy_config.gumbel_algo: - game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id], - improved_policy=improved_policy_dict[env_id]) + if collect_with_pure_policy: + game_segments[env_id].store_search_stats(temp_visit_list, 0) else: - game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id]) + if self.policy_config.sampled_algo: + game_segments[env_id].store_search_stats( + distributions_dict[env_id], value_dict[env_id], root_sampled_actions_dict[env_id] + ) + elif self.policy_config.gumbel_algo: + game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id], + improved_policy=improved_policy_dict[env_id]) + else: + game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id]) + # append a transition tuple, including a_t, o_{t+1}, r_{t}, action_mask_{t}, to_play_{t} # in ``game_segments[env_id].init``, we have appended o_{t} in ``self.obs_segment`` if self.policy_config.use_ture_chance_label_in_chance_encoder: @@ -517,9 +539,10 @@ def collect(self, else: dones[env_id] = done - visit_entropies_lst[env_id] += visit_entropy_dict[env_id] - if self.policy_config.gumbel_algo: - completed_value_lst[env_id] += np.mean(np.array(completed_value_dict[env_id])) + if not collect_with_pure_policy: + visit_entropies_lst[env_id] += visit_entropy_dict[env_id] + if self.policy_config.gumbel_algo: + completed_value_lst[env_id] += np.mean(np.array(completed_value_dict[env_id])) eps_steps_lst[env_id] += 1 total_transitions += 1 @@ -527,7 +550,7 @@ def collect(self, if self.policy_config.use_priority: pred_values_lst[env_id].append(pred_value_dict[env_id]) search_values_lst[env_id].append(value_dict[env_id]) - if self.policy_config.gumbel_algo: + if self.policy_config.gumbel_algo and not collect_with_pure_policy: improved_policy_lst[env_id].append(improved_policy_dict[env_id]) # append the newest obs @@ -550,7 +573,7 @@ def collect(self, priorities = self._compute_priorities(env_id, pred_values_lst, search_values_lst) pred_values_lst[env_id] = [] search_values_lst[env_id] = [] - if self.policy_config.gumbel_algo: + if self.policy_config.gumbel_algo and not collect_with_pure_policy: improved_policy_lst[env_id] = [] # the current game_segments become last_game_segment @@ -575,10 +598,12 @@ def collect(self, 'reward': reward, 'time': self._env_info[env_id]['time'], 'step': self._env_info[env_id]['step'], - 'visit_entropy': visit_entropies_lst[env_id] / eps_steps_lst[env_id], } - if self.policy_config.gumbel_algo: - info['completed_value'] = completed_value_lst[env_id] / eps_steps_lst[env_id] + if not collect_with_pure_policy: + info['visit_entropy'] = visit_entropies_lst[env_id] / eps_steps_lst[env_id] + if self.policy_config.gumbel_algo: + info['completed_value'] = completed_value_lst[env_id] / eps_steps_lst[env_id] + collected_episode += 1 self._episode_info.append(info) @@ -650,7 +675,8 @@ def collect(self, # log self_play_moves_max = max(self_play_moves_max, eps_steps_lst[env_id]) - self_play_visit_entropy.append(visit_entropies_lst[env_id] / eps_steps_lst[env_id]) + if not collect_with_pure_policy: + self_play_visit_entropy.append(visit_entropies_lst[env_id] / eps_steps_lst[env_id]) self_play_moves += eps_steps_lst[env_id] self_play_episodes += 1 @@ -707,7 +733,10 @@ def _output_log(self, train_iter: int) -> None: envstep_count = sum([d['step'] for d in self._episode_info]) duration = sum([d['time'] for d in self._episode_info]) episode_reward = [d['reward'] for d in self._episode_info] - visit_entropy = [d['visit_entropy'] for d in self._episode_info] + if not self.collect_with_pure_policy: + visit_entropy = [d['visit_entropy'] for d in self._episode_info] + else: + visit_entropy = [0.0] if self.policy_config.gumbel_algo: completed_value = [d['completed_value'] for d in self._episode_info] self._total_duration += duration @@ -726,7 +755,6 @@ def _output_log(self, train_iter: int) -> None: 'total_episode_count': self._total_episode_count, 'total_duration': self._total_duration, 'visit_entropy': np.mean(visit_entropy), - # 'each_reward': episode_reward, } if self.policy_config.gumbel_algo: info['completed_value'] = np.mean(completed_value) @@ -738,4 +766,4 @@ def _output_log(self, train_iter: int) -> None: self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) if k in ['total_envstep_count']: continue - self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count) + self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 4b3ce597f..ad1020b27 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,5 @@ DI-engine>=0.4.7 gymnasium[atari] -moviepy numpy>=1.22.4 pympler bsuite diff --git a/zoo/atari/config/atari_gumbel_muzero_config.py b/zoo/atari/config/atari_gumbel_muzero_config.py index 918a93236..4e3591727 100644 --- a/zoo/atari/config/atari_gumbel_muzero_config.py +++ b/zoo/atari/config/atari_gumbel_muzero_config.py @@ -31,7 +31,7 @@ atari_gumbel_muzero_config = dict( exp_name= - f'data_mz_ctree/{env_id[:-14]}_gumbel_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', + f'data_muzero/{env_id[:-14]}_gumbel_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', env=dict( stop_value=int(1e6), env_id=env_id, diff --git a/zoo/atari/config/atari_muzero_config.py b/zoo/atari/config/atari_muzero_config.py index 1e5d49ce2..6d4d26c1c 100644 --- a/zoo/atari/config/atari_muzero_config.py +++ b/zoo/atari/config/atari_muzero_config.py @@ -31,7 +31,7 @@ # ============================================================== atari_muzero_config = dict( - exp_name=f'data_mz_ctree/{env_id[:-14]}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', + exp_name=f'data_muzero/{env_id[:-14]}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', env=dict( stop_value=int(1e6), env_id=env_id, @@ -52,7 +52,7 @@ norm_type='BN', ), cuda=True, - reanalyze_noise=False, + reanalyze_noise=True, env_type='not_board_games', game_segment_length=400, random_collect_episode_num=0, diff --git a/zoo/atari/config/atari_muzero_multigpu_ddp_config.py b/zoo/atari/config/atari_muzero_multigpu_ddp_config.py index 4c80b7f90..b447f934d 100644 --- a/zoo/atari/config/atari_muzero_multigpu_ddp_config.py +++ b/zoo/atari/config/atari_muzero_multigpu_ddp_config.py @@ -32,7 +32,7 @@ # ============================================================== atari_muzero_config = dict( - exp_name=f'data_mz_ctree/{env_id[:-14]}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_ddp_{gpu_num}gpu_seed0', + exp_name=f'data_muzero/{env_id[:-14]}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_ddp_{gpu_num}gpu_seed0', env=dict( stop_value=int(1e6), env_id=env_id, diff --git a/zoo/atari/config/atari_rezero_ez_config.py b/zoo/atari/config/atari_rezero_ez_config.py new file mode 100644 index 000000000..a03ef8d10 --- /dev/null +++ b/zoo/atari/config/atari_rezero_ez_config.py @@ -0,0 +1,95 @@ +from easydict import EasyDict +# options={'PongNoFrameskip-v4', 'QbertNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SpaceInvadersNoFrameskip-v4', 'BreakoutNoFrameskip-v4', ...} +env_id = 'PongNoFrameskip-v4' + +if env_id == 'PongNoFrameskip-v4': + action_space_size = 6 +elif env_id == 'QbertNoFrameskip-v4': + action_space_size = 6 +elif env_id == 'MsPacmanNoFrameskip-v4': + action_space_size = 9 +elif env_id == 'SpaceInvadersNoFrameskip-v4': + action_space_size = 6 +elif env_id == 'BreakoutNoFrameskip-v4': + action_space_size = 4 + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 50 +update_per_collect = None +batch_size = 256 +max_env_step = int(5e5) +use_priority = False +# ============= The key different params for ReZero ============= +reuse_search = True +collect_with_pure_policy = True +buffer_reanalyze_freq = 1 +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +atari_efficientzero_config = dict( + exp_name=f'data_rezero_ez/{env_id[:-14]}_rezero_efficientzero_ns{num_simulations}_upc{update_per_collect}_brf{buffer_reanalyze_freq}_seed0', + env=dict( + env_id=env_id, + obs_shape=(4, 96, 96), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=(4, 96, 96), + frame_stack_num=4, + action_space_size=action_space_size, + downsample=True, + discrete_action_encoding_type='one_hot', + norm_type='BN', + ), + cuda=True, + env_type='not_board_games', + use_augmentation=True, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='SGD', + lr_piecewise_constant_decay=True, + learning_rate=0.2, + num_simulations=num_simulations, + reanalyze_ratio=0, # NOTE: for rezero, reanalyze_ratio should be 0. + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e6), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + use_priority=use_priority, + # ============= The key different params for ReZero ============= + reuse_search=reuse_search, + collect_with_pure_policy=collect_with_pure_policy, + buffer_reanalyze_freq=buffer_reanalyze_freq, + ), +) +atari_efficientzero_config = EasyDict(atari_efficientzero_config) +main_config = atari_efficientzero_config + +atari_efficientzero_create_config = dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='efficientzero', + import_names=['lzero.policy.efficientzero'], + ), +) +atari_efficientzero_create_config = EasyDict(atari_efficientzero_create_config) +create_config = atari_efficientzero_create_config + +if __name__ == "__main__": + from lzero.entry import train_rezero + train_rezero([main_config, create_config], seed=0, max_env_step=max_env_step) diff --git a/zoo/atari/config/atari_rezero_mz_config.py b/zoo/atari/config/atari_rezero_mz_config.py new file mode 100644 index 000000000..a1d76e36c --- /dev/null +++ b/zoo/atari/config/atari_rezero_mz_config.py @@ -0,0 +1,105 @@ +from easydict import EasyDict + +# options={'PongNoFrameskip-v4', 'QbertNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SpaceInvadersNoFrameskip-v4', 'BreakoutNoFrameskip-v4', ...} +env_id = 'PongNoFrameskip-v4' + +if env_id == 'PongNoFrameskip-v4': + action_space_size = 6 +elif env_id == 'QbertNoFrameskip-v4': + action_space_size = 6 +elif env_id == 'MsPacmanNoFrameskip-v4': + action_space_size = 9 +elif env_id == 'SpaceInvadersNoFrameskip-v4': + action_space_size = 6 +elif env_id == 'BreakoutNoFrameskip-v4': + action_space_size = 4 + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 50 +update_per_collect = None +batch_size = 256 +replay_ratio = 0.25 +max_env_step = int(5e5) +use_priority = False + +# ============= The key different params for ReZero ============= +reuse_search = True +collect_with_pure_policy = True +buffer_reanalyze_freq = 1 +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +atari_muzero_config = dict( + exp_name=f'data_rezero_mz/{env_id[:-14]}_rezero_muzero_ns{num_simulations}_upc{update_per_collect}_brf{buffer_reanalyze_freq}_seed0', + env=dict( + stop_value=int(1e6), + env_id=env_id, + obs_shape=(4, 96, 96), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=(4, 96, 96), + frame_stack_num=4, + action_space_size=action_space_size, + downsample=True, + self_supervised_learning_loss=True, + discrete_action_encoding_type='one_hot', + norm_type='BN', + ), + cuda=True, + env_type='not_board_games', + use_augmentation=True, + update_per_collect=update_per_collect, + replay_ratio=replay_ratio, + batch_size=batch_size, + optim_type='SGD', + lr_piecewise_constant_decay=True, + learning_rate=0.2, + num_simulations=num_simulations, + ssl_loss_weight=2, + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e6), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + use_priority=use_priority, + # ============= The key different params for ReZero ============= + reuse_search=reuse_search, + collect_with_pure_policy=collect_with_pure_policy, + buffer_reanalyze_freq=buffer_reanalyze_freq, + ), +) +atari_muzero_config = EasyDict(atari_muzero_config) +main_config = atari_muzero_config + +atari_muzero_create_config = dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='muzero', + import_names=['lzero.policy.muzero'], + ), + collector=dict( + type='episode_muzero', + import_names=['lzero.worker.muzero_collector'], + ) +) +atari_muzero_create_config = EasyDict(atari_muzero_create_config) +create_config = atari_muzero_create_config + +if __name__ == "__main__": + from lzero.entry import train_rezero + train_rezero([main_config, create_config], seed=0, max_env_step=max_env_step) diff --git a/zoo/atari/entry/test_memory_usage.py b/zoo/atari/entry/test_memory_usage.py new file mode 100644 index 000000000..cfb3e9d5d --- /dev/null +++ b/zoo/atari/entry/test_memory_usage.py @@ -0,0 +1,103 @@ +from lzero.mcts import ReZeroMZGameBuffer as GameBuffer +from zoo.atari.config.atari_rezero_mz_config import main_config, create_config +import torch +import numpy as np +from ding.config import compile_config +from ding.policy import create_policy +from tensorboardX import SummaryWriter +import psutil + + +def get_memory_usage(): + """ + Get the current memory usage of the process. + + Returns: + int: Memory usage in bytes. + """ + process = psutil.Process() + memory_info = process.memory_info() + return memory_info.rss + + +def initialize_policy(cfg, create_cfg): + """ + Initialize the policy based on the given configuration. + + Args: + cfg (Config): Main configuration object. + create_cfg (Config): Creation configuration object. + + Returns: + Policy: Initialized policy object. + """ + if cfg.policy.cuda and torch.cuda.is_available(): + cfg.policy.device = 'cuda' + else: + cfg.policy.device = 'cpu' + + cfg = compile_config(cfg, seed=0, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + policy = create_policy(cfg.policy, model=None, enable_field=['learn', 'collect', 'eval']) + + model_path = '{template_path}/iteration_20000.pth.tar' + if model_path is not None: + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) + + return policy, cfg + + +def run_memory_test(replay_buffer, policy, writer, sample_batch_size): + """ + Run memory usage test for sampling from the replay buffer. + + Args: + replay_buffer (GameBuffer): The replay buffer to sample from. + policy (Policy): The policy object. + writer (SummaryWriter): TensorBoard summary writer. + sample_batch_size (int): The base batch size for sampling. + """ + for i in range(2): + initial_memory = get_memory_usage() + print(f"Initial memory usage: {initial_memory} bytes") + + replay_buffer.sample(sample_batch_size * (i + 1), policy) + + final_memory = get_memory_usage() + memory_cost = final_memory - initial_memory + + print(f"Memory usage after sampling: {final_memory} bytes") + print(f"Memory cost of sampling: {float(memory_cost) / 1e9:.2f} GB") + + writer.add_scalar("Sampling Memory Usage (GB)", float(memory_cost) / 1e9, i + 1) + + # Reset counters + replay_buffer.compute_target_re_time = 0 + replay_buffer.origin_search_time = 0 + replay_buffer.reuse_search_time = 0 + replay_buffer.active_root_num = 0 + + +def main(): + """ + Main function to run the memory usage test. + """ + cfg, create_cfg = main_config, create_config + policy, cfg = initialize_policy(cfg, create_cfg) + + replay_buffer = GameBuffer(cfg.policy) + + # Load and push data to the replay buffer + data = np.load('{template_path}/collected_data.npy', allow_pickle=True) + for _ in range(50): + replay_buffer.push_game_segments(data) + + log_dir = "logs/memory_test2" + writer = SummaryWriter(log_dir) + + run_memory_test(replay_buffer, policy, writer, sample_batch_size=25600) + + writer.close() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/zoo/board_games/connect4/config/connect4_muzero_bot_mode_config.py b/zoo/board_games/connect4/config/connect4_muzero_bot_mode_config.py index 72491bf3b..091ba98cb 100644 --- a/zoo/board_games/connect4/config/connect4_muzero_bot_mode_config.py +++ b/zoo/board_games/connect4/config/connect4_muzero_bot_mode_config.py @@ -16,7 +16,7 @@ # ============================================================== connect4_muzero_config = dict( - exp_name=f'data_mz_ctree/connect4_play-with-bot-mode_seed0', + exp_name=f'data_muzero/connect4_play-with-bot-mode_seed0', env=dict( battle_mode='play_with_bot_mode', bot_action_type='rule', @@ -79,5 +79,4 @@ if __name__ == "__main__": from lzero.entry import train_muzero - - train_muzero([main_config, create_config], seed=1, max_env_step=max_env_step) \ No newline at end of file + train_muzero([main_config, create_config], seed=0, max_env_step=max_env_step) \ No newline at end of file diff --git a/zoo/board_games/connect4/config/connect4_muzero_sp_mode_config.py b/zoo/board_games/connect4/config/connect4_muzero_sp_mode_config.py index b1c2ffd1f..10bcbcc86 100644 --- a/zoo/board_games/connect4/config/connect4_muzero_sp_mode_config.py +++ b/zoo/board_games/connect4/config/connect4_muzero_sp_mode_config.py @@ -16,7 +16,7 @@ # ============================================================== connect4_muzero_config = dict( - exp_name=f'data_mz_ctree/connect4_self-play-mode_seed0', + exp_name=f'data_muzero/connect4_self-play-mode_seed0', env=dict( battle_mode='self_play_mode', bot_action_type='rule', diff --git a/zoo/board_games/connect4/config/connect4_rezero_mz_bot_mode_config.py b/zoo/board_games/connect4/config/connect4_rezero_mz_bot_mode_config.py new file mode 100644 index 000000000..f78bf043a --- /dev/null +++ b/zoo/board_games/connect4/config/connect4_rezero_mz_bot_mode_config.py @@ -0,0 +1,97 @@ +from easydict import EasyDict +# import torch +# torch.cuda.set_device(0) +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 5 +num_simulations = 50 +update_per_collect = 50 +batch_size = 256 +max_env_step = int(1e6) + +reuse_search = True +collect_with_pure_policy = True +use_priority = False +buffer_reanalyze_freq = 1 +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +connect4_muzero_config = dict( + exp_name=f'data_rezero_mz/connect4_muzero_bot-mode_ns{num_simulations}_upc{update_per_collect}_brf{buffer_reanalyze_freq}_seed0', + env=dict( + battle_mode='play_with_bot_mode', + bot_action_type='rule', + channel_last=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=(3, 6, 7), + action_space_size=7, + image_channel=3, + num_res_blocks=1, + num_channels=64, + support_scale=300, + reward_support_size=601, + value_support_size=601, + ), + cuda=True, + env_type='board_games', + action_type='varied_action_space', + game_segment_length=int(6 * 7 / 2), # for battle_mode='play_with_bot_mode' + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='Adam', + lr_piecewise_constant_decay=False, + learning_rate=0.003, + grad_clip_value=0.5, + num_simulations=num_simulations, + # NOTE๏ผšIn board_games, we set large td_steps to make sure the value target is the final outcome. + td_steps=int(6 * 7 / 2), # for battle_mode='play_with_bot_mode' + # NOTE๏ผšIn board_games, we set discount_factor=1. + discount_factor=1, + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e5), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + reanalyze_noise=True, + use_priority=use_priority, + # ============= The key different params for ReZero ============= + reuse_search=reuse_search, + collect_with_pure_policy=collect_with_pure_policy, + buffer_reanalyze_freq=buffer_reanalyze_freq, + ), +) +connect4_muzero_config = EasyDict(connect4_muzero_config) +main_config = connect4_muzero_config + +connect4_muzero_create_config = dict( + env=dict( + type='connect4', + import_names=['zoo.board_games.connect4.envs.connect4_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='muzero', + import_names=['lzero.policy.muzero'], + ), +) +connect4_muzero_create_config = EasyDict(connect4_muzero_create_config) +create_config = connect4_muzero_create_config + +if __name__ == "__main__": + # Define a list of seeds for multiple runs + seeds = [0, 1, 2] # You can add more seed values here + for seed in seeds: + # Update exp_name to include the current seed + main_config.exp_name = f'data_rezero_mz/connect4_muzero_bot-mode_ns{num_simulations}_upc{update_per_collect}_brf{buffer_reanalyze_freq}_seed{seed}' + from lzero.entry import train_rezero + train_rezero([main_config, create_config], seed=seed, max_env_step=max_env_step) \ No newline at end of file diff --git a/zoo/board_games/connect4/envs/test_bots.py b/zoo/board_games/connect4/envs/test_bots.py index 9189b761e..9b1fbf063 100644 --- a/zoo/board_games/connect4/envs/test_bots.py +++ b/zoo/board_games/connect4/envs/test_bots.py @@ -1,6 +1,7 @@ import time import numpy as np +import psutil import pytest from easydict import EasyDict @@ -8,6 +9,11 @@ from zoo.board_games.mcts_bot import MCTSBot +def get_memory_usage(): + process = psutil.Process() + memory_info = process.memory_info() + return memory_info.rss + @pytest.mark.unittest class TestConnect4Bot(): """ @@ -31,7 +37,7 @@ def setup(self) -> None: prob_expert_agent=0, bot_action_type='rule', screen_scaling=9, - render_mode='image_savefile_mode', + render_mode= None, prob_random_action_in_bot=0, ) @@ -50,6 +56,8 @@ def test_mcts_bot_vs_rule_bot(self, num_simulations: int = 200) -> None: # Repeat the game for 10 rounds. for i in range(10): print('-' * 10 + str(i) + '-' * 10) + memory_usage = get_memory_usage() + print(f"Initial memory usage: {memory_usage} bytes") # Initialize the game, where there are two players: player 1 and player 2. env = Connect4Env(EasyDict(self.cfg)) # Reset the environment, set the board to a clean board and the start player to be player 1. @@ -61,6 +69,7 @@ def test_mcts_bot_vs_rule_bot(self, num_simulations: int = 200) -> None: player = MCTSBot(env_mcts, 'a', num_simulations) # player_index = 0, player = 1 # Set player 1 to move first. player_index = 0 + step = 1 while not env.get_done_reward()[0]: """ Overview: @@ -70,7 +79,7 @@ def test_mcts_bot_vs_rule_bot(self, num_simulations: int = 200) -> None: if player_index == 0: t1 = time.time() # action = env.bot_action() - action = player.get_actions(state, player_index=player_index) + action, node = player.get_actions(state, step, player_index=player_index) t2 = time.time() # print("The time difference is :", t2-t1) mcts_bot_time_list.append(t2 - t1) @@ -86,7 +95,13 @@ def test_mcts_bot_vs_rule_bot(self, num_simulations: int = 200) -> None: player_index = 0 env.step(action) state = env.board - # print(np.array(state).reshape(6, 7)) + step += 1 + print(np.array(state).reshape(6, 7)) + temp = memory_usage + memory_usage = get_memory_usage() + memory_cost = memory_usage - temp + print(f"Memory usage after search: {memory_usage} bytes") + print(f"Increased memory usage due to searches: {memory_cost} bytes") # Record the winner. winner.append(env.get_done_winner()[1]) @@ -115,7 +130,7 @@ def test_mcts_bot_vs_rule_bot(self, num_simulations: int = 200) -> None: def test_mcts_bot_vs_mcts_bot(self, num_simulations_1: int = 50, num_simulations_2: int = 50) -> None: """ Overview: - A tictactoe game between mcts_bot and rule_bot, where rule_bot take the first move. + A tictactoe game between two mcts_bots. Arguments: - num_simulations_1 (:obj:`int`): The number of the simulations of player 1 required to find the best move. - num_simulations_2 (:obj:`int`): The number of the simulations of player 2 required to find the best move. @@ -126,17 +141,21 @@ def test_mcts_bot_vs_mcts_bot(self, num_simulations_1: int = 50, num_simulations winner = [] # Repeat the game for 10 rounds. - for i in range(10): + for i in range(1): print('-' * 10 + str(i) + '-' * 10) + memory_usage = get_memory_usage() + print(f"ๅˆๅง‹ๅ†…ๅญ˜ไฝฟ็”จ้‡: {memory_usage} bytes") # Initialize the game, where there are two players: player 1 and player 2. env = Connect4Env(EasyDict(self.cfg)) # Reset the environment, set the board to a clean board and the start player to be player 1. env.reset() state = env.board player1 = MCTSBot(env, 'a', num_simulations_1) # player_index = 0, player = 1 - player2 = MCTSBot(env, 'a', num_simulations_2) + player2 = MCTSBot(env, 'b', num_simulations_2) # Set player 1 to move first. player_index = 0 + step = 1 + node = None while not env.get_done_reward()[0]: """ Overview: @@ -146,7 +165,7 @@ def test_mcts_bot_vs_mcts_bot(self, num_simulations_1: int = 50, num_simulations if player_index == 0: t1 = time.time() # action = env.bot_action() - action = player1.get_actions(state, player_index=player_index) + action, node, visit = player1.get_actions(state, step, player_index) t2 = time.time() # print("The time difference is :", t2-t1) mcts_bot1_time_list.append(t2 - t1) @@ -155,14 +174,20 @@ def test_mcts_bot_vs_mcts_bot(self, num_simulations_1: int = 50, num_simulations else: t1 = time.time() # action = env.bot_action() - action = player2.get_actions(state, player_index=player_index) + action, node, visit = player2.get_actions(state, step, player_index, num_simulation=visit) t2 = time.time() # print("The time difference is :", t2-t1) mcts_bot2_time_list.append(t2 - t1) player_index = 0 env.step(action) + step += 1 state = env.board - # print(np.array(state).reshape(6, 7)) + print(np.array(state).reshape(6, 7)) + temp = memory_usage + memory_usage = get_memory_usage() + memory_cost = memory_usage - temp + print(f"Memory usage after search: {memory_usage} bytes") + print(f"Increased memory usage due to searches: {memory_cost} bytes") # Record the winner. winner.append(env.get_done_winner()[1]) @@ -175,11 +200,11 @@ def test_mcts_bot_vs_mcts_bot(self, num_simulations_1: int = 50, num_simulations mcts_bot2_var = np.var(mcts_bot2_time_list) # Print the information of the games. - print('num_simulations={}\n'.format(200)) + print('num_simulations={}\n'.format(num_simulations_1)) print('mcts_bot1_time_list={}\n'.format(mcts_bot1_time_list)) print('mcts_bot1_mu={}, mcts_bot1_var={}\n'.format(mcts_bot1_mu, mcts_bot1_var)) - print('num_simulations={}\n'.format(1000)) + print('num_simulations={}\n'.format(num_simulations_2)) print('mcts_bot2_time_list={}\n'.format(mcts_bot2_time_list)) print('mcts_bot2_mu={}, mcts_bot2_var={}\n'.format(mcts_bot2_mu, mcts_bot2_var)) @@ -204,6 +229,8 @@ def test_rule_bot_vs_rule_bot(self) -> None: # Repeat the game for 10 rounds. for i in range(10): print('-' * 10 + str(i) + '-' * 10) + memory_usage = get_memory_usage() + print(f"ๅˆๅง‹ๅ†…ๅญ˜ไฝฟ็”จ้‡: {memory_usage} bytes") # Initialize the game, where there are two players: player 1 and player 2. env = Connect4Env(EasyDict(self.cfg)) # Reset the environment, set the board to a clean board and the start player to be player 1. @@ -234,7 +261,12 @@ def test_rule_bot_vs_rule_bot(self) -> None: player_index = 0 env.step(action) state = env.board - # print(np.array(state).reshape(6, 7)) + print(np.array(state).reshape(6, 7)) + temp = memory_usage + memory_usage = get_memory_usage() + memory_cost = memory_usage - temp + print(f"Memory usage after search: {memory_usage} bytes") + print(f"Increased memory usage due to searches: {memory_cost} bytes") # Record the winner. winner.append(env.get_done_winner()[1]) @@ -258,3 +290,9 @@ def test_rule_bot_vs_rule_bot(self) -> None: winner, winner.count(-1), winner.count(1), winner.count(2) ) ) + + +if __name__ == "__main__": + test = TestConnect4Bot() + test.setup() + test.test_mcts_bot_vs_mcts_bot(2000,200) \ No newline at end of file diff --git a/zoo/board_games/gomoku/config/gomoku_gumbel_muzero_bot_mode_config.py b/zoo/board_games/gomoku/config/gomoku_gumbel_muzero_bot_mode_config.py index 49d44e432..e9fb19301 100644 --- a/zoo/board_games/gomoku/config/gomoku_gumbel_muzero_bot_mode_config.py +++ b/zoo/board_games/gomoku/config/gomoku_gumbel_muzero_bot_mode_config.py @@ -21,7 +21,7 @@ gomoku_gumbel_muzero_config = dict( exp_name= - f'data_mz_ctree/gomoku_b{board_size}_rand{prob_random_action_in_bot}_gumbel_muzero_bot-mode_type-{bot_action_type}_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', + f'data_muzero/gomoku_b{board_size}_rand{prob_random_action_in_bot}_gumbel_muzero_bot-mode_type-{bot_action_type}_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', env=dict( board_size=board_size, battle_mode='play_with_bot_mode', diff --git a/zoo/board_games/gomoku/config/gomoku_muzero_bot_mode_config.py b/zoo/board_games/gomoku/config/gomoku_muzero_bot_mode_config.py index a30a31437..36bf5ae5f 100644 --- a/zoo/board_games/gomoku/config/gomoku_muzero_bot_mode_config.py +++ b/zoo/board_games/gomoku/config/gomoku_muzero_bot_mode_config.py @@ -20,8 +20,7 @@ # ============================================================== gomoku_muzero_config = dict( - exp_name= - f'data_mz_ctree/gomoku_b{board_size}_rand{prob_random_action_in_bot}_muzero_bot-mode_type-{bot_action_type}_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', + exp_name=f'data_muzero/gomoku_b{board_size}_rand{prob_random_action_in_bot}_muzero_bot-mode_type-{bot_action_type}_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', env=dict( board_size=board_size, battle_mode='play_with_bot_mode', @@ -85,5 +84,10 @@ create_config = gomoku_muzero_create_config if __name__ == "__main__": - from lzero.entry import train_muzero - train_muzero([main_config, create_config], seed=0, max_env_step=max_env_step) + # Define a list of seeds for multiple runs + seeds = [0] # You can add more seed values here + for seed in seeds: + # Update exp_name to include the current seed + main_config.exp_name = f'data_muzero/gomoku_b{board_size}_rand{prob_random_action_in_bot}_muzero_bot-mode_type-{bot_action_type}_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed{seed}' + from lzero.entry import train_muzero + train_muzero([main_config, create_config], seed=seed, max_env_step=max_env_step) diff --git a/zoo/board_games/gomoku/config/gomoku_muzero_sp_mode_config.py b/zoo/board_games/gomoku/config/gomoku_muzero_sp_mode_config.py index 1fbaedb8b..7b12445f2 100644 --- a/zoo/board_games/gomoku/config/gomoku_muzero_sp_mode_config.py +++ b/zoo/board_games/gomoku/config/gomoku_muzero_sp_mode_config.py @@ -21,12 +21,12 @@ gomoku_muzero_config = dict( exp_name= - f'data_mz_ctree/gomoku_muzero_sp-mode_rand{prob_random_action_in_bot}_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', + f'data_muzero/gomoku_muzero_sp-mode_rand{prob_random_action_in_bot}_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', env=dict( battle_mode='self_play_mode', bot_action_type=bot_action_type, prob_random_action_in_bot=prob_random_action_in_bot, - channel_last=True, + channel_last=False, collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, n_evaluator_episode=evaluator_env_num, diff --git a/zoo/board_games/gomoku/config/gomoku_rezero-mz_bot_mode_config.py b/zoo/board_games/gomoku/config/gomoku_rezero-mz_bot_mode_config.py new file mode 100644 index 000000000..031c7ef01 --- /dev/null +++ b/zoo/board_games/gomoku/config/gomoku_rezero-mz_bot_mode_config.py @@ -0,0 +1,100 @@ +from easydict import EasyDict +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 32 +n_episode = 32 +evaluator_env_num = 5 +num_simulations = 50 +update_per_collect = 50 +batch_size = 256 +max_env_step = int(1e6) +board_size = 6 # default_size is 15 +bot_action_type = 'v0' # options={'v0', 'v1'} +prob_random_action_in_bot = 0.5 + +reuse_search = True +collect_with_pure_policy = True +use_priority = False +buffer_reanalyze_freq = 1 +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +gomoku_muzero_config = dict( + exp_name=f'data_rezero_mz/gomoku_b{board_size}_rand{prob_random_action_in_bot}_muzero_bot-mode_type-{bot_action_type}_ns{num_simulations}_upc{update_per_collect}_brf{buffer_reanalyze_freq}_seed0', + env=dict( + board_size=board_size, + battle_mode='play_with_bot_mode', + bot_action_type=bot_action_type, + prob_random_action_in_bot=prob_random_action_in_bot, + channel_last=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=(3, board_size, board_size), + action_space_size=int(board_size * board_size), + image_channel=3, + num_res_blocks=1, + num_channels=32, + support_scale=10, + reward_support_size=21, + value_support_size=21, + ), + cuda=True, + env_type='board_games', + action_type='varied_action_space', + game_segment_length=int(board_size * board_size / 2), # for battle_mode='play_with_bot_mode' + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='Adam', + lr_piecewise_constant_decay=False, + learning_rate=0.003, + grad_clip_value=0.5, + num_simulations=num_simulations, + # NOTE๏ผšIn board_games, we set large td_steps to make sure the value target is the final outcome. + td_steps=int(board_size * board_size / 2), # for battle_mode='play_with_bot_mode' + # NOTE๏ผšIn board_games, we set discount_factor=1. + discount_factor=1, + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e5), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + reanalyze_noise=True, + use_priority=use_priority, + # ============= The key different params for ReZero ============= + reuse_search=reuse_search, + collect_with_pure_policy=collect_with_pure_policy, + buffer_reanalyze_freq=buffer_reanalyze_freq, + ), +) +gomoku_muzero_config = EasyDict(gomoku_muzero_config) +main_config = gomoku_muzero_config + +gomoku_muzero_create_config = dict( + env=dict( + type='gomoku', + import_names=['zoo.board_games.gomoku.envs.gomoku_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='muzero', + import_names=['lzero.policy.muzero'], + ), +) +gomoku_muzero_create_config = EasyDict(gomoku_muzero_create_config) +create_config = gomoku_muzero_create_config + +if __name__ == "__main__": + # Define a list of seeds for multiple runs + seeds = [0] # You can add more seed values here + for seed in seeds: + # Update exp_name to include the current seed + main_config.exp_name = f'data_rezero-mz/gomoku_b{board_size}_rand{prob_random_action_in_bot}_rezero-mz_bot-mode_type-{bot_action_type}_ns{num_simulations}_upc{update_per_collect}_brf{buffer_reanalyze_freq}_seed{seed}' + from lzero.entry import train_rezero + train_rezero([main_config, create_config], seed=seed, max_env_step=max_env_step) \ No newline at end of file diff --git a/zoo/board_games/gomoku/entry/gomoku_gumbel_muzero_eval.py b/zoo/board_games/gomoku/entry/gomoku_gumbel_muzero_eval.py index e34d1af02..5568148a2 100644 --- a/zoo/board_games/gomoku/entry/gomoku_gumbel_muzero_eval.py +++ b/zoo/board_games/gomoku/entry/gomoku_gumbel_muzero_eval.py @@ -8,7 +8,7 @@ point to the ckpt file of the pretrained model, and an absolute path is recommended. In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``. """ - model_path = './data_mz_ctree/gomoku_gumbel_muzero_visit50_value1_purevaluenetwork_deletepriority_deletetargetsoftmax_seed0_230517_131142/ckpt/iteration_113950.pth.tar' + model_path = './data_muzero/gomoku_gumbel_muzero_visit50_value1_purevaluenetwork_deletepriority_deletetargetsoftmax_seed0_230517_131142/ckpt/iteration_113950.pth.tar' seeds = [0,1,2,3,4] num_episodes_each_seed = 10 # If True, you can play with the agent. diff --git a/zoo/board_games/mcts_bot.py b/zoo/board_games/mcts_bot.py index 32609de54..706c0f468 100644 --- a/zoo/board_games/mcts_bot.py +++ b/zoo/board_games/mcts_bot.py @@ -10,8 +10,11 @@ """ import time +import copy from abc import ABC, abstractmethod from collections import defaultdict +from graphviz import Digraph +import os import numpy as np import copy @@ -170,7 +173,8 @@ def value(self): - If the parent's current player is player 2, then Q-value = 5 - 10 = -5. This way, a higher Q-value for a node indicates a higher win rate for the parent's current player. """ - + if self.parent == None: + return 0 # Determine the number of wins and losses based on the current player at the parent node. wins, loses = (self._results[1], self._results[-1]) if self.parent.env.current_player == 1 else ( self._results[-1], self._results[1]) @@ -344,6 +348,11 @@ def _tree_policy(self): else: current_node = current_node.best_child() return current_node + + def print_tree(self, node, indent="*"): + print(indent + str(np.array(node.env.board).reshape(6,7)) + str(node.visit_count)) + for child in node.children: + self.print_tree(child, indent + "*") class MCTSBot: @@ -365,7 +374,7 @@ def __init__(self, env, bot_name, num_simulation=50): self.num_simulation = num_simulation self.simulator_env = env - def get_actions(self, state, player_index, best_action_type="UCB"): + def get_actions(self, state, step, player_index, root=None, num_simulation=None, best_action_type="UCB"): """ Overview: This function gets the actions that the MCTS Bot will take. @@ -380,8 +389,95 @@ def get_actions(self, state, player_index, best_action_type="UCB"): """ # Every time before make a decision, reset the environment to the current environment of the game. self.simulator_env.reset(start_player_index=player_index, init_state=state) - root = TwoPlayersMCTSNode(self.simulator_env) + if root == None: + root = TwoPlayersMCTSNode(self.simulator_env) # Do the MCTS to find the best action to take. mcts = MCTS(root) - mcts.best_action(self.num_simulation, best_action_type=best_action_type) - return root.best_action + if num_simulation == None: + child_node = mcts.best_action(self.num_simulation, best_action_type=best_action_type) + else: + child_node = mcts.best_action(num_simulation, best_action_type=best_action_type) + print(root.visit_count) + if step%2 == 1: + self.plot_simulation_graph(child_node, step) + else: + self.plot_simulation_graph(root, step) + # if step == 3: + # self.plot_simulation_graph(root, step) + + return root.best_action, child_node, int(child_node.visit_count) + + def obtain_tree_topology(self, root, to_play=-1): + node_stack = [] + edge_topology_list = [] + node_topology_list = [] + node_id_list = [] + node_stack.append(root) + while len(node_stack) > 0: + node = node_stack[-1] + node_stack.pop() + node_dict = {} + node_dict['node_id'] = np.array(node.env.board).reshape(6,7) + node_dict['visit_count'] = node.visit_count + # node_dict['policy_prior'] = node.prior + node_dict['value'] = node.value + node_topology_list.append(node_dict) + + # node_id_list.append(node.simulation_index) + for child in node.children: + # child.parent_simulation_index = node.simulation_index + edge_dict = {} + edge_dict['parent_id'] = np.array(node.env.board).reshape(6,7) + edge_dict['child_id'] = np.array(child.env.board).reshape(6,7) + edge_topology_list.append(edge_dict) + node_stack.append(child) + return edge_topology_list, node_id_list, node_topology_list + + def obtain_child_topology(self, root, to_play=-1): + edge_topology_list = [] + node_topology_list = [] + node_id_list = [] + node = root + node_dict = {} + node_dict['node_id'] = np.array(node.env.board).reshape(6,7) + node_dict['visit_count'] = node.visit_count + # node_dict['policy_prior'] = node.prior + node_dict['value'] = node.value + node_topology_list.append(node_dict) + + for child in node.children: + # child.parent_simulation_index = node.simulation_index + edge_dict = {} + edge_dict['parent_id'] = np.array(node.env.board).reshape(6,7) + edge_dict['child_id'] = np.array(child.env.board).reshape(6,7) + edge_topology_list.append(edge_dict) + node_dict = {} + node_dict['node_id'] = np.array(child.env.board).reshape(6,7) + node_dict['visit_count'] = child.visit_count + # node_dict['policy_prior'] = node.prior + node_dict['value'] = child.value + node_topology_list.append(node_dict) + return edge_topology_list, node_id_list, node_topology_list + + def plot_simulation_graph(self, env_root, current_step, type="child", graph_directory=None): + if type == "child": + edge_topology_list, node_id_list, node_topology_list = self.obtain_child_topology(env_root) + elif type == "tree": + edge_topology_list, node_id_list, node_topology_list = self.obtain_tree_topology(env_root) + dot = Digraph(comment='this is direction') + for node_topology in node_topology_list: + node_name = str(node_topology['node_id']) + label = f"{node_topology['node_id']}, \n visit_count: {node_topology['visit_count']}, \n value: {round(node_topology['value'], 4)}" + dot.node(node_name, label=label) + for edge_topology in edge_topology_list: + parent_id = str(edge_topology['parent_id']) + child_id = str(edge_topology['child_id']) + label = parent_id + '-' + child_id + dot.edge(parent_id, child_id, label=None) + if graph_directory is None: + graph_directory = './data_visualize/' + if not os.path.exists(graph_directory): + os.makedirs(graph_directory) + graph_path = graph_directory + 'same_num_' + str(current_step) + 'step.gv' + dot.format = 'png' + dot.render(graph_path, view=False) diff --git a/zoo/board_games/test_speed_win-rate_between_bots.py b/zoo/board_games/test_speed_win-rate_between_bots.py index c31504420..7d9374e6d 100644 --- a/zoo/board_games/test_speed_win-rate_between_bots.py +++ b/zoo/board_games/test_speed_win-rate_between_bots.py @@ -9,11 +9,17 @@ import numpy as np from easydict import EasyDict +import psutil from zoo.board_games.gomoku.envs.gomoku_env import GomokuEnv from zoo.board_games.mcts_bot import MCTSBot from zoo.board_games.tictactoe.envs.tictactoe_env import TicTacToeEnv +def get_memory_usage(): + process = psutil.Process() + memory_info = process.memory_info() + return memory_info.rss + cfg_tictactoe = dict( battle_mode='self_play_mode', agent_vs_human=False, @@ -361,6 +367,8 @@ def test_tictactoe_mcts_bot_vs_alphabeta_bot(num_simulations=50): # Repeat the game for 10 rounds. for i in range(10): print('-' * 10 + str(i) + '-' * 10) + memory_usage = get_memory_usage() + print(f'Memory usage at the beginning of the game: {memory_usage} bytes') # Initialize the game, where there are two players: player 1 and player 2. env = TicTacToeEnv(EasyDict(cfg_tictactoe)) # Reset the environment, set the board to a clean board and the start player to be player 1. @@ -369,6 +377,7 @@ def test_tictactoe_mcts_bot_vs_alphabeta_bot(num_simulations=50): player = MCTSBot(env, 'a', num_simulations) # player_index = 0, player = 1 # Set player 1 to move first. player_index = 0 + step = 1 while not env.get_done_reward()[0]: """ Overview: @@ -378,7 +387,7 @@ def test_tictactoe_mcts_bot_vs_alphabeta_bot(num_simulations=50): if player_index == 0: t1 = time.time() # action = env.mcts_bot() - action = player.get_actions(state, player_index=player_index, best_action_type = "most_visit") + action = player.get_actions(state, step, player_index, best_action_type = "most_visit")[0] t2 = time.time() # print("The time difference is :", t2-t1) # mcts_bot_time_list.append(t2 - t1) @@ -396,10 +405,17 @@ def test_tictactoe_mcts_bot_vs_alphabeta_bot(num_simulations=50): alphabeta_pruning_time_list.append(t2 - t1) player_index = 0 env.step(action) + step += 1 state = env.board # Print the result of the game. if env.get_done_reward()[0]: print(state) + + temp = memory_usage + memory_usage = get_memory_usage() + memory_cost = memory_usage - temp + print(f'Memory usage after searching: {memory_usage} bytes') + print(f'Memory increase after searching: {memory_cost} bytes') # Record the winner. winner.append(env.get_done_winner()[1]) diff --git a/zoo/board_games/tictactoe/config/tictactoe_gumbel_muzero_bot_mode_config.py b/zoo/board_games/tictactoe/config/tictactoe_gumbel_muzero_bot_mode_config.py index 0ef0bf9f7..6e4ee623a 100644 --- a/zoo/board_games/tictactoe/config/tictactoe_gumbel_muzero_bot_mode_config.py +++ b/zoo/board_games/tictactoe/config/tictactoe_gumbel_muzero_bot_mode_config.py @@ -17,7 +17,7 @@ tictactoe_gumbel_muzero_config = dict( exp_name= - f'data_mz_ctree/tictactoe_gumbel_muzero_bot-mode_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', + f'data_muzero/tictactoe_gumbel_muzero_bot-mode_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', env=dict( battle_mode='play_with_bot_mode', collector_env_num=collector_env_num, diff --git a/zoo/board_games/tictactoe/config/tictactoe_muzero_bot_mode_config.py b/zoo/board_games/tictactoe/config/tictactoe_muzero_bot_mode_config.py index da4bf2c0e..82912877e 100644 --- a/zoo/board_games/tictactoe/config/tictactoe_muzero_bot_mode_config.py +++ b/zoo/board_games/tictactoe/config/tictactoe_muzero_bot_mode_config.py @@ -17,7 +17,7 @@ tictactoe_muzero_config = dict( exp_name= - f'data_mz_ctree/tictactoe_muzero_bot-mode_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', + f'data_muzero/tictactoe_muzero_bot-mode_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', env=dict( battle_mode='play_with_bot_mode', collector_env_num=collector_env_num, diff --git a/zoo/board_games/tictactoe/config/tictactoe_muzero_sp_mode_config.py b/zoo/board_games/tictactoe/config/tictactoe_muzero_sp_mode_config.py index 546fc4c0e..981cf0786 100644 --- a/zoo/board_games/tictactoe/config/tictactoe_muzero_sp_mode_config.py +++ b/zoo/board_games/tictactoe/config/tictactoe_muzero_sp_mode_config.py @@ -17,7 +17,7 @@ tictactoe_muzero_config = dict( exp_name= - f'data_mz_ctree/tictactoe_muzero_sp-mode_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', + f'data_muzero/tictactoe_muzero_sp-mode_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', env=dict( battle_mode='self_play_mode', collector_env_num=collector_env_num, diff --git a/zoo/box2d/bipedalwalker/config/bipedalwalker_cont_disc_efficientzero_config.py b/zoo/box2d/bipedalwalker/config/bipedalwalker_cont_disc_efficientzero_config.py index 929d5594f..ebc11c8b4 100644 --- a/zoo/box2d/bipedalwalker/config/bipedalwalker_cont_disc_efficientzero_config.py +++ b/zoo/box2d/bipedalwalker/config/bipedalwalker_cont_disc_efficientzero_config.py @@ -10,7 +10,7 @@ # each_dim_disc_size = 4 # thus the total discrete action number is 4**4=256 # num_simulations = 50 # update_per_collect = None -# model_update_ratio = 0.25 +# replay_ratio = 0.25 # batch_size = 256 # max_env_step = int(5e6) # reanalyze_ratio = 0. @@ -20,7 +20,7 @@ bipedalwalker_cont_disc_efficientzero_config = dict( exp_name= - f'data_sez_ctree/bipedalwalker_cont_disc_efficientzero_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_seed0', + f'data_sez_ctree/bipedalwalker_cont_disc_efficientzero_ns{num_simulations}_upc{update_per_collect}-mur{replay_ratio}_rr{reanalyze_ratio}_seed0', env=dict( stop_value=int(1e6), env_id='BipedalWalker-v3', @@ -61,7 +61,7 @@ reanalyze_ratio=reanalyze_ratio, n_episode=n_episode, eval_freq=int(2e3), - model_update_ratio=model_update_ratio, + replay_ratio=replay_ratio, replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, diff --git a/zoo/box2d/bipedalwalker/config/bipedalwalker_cont_disc_sampled_efficientzero_config.py b/zoo/box2d/bipedalwalker/config/bipedalwalker_cont_disc_sampled_efficientzero_config.py index f20558b3a..ecbfe1249 100644 --- a/zoo/box2d/bipedalwalker/config/bipedalwalker_cont_disc_sampled_efficientzero_config.py +++ b/zoo/box2d/bipedalwalker/config/bipedalwalker_cont_disc_sampled_efficientzero_config.py @@ -11,7 +11,7 @@ K = 20 # num_of_sampled_actions num_simulations = 50 update_per_collect = None -model_update_ratio = 0.25 +replay_ratio = 0.25 batch_size = 256 max_env_step = int(5e6) reanalyze_ratio = 0. @@ -21,7 +21,7 @@ bipedalwalker_cont_disc_sampled_efficientzero_config = dict( exp_name= - f'data_sez_ctree/bipedalwalker_cont_disc_sampled_efficientzero_k{K}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_seed0', + f'data_sez_ctree/bipedalwalker_cont_disc_sampled_efficientzero_k{K}_ns{num_simulations}_upc{update_per_collect}-mur{replay_ratio}_rr{reanalyze_ratio}_seed0', env=dict( stop_value=int(1e6), env_id='BipedalWalker-v3', @@ -63,7 +63,7 @@ reanalyze_ratio=reanalyze_ratio, n_episode=n_episode, eval_freq=int(2e3), - model_update_ratio=model_update_ratio, + replay_ratio=replay_ratio, replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, diff --git a/zoo/box2d/bipedalwalker/config/bipedalwalker_cont_sampled_efficientzero_config.py b/zoo/box2d/bipedalwalker/config/bipedalwalker_cont_sampled_efficientzero_config.py index a61856c44..d691e9d1c 100644 --- a/zoo/box2d/bipedalwalker/config/bipedalwalker_cont_sampled_efficientzero_config.py +++ b/zoo/box2d/bipedalwalker/config/bipedalwalker_cont_sampled_efficientzero_config.py @@ -10,7 +10,7 @@ K = 20 # num_of_sampled_actions num_simulations = 50 update_per_collect = None -model_update_ratio = 0.25 +replay_ratio = 0.25 batch_size = 256 max_env_step = int(5e6) reanalyze_ratio = 0. @@ -20,7 +20,7 @@ bipedalwalker_cont_sampled_efficientzero_config = dict( exp_name= - f'data_sez_ctree/bipedalwalker_cont_sampled_efficientzero_k{K}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_seed0', + f'data_sez_ctree/bipedalwalker_cont_sampled_efficientzero_k{K}_ns{num_simulations}_upc{update_per_collect}-mur{replay_ratio}_rr{reanalyze_ratio}_seed0', env=dict( env_id='BipedalWalker-v3', env_type='normal', @@ -61,7 +61,7 @@ reanalyze_ratio=reanalyze_ratio, n_episode=n_episode, eval_freq=int(2e3), - model_update_ratio=model_update_ratio, + replay_ratio=replay_ratio, replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, diff --git a/zoo/box2d/lunarlander/config/lunarlander_disc_efficientzero_config.py b/zoo/box2d/lunarlander/config/lunarlander_disc_efficientzero_config.py index 38e268c3e..ab274eb9e 100644 --- a/zoo/box2d/lunarlander/config/lunarlander_disc_efficientzero_config.py +++ b/zoo/box2d/lunarlander/config/lunarlander_disc_efficientzero_config.py @@ -1,5 +1,6 @@ from easydict import EasyDict - +import torch +torch.cuda.set_device(0) # ============================================================== # begin of the most frequently changed config specified by the user # ============================================================== @@ -7,17 +8,26 @@ n_episode = 8 evaluator_env_num = 3 num_simulations = 50 -update_per_collect = 200 +# update_per_collect = 200 +update_per_collect = None +replay_ratio = 0.25 + batch_size = 256 -max_env_step = int(5e6) -reanalyze_ratio = 0. +max_env_step = int(1e6) +# reanalyze_ratio = 0. +# reanalyze_ratio = 1 +reanalyze_ratio = 0.99 + + +seed = 0 + # ============================================================== # end of the most frequently changed config specified by the user # ============================================================== lunarlander_disc_efficientzero_config = dict( - exp_name= - f'data_ez_ctree/lunarlander_disc_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', + # exp_name=f'data_ez_ctree/lunarlander_disc_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed{seed}', + exp_name=f'data_ez_ctree_0129/lunarlander/ez_rr{reanalyze_ratio}_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed{seed}', env=dict( env_id='LunarLander-v2', continuous=False, @@ -42,6 +52,7 @@ env_type='not_board_games', game_segment_length=200, update_per_collect=update_per_collect, + replay_ratio=replay_ratio, batch_size=batch_size, optim_type='Adam', lr_piecewise_constant_decay=False, diff --git a/zoo/box2d/lunarlander/config/lunarlander_disc_gumbel_muzero_config.py b/zoo/box2d/lunarlander/config/lunarlander_disc_gumbel_muzero_config.py index 3f4241685..0d8e1585d 100644 --- a/zoo/box2d/lunarlander/config/lunarlander_disc_gumbel_muzero_config.py +++ b/zoo/box2d/lunarlander/config/lunarlander_disc_gumbel_muzero_config.py @@ -16,7 +16,7 @@ # ============================================================== lunarlander_gumbel_muzero_config = dict( - exp_name=f'data_mz_ctree/lunarlander_gumbel_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', + exp_name=f'data_muzero/lunarlander_gumbel_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', env=dict( env_id='LunarLander-v2', continuous=False, diff --git a/zoo/box2d/lunarlander/config/lunarlander_disc_muzero_config.py b/zoo/box2d/lunarlander/config/lunarlander_disc_muzero_config.py index 19b7bcd86..9c9c2b277 100644 --- a/zoo/box2d/lunarlander/config/lunarlander_disc_muzero_config.py +++ b/zoo/box2d/lunarlander/config/lunarlander_disc_muzero_config.py @@ -16,7 +16,7 @@ # ============================================================== lunarlander_muzero_config = dict( - exp_name=f'data_mz_ctree/lunarlander_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', + exp_name=f'data_muzero/lunarlander_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', env=dict( env_id='LunarLander-v2', continuous=False, diff --git a/zoo/box2d/lunarlander/config/lunarlander_disc_muzero_multi-seed_config.py b/zoo/box2d/lunarlander/config/lunarlander_disc_muzero_multi-seed_config.py new file mode 100644 index 000000000..f13011418 --- /dev/null +++ b/zoo/box2d/lunarlander/config/lunarlander_disc_muzero_multi-seed_config.py @@ -0,0 +1,91 @@ +from easydict import EasyDict +import torch +torch.cuda.set_device(5) +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 50 +update_per_collect = 200 +batch_size = 256 +max_env_step = int(5e6) +reanalyze_ratio = 0. +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +lunarlander_muzero_config = dict( + exp_name=f'data_muzero/lunarlander_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', + env=dict( + env_name='LunarLander-v2', + continuous=False, + manually_discretization=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=8, + action_space_size=4, + model_type='mlp', + lstm_hidden_size=256, + latent_state_dim=256, + self_supervised_learning_loss=True, # NOTE: default is False. + discrete_action_encoding_type='one_hot', + res_connection_in_dynamics=True, + norm_type='BN', + ), + cuda=True, + env_type='not_board_games', + game_segment_length=200, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='Adam', + lr_piecewise_constant_decay=False, + learning_rate=0.003, + ssl_loss_weight=2, # NOTE: default is 0. + grad_clip_value=0.5, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(1e3), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), +) +lunarlander_muzero_config = EasyDict(lunarlander_muzero_config) +main_config = lunarlander_muzero_config + +lunarlander_muzero_create_config = dict( + env=dict( + type='lunarlander', + import_names=['zoo.box2d.lunarlander.envs.lunarlander_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='muzero', + import_names=['lzero.policy.muzero'], + ), + collector=dict( + type='episode_muzero', + get_train_sample=True, + import_names=['lzero.worker.muzero_collector'], + ) +) +lunarlander_muzero_create_config = EasyDict(lunarlander_muzero_create_config) +create_config = lunarlander_muzero_create_config + +if __name__ == "__main__": + # Define a list of seeds for multiple runs + seeds = [0, 1] # You can add more seed values here + + for seed in seeds: + # Update exp_name to include the current seed + main_config.exp_name = f'data_muzero_0128/lunarlander_muzero_ns{main_config.policy.num_simulations}_upc{main_config.policy.update_per_collect}_rr{main_config.policy.reanalyze_ratio}_seed{seed}' + from lzero.entry import train_muzero + train_muzero([main_config, create_config], seed=seed, max_env_step=max_env_step) \ No newline at end of file diff --git a/zoo/box2d/lunarlander/config/lunarlander_disc_rezero-ez_config.py b/zoo/box2d/lunarlander/config/lunarlander_disc_rezero-ez_config.py new file mode 100644 index 000000000..1e180390d --- /dev/null +++ b/zoo/box2d/lunarlander/config/lunarlander_disc_rezero-ez_config.py @@ -0,0 +1,111 @@ +from easydict import EasyDict +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 50 +update_per_collect = None +replay_ratio = 0.25 +batch_size = 256 +max_env_step = int(1e6) +use_priority = False + +reuse_search = True +collect_with_pure_policy = True +buffer_reanalyze_freq = 1 +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +lunarlander_muzero_config = dict( + exp_name=f'data_rezero-ez/lunarlander_rezero-ez_ns{num_simulations}_upc{update_per_collect}_brf{buffer_reanalyze_freq}_seed0', + env=dict( + env_name='LunarLander-v2', + continuous=False, + manually_discretization=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + # NOTE: not save ckpt + learn=dict( + learner=dict( + train_iterations=1000000000, + dataloader=dict( + num_workers=0, + ), + log_policy=True, + hook=dict( + load_ckpt_before_run='', + log_show_after_iter=1000, + save_ckpt_after_iter=1000000, + save_ckpt_after_run=True, + ), + cfg_type='BaseLearnerDict', + ), + ), + model=dict( + observation_shape=8, + action_space_size=4, + model_type='mlp', + lstm_hidden_size=256, + latent_state_dim=256, + self_supervised_learning_loss=True, + discrete_action_encoding_type='one_hot', + res_connection_in_dynamics=True, + norm_type='BN', + ), + cuda=True, + env_type='not_board_games', + game_segment_length=200, + update_per_collect=update_per_collect, + replay_ratio=replay_ratio, + batch_size=batch_size, + optim_type='Adam', + lr_piecewise_constant_decay=False, + learning_rate=0.003, + ssl_loss_weight=2, + grad_clip_value=0.5, + num_simulations=num_simulations, + n_episode=n_episode, + eval_freq=int(1e3), + replay_buffer_size=int(1e6), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + reanalyze_noise=True, + use_priority=use_priority, + # ============= The key different params for ReZero ============= + reuse_search=reuse_search, + collect_with_pure_policy=collect_with_pure_policy, + buffer_reanalyze_freq=buffer_reanalyze_freq, + ), +) +lunarlander_muzero_config = EasyDict(lunarlander_muzero_config) +main_config = lunarlander_muzero_config + +lunarlander_muzero_create_config = dict( + env=dict( + type='lunarlander', + import_names=['zoo.box2d.lunarlander.envs.lunarlander_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='efficientzero', + import_names=['lzero.policy.efficientzero'], + ), +) +lunarlander_muzero_create_config = EasyDict(lunarlander_muzero_create_config) +create_config = lunarlander_muzero_create_config + +if __name__ == "__main__": + # Define a list of seeds for multiple runs + seeds = [0] # You can add more seed values here + for seed in seeds: + # Update exp_name to include the current seed + main_config.exp_name = f'data_rezero_ez/lunarlander_rezero-ez_ns{num_simulations}_upc{update_per_collect}_brf{buffer_reanalyze_freq}_seed{seed}' + from lzero.entry import train_rezero + train_rezero([main_config, create_config], seed=seed, max_env_step=max_env_step) diff --git a/zoo/box2d/lunarlander/config/lunarlander_disc_rezero-mz_config.py b/zoo/box2d/lunarlander/config/lunarlander_disc_rezero-mz_config.py new file mode 100644 index 000000000..1e99bc146 --- /dev/null +++ b/zoo/box2d/lunarlander/config/lunarlander_disc_rezero-mz_config.py @@ -0,0 +1,94 @@ +from easydict import EasyDict +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 50 +update_per_collect = None +replay_ratio = 0.25 +batch_size = 256 +max_env_step = int(1e6) +use_priority = False + +reuse_search = True +collect_with_pure_policy = True +buffer_reanalyze_freq = 1 +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +lunarlander_muzero_config = dict( + exp_name=f'data_rezero-mz/lunarlander_rezero-mz_ns{num_simulations}_upc{update_per_collect}_brf{buffer_reanalyze_freq}_seed0', + env=dict( + env_name='LunarLander-v2', + continuous=False, + manually_discretization=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=8, + action_space_size=4, + model_type='mlp', + lstm_hidden_size=256, + latent_state_dim=256, + self_supervised_learning_loss=True, + discrete_action_encoding_type='one_hot', + res_connection_in_dynamics=True, + norm_type='BN', + ), + cuda=True, + env_type='not_board_games', + game_segment_length=200, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='Adam', + lr_piecewise_constant_decay=False, + learning_rate=0.003, + ssl_loss_weight=2, + grad_clip_value=0.5, + num_simulations=num_simulations, + reanalyze_ratio=0, # NOTE: for rezero, reanalyze_ratio should be 0. + n_episode=n_episode, + eval_freq=int(1e3), + replay_buffer_size=int(1e6), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + reanalyze_noise=True, + use_priority=use_priority, + # ============= The key different params for ReZero ============= + reuse_search=reuse_search, + collect_with_pure_policy=collect_with_pure_policy, + buffer_reanalyze_freq=buffer_reanalyze_freq, + ), +) +lunarlander_muzero_config = EasyDict(lunarlander_muzero_config) +main_config = lunarlander_muzero_config + +lunarlander_muzero_create_config = dict( + env=dict( + type='lunarlander', + import_names=['zoo.box2d.lunarlander.envs.lunarlander_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='muzero', + import_names=['lzero.policy.muzero'], + ), +) +lunarlander_muzero_create_config = EasyDict(lunarlander_muzero_create_config) +create_config = lunarlander_muzero_create_config + +if __name__ == "__main__": + # Define a list of seeds for multiple runs + seeds = [0] # You can add more seed values here + for seed in seeds: + # Update exp_name to include the current seed + main_config.exp_name = f'data_rezero_mz/lunarlander_rezero-mz_ns{num_simulations}_upc{update_per_collect}_brf{buffer_reanalyze_freq}_seed{seed}' + from lzero.entry import train_rezero + train_rezero([main_config, create_config], seed=seed, max_env_step=max_env_step) \ No newline at end of file diff --git a/zoo/bsuite/config/bsuite_muzero_config.py b/zoo/bsuite/config/bsuite_muzero_config.py index 5c40a204b..9e5682df7 100644 --- a/zoo/bsuite/config/bsuite_muzero_config.py +++ b/zoo/bsuite/config/bsuite_muzero_config.py @@ -38,7 +38,7 @@ # ============================================================== bsuite_muzero_config = dict( - exp_name=f'data_mz_ctree/bsuite_{env_id}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed{seed}', + exp_name=f'data_muzero/bsuite_{env_id}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed{seed}', env=dict( env_id=env_id, stop_value=int(1e6), diff --git a/zoo/classic_control/cartpole/config/cartpole_gumbel_muzero_config.py b/zoo/classic_control/cartpole/config/cartpole_gumbel_muzero_config.py index 9f292575a..1db5629ee 100644 --- a/zoo/classic_control/cartpole/config/cartpole_gumbel_muzero_config.py +++ b/zoo/classic_control/cartpole/config/cartpole_gumbel_muzero_config.py @@ -16,7 +16,7 @@ # ============================================================== cartpole_gumbel_muzero_config = dict( - exp_name=f'data_mz_ctree/cartpole_gumbel_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', + exp_name=f'data_muzero/cartpole_gumbel_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', env=dict( env_id='CartPole-v0', continuous=False, diff --git a/zoo/classic_control/cartpole/config/cartpole_muzero_config.py b/zoo/classic_control/cartpole/config/cartpole_muzero_config.py index 56ee32acf..1918ca408 100644 --- a/zoo/classic_control/cartpole/config/cartpole_muzero_config.py +++ b/zoo/classic_control/cartpole/config/cartpole_muzero_config.py @@ -16,7 +16,7 @@ # ============================================================== cartpole_muzero_config = dict( - exp_name=f'data_mz_ctree/cartpole_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', + exp_name=f'data_muzero/cartpole_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', env=dict( env_id='CartPole-v0', continuous=False, diff --git a/zoo/classic_control/cartpole/config/cartpole_rezero_mz_config.py b/zoo/classic_control/cartpole/config/cartpole_rezero_mz_config.py new file mode 100644 index 000000000..7405547f3 --- /dev/null +++ b/zoo/classic_control/cartpole/config/cartpole_rezero_mz_config.py @@ -0,0 +1,87 @@ +from easydict import EasyDict + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 25 +update_per_collect = 100 +batch_size = 256 +max_env_step = int(1e5) +use_priority = False + +reuse_search = True +collect_with_pure_policy = False +buffer_reanalyze_freq = 1 +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +cartpole_muzero_config = dict( + exp_name=f'data_rezero-mz/cartpole_rezero-mz_ns{num_simulations}_upc{update_per_collect}_brf{buffer_reanalyze_freq}_seed0', + env=dict( + env_name='CartPole-v0', + continuous=False, + manually_discretization=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=4, + action_space_size=2, + model_type='mlp', + lstm_hidden_size=128, + latent_state_dim=128, + self_supervised_learning_loss=True, # NOTE: default is False. + discrete_action_encoding_type='one_hot', + norm_type='BN', + ), + cuda=True, + env_type='not_board_games', + game_segment_length=50, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='Adam', + lr_piecewise_constant_decay=False, + learning_rate=0.003, + ssl_loss_weight=2, + num_simulations=num_simulations, + n_episode=n_episode, + eval_freq=int(2e2), + replay_buffer_size=int(1e6), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + reanalyze_noise=True, + use_priority=use_priority, + # ============= The key different params for ReZero ============= + reuse_search=reuse_search, + collect_with_pure_policy=collect_with_pure_policy, + buffer_reanalyze_freq=buffer_reanalyze_freq, + ), +) + +cartpole_muzero_config = EasyDict(cartpole_muzero_config) +main_config = cartpole_muzero_config + +cartpole_muzero_create_config = dict( + env=dict( + type='cartpole_lightzero', + import_names=['zoo.classic_control.cartpole.envs.cartpole_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='muzero', + import_names=['lzero.policy.muzero'], + ), +) +cartpole_muzero_create_config = EasyDict(cartpole_muzero_create_config) +create_config = cartpole_muzero_create_config + +if __name__ == "__main__": + from lzero.entry import train_rezero + train_rezero([main_config, create_config], seed=0, max_env_step=max_env_step) \ No newline at end of file diff --git a/zoo/classic_control/pendulum/config/pendulum_cont_disc_gumbel_muzero_config.py b/zoo/classic_control/pendulum/config/pendulum_cont_disc_gumbel_muzero_config.py index f2c0749b6..bed3a192f 100644 --- a/zoo/classic_control/pendulum/config/pendulum_cont_disc_gumbel_muzero_config.py +++ b/zoo/classic_control/pendulum/config/pendulum_cont_disc_gumbel_muzero_config.py @@ -20,7 +20,7 @@ # ============================================================== pendulum_disc_gumbel_muzero_config = dict( - exp_name=f'data_mz_ctree/pendulum_disc_gumbel_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', + exp_name=f'data_muzero/pendulum_disc_gumbel_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', env=dict( env_id='Pendulum-v1', continuous=False, diff --git a/zoo/game_2048/config/muzero_2048_config.py b/zoo/game_2048/config/muzero_2048_config.py index 45f1c3e05..5b7204499 100644 --- a/zoo/game_2048/config/muzero_2048_config.py +++ b/zoo/game_2048/config/muzero_2048_config.py @@ -21,7 +21,7 @@ # ============================================================== atari_muzero_config = dict( - exp_name=f'data_mz_ctree/game_2048_npct-{num_of_possible_chance_tile}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_bs{batch_size}_sslw2_seed0', + exp_name=f'data_muzero/game_2048_npct-{num_of_possible_chance_tile}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_bs{batch_size}_sslw2_seed0', env=dict( stop_value=int(1e6), env_id=env_id, diff --git a/zoo/memory/config/memory_muzero_config.py b/zoo/memory/config/memory_muzero_config.py index 6b204d4f7..0477111f0 100644 --- a/zoo/memory/config/memory_muzero_config.py +++ b/zoo/memory/config/memory_muzero_config.py @@ -34,7 +34,7 @@ # ============================================================== memory_muzero_config = dict( - exp_name=f'data_mz_ctree/{env_id}_memlen-{memory_length}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_' + exp_name=f'data_muzero/{env_id}_memlen-{memory_length}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_' f'collect-eps-{eps_greedy_exploration_in_collect}_temp-final-steps-{threshold_training_steps_for_final_temperature}' f'_pelw{policy_entropy_loss_weight}_seed{seed}', env=dict( diff --git a/zoo/minigrid/config/minigrid_muzero_config.py b/zoo/minigrid/config/minigrid_muzero_config.py index 304d0860c..b66b311ee 100644 --- a/zoo/minigrid/config/minigrid_muzero_config.py +++ b/zoo/minigrid/config/minigrid_muzero_config.py @@ -25,7 +25,7 @@ # ============================================================== minigrid_muzero_config = dict( - exp_name=f'data_mz_ctree/{env_id}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_' + exp_name=f'data_muzero/{env_id}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_' f'collect-eps-{eps_greedy_exploration_in_collect}_temp-final-steps-{threshold_training_steps_for_final_temperature}_pelw{policy_entropy_loss_weight}_seed{seed}', env=dict( stop_value=int(1e6),