Skip to content

Commit

Permalink
feature(xcy): add ReZero algo. and related configs (#238)
Browse files Browse the repository at this point in the history
* feature(xcy):print tree and reuse the node

* polish(xcy): change test file

* feature(xcy): add node graph and reuse to test

* feature(xcy): add my algorithm

* feature(xcy): add my algorithm and the config file

* feature(xcy):add reuse and search_and_save

* feature(xcy): add breakout configs

* feature(xcy): add pong config

* feature(xcy):add buffer time log

* feature(xcy): add big batch code

* feature(xcy):add big batch and speed test

* polish(xcy): polish speed test

* feature(xcy):add speed and memory log

* feature(xcy):change to final framework

* feature(xcy): add configs and reanalyze freq

* Added Untitled Diagram.drawio

* feature(xcy):add MCTS collect and change cnode

* feature(pu): add rezero configs for lunarlander/connect4/gomoku

* feature(xcy): add qbert and upndown

* feature(xcy):add reez

* polish(xcy):little typos

* polish(xcy):requirements

* polish(pu): polish configs

* polish(xcy):test ez ratio1

* feature(xcy): return to gym atari

* polish(xcy):change configs

* polish(xcy):change config

* polish(xcy):change config

* polish(pu): polish configs

* feature(xcy):fix ma and add nrma

* polish(xcy):add config

* polish(xcy):polish data process for reuse search

* feature(xcy):refractor the code

* feature(xcy):back to gym

* add some config

* feature(xcy):Refactor policy collect and reuse

* polish(pu): polish game_buffer_rezero and configs

* polish(pu): polish search_with_reuse and policy

* polish(pu): polish muzero_collector

* polish(pu): polish rezero configs

* fix(pu): fix muzero_collector.py

* polish(pu): polish train_rezero.py

* polish(pu): polish buffer_reanalyze_freq

* polish(pu): polish padding action comments

* polish(pu): rename model_update_ratio to replay_ratio

* polish(pu): polish comments

* polish(pu): polish rezero configs

* polish(pu): add rezero in readme algo/env tables

---------

Co-authored-by: HarryXuancy <lhsxxcy@126.com>
Co-authored-by: HarryXuancy <52876902+HarryXuancy@users.noreply.github.com>
Co-authored-by: jiayilee65 <jiayilee65@163.com>
  • Loading branch information
4 people authored Jun 28, 2024
1 parent 61e8960 commit 87de8f9
Show file tree
Hide file tree
Showing 82 changed files with 3,363 additions and 313 deletions.
36 changes: 19 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,23 +122,25 @@ LightZero is a library with a [PyTorch](https://pytorch.org/) implementation of

The environments and algorithms currently supported by LightZero are shown in the table below:

| Env./Algo. | AlphaZero | MuZero | EfficientZero | Sampled EfficientZero | Gumbel MuZero | Stochastic MuZero | UniZero |
|---------------| -------- | ------ |-------------| ------------------ | ---------- |----------------|---------------|
| TicTacToe ||| 🔒 | 🔒 || 🔒 ||
| Gomoku ||| 🔒 | 🔒 || 🔒 ||
| Connect4 ||| 🔒 | 🔒 | 🔒 | 🔒 ||
| 2048 | --- || 🔒 | 🔒 | 🔒 |||
| Chess | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 |🔒|
| Go | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 |🔒|
| CartPole | --- |||||||
| Pendulum | --- ||||||🔒|
| LunarLander | --- |||||||
| BipedalWalker | --- ||||| 🔒 |🔒|
| Atari | --- |||||||
| MuJoCo | --- |||| 🔒 | 🔒 |🔒|
| MiniGrid | --- |||| 🔒 | 🔒 ||
| Bsuite | --- |||| 🔒 | 🔒 ||
| Memory | --- |||| 🔒 | 🔒 ||

| Env./Algo. | AlphaZero | MuZero | EfficientZero | Sampled EfficientZero | Gumbel MuZero | Stochastic MuZero | UniZero |ReZero |
|---------------| -------- | ------ |-------------| ------------------ | ---------- |----------------|---------------|----------------|
| TicTacToe ||| 🔒 | 🔒 || 🔒 ||🔒 |
| Gomoku ||| 🔒 | 🔒 || 🔒 |||
| Connect4 ||| 🔒 | 🔒 | 🔒 | 🔒 |||
| 2048 | --- || 🔒 | 🔒 | 🔒 |||🔒 |
| Chess | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 |🔒|🔒 |
| Go | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 |🔒|🔒 |
| CartPole | --- ||||||||
| Pendulum | --- ||||||🔒|🔒 |
| LunarLander | --- |||||||🔒 |
| BipedalWalker | --- ||||| 🔒 |🔒|🔒 |
| Atari | --- ||||||||
| MuJoCo | --- |||| 🔒 | 🔒 |🔒|🔒 |
| MiniGrid | --- |||| 🔒 | 🔒 ||🔒 |
| Bsuite | --- |||| 🔒 | 🔒 ||🔒 |
| Memory | --- |||| 🔒 | 🔒 ||🔒 |


<sup>(1): "✔" means that the corresponding item is finished and well-tested.</sup>

Expand Down
34 changes: 17 additions & 17 deletions README.zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,23 +110,23 @@ LightZero 是基于 [PyTorch](https://pytorch.org/) 实现的 MCTS 算法库,

LightZero 目前支持的环境及算法如下表所示:

| Env./Algo. | AlphaZero | MuZero | EfficientZero | Sampled EfficientZero | Gumbel MuZero | Stochastic MuZero | UniZero |
|---------------| -------- | ------ |-------------| ------------------ | ---------- |----------------|---------------|
| TicTacToe ||| 🔒 | 🔒 || 🔒 ||
| Gomoku ||| 🔒 | 🔒 || 🔒 ||
| Connect4 ||| 🔒 | 🔒 | 🔒 | 🔒 ||
| 2048 | --- || 🔒 | 🔒 | 🔒 |||
| Chess | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 |🔒|
| Go | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 |🔒|
| CartPole | --- |||||||
| Pendulum | --- ||||||🔒|
| LunarLander | --- |||||||
| BipedalWalker | --- ||||| 🔒 |🔒|
| Atari | --- |||||||
| MuJoCo | --- |||| 🔒 | 🔒 |🔒|
| MiniGrid | --- |||| 🔒 | 🔒 ||
| Bsuite | --- |||| 🔒 | 🔒 ||
| Memory | --- |||| 🔒 | 🔒 ||
| Env./Algo. | AlphaZero | MuZero | EfficientZero | Sampled EfficientZero | Gumbel MuZero | Stochastic MuZero | UniZero |ReZero |
|---------------| -------- | ------ |-------------| ------------------ | ---------- |----------------|---------------|----------------|
| TicTacToe ||| 🔒 | 🔒 || 🔒 ||🔒 |
| Gomoku ||| 🔒 | 🔒 || 🔒 |||
| Connect4 ||| 🔒 | 🔒 | 🔒 | 🔒 |||
| 2048 | --- || 🔒 | 🔒 | 🔒 |||🔒 |
| Chess | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 |🔒|🔒 |
| Go | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 |🔒|🔒 |
| CartPole | --- ||||||||
| Pendulum | --- ||||||🔒|🔒 |
| LunarLander | --- |||||||🔒 |
| BipedalWalker | --- ||||| 🔒 |🔒|🔒 |
| Atari | --- ||||||||
| MuJoCo | --- |||| 🔒 | 🔒 |🔒|🔒 |
| MiniGrid | --- |||| 🔒 | 🔒 ||🔒 |
| Bsuite | --- |||| 🔒 | 🔒 ||🔒 |
| Memory | --- |||| 🔒 | 🔒 ||🔒 |

<sup>(1): "✔" 表示对应的项目已经完成并经过良好的测试。</sup>

Expand Down
4 changes: 2 additions & 2 deletions lzero/agent/alphazero.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,9 @@ def train(
new_data = sum(new_data, [])

if self.cfg.policy.update_per_collect is None:
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio.
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio.
collected_transitions_num = len(new_data)
update_per_collect = int(collected_transitions_num * self.cfg.policy.model_update_ratio)
update_per_collect = int(collected_transitions_num * self.cfg.policy.replay_ratio)
replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)

# Learn policy from collected data
Expand Down
4 changes: 2 additions & 2 deletions lzero/agent/efficientzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,9 @@ def train(
# Collect data by default config n_sample/n_episode.
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
if self.cfg.policy.update_per_collect is None:
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio.
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio.
collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]])
update_per_collect = int(collected_transitions_num * self.cfg.policy.model_update_ratio)
update_per_collect = int(collected_transitions_num * self.cfg.policy.replay_ratio)
# save returned new_data collected by the collector
replay_buffer.push_game_segments(new_data)
# remove the oldest data if the replay buffer is full.
Expand Down
4 changes: 2 additions & 2 deletions lzero/agent/gumbel_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,9 @@ def train(
# Collect data by default config n_sample/n_episode.
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
if self.cfg.policy.update_per_collect is None:
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio.
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio.
collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]])
update_per_collect = int(collected_transitions_num * self.cfg.policy.model_update_ratio)
update_per_collect = int(collected_transitions_num * self.cfg.policy.replay_ratio)
# save returned new_data collected by the collector
replay_buffer.push_game_segments(new_data)
# remove the oldest data if the replay buffer is full.
Expand Down
4 changes: 2 additions & 2 deletions lzero/agent/muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,9 @@ def train(
# Collect data by default config n_sample/n_episode.
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
if self.cfg.policy.update_per_collect is None:
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio.
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio.
collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]])
update_per_collect = int(collected_transitions_num * self.cfg.policy.model_update_ratio)
update_per_collect = int(collected_transitions_num * self.cfg.policy.replay_ratio)
# save returned new_data collected by the collector
replay_buffer.push_game_segments(new_data)
# remove the oldest data if the replay buffer is full.
Expand Down
4 changes: 2 additions & 2 deletions lzero/agent/sampled_alphazero.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,9 @@ def train(
new_data = sum(new_data, [])

if self.cfg.policy.update_per_collect is None:
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio.
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio.
collected_transitions_num = len(new_data)
update_per_collect = int(collected_transitions_num * self.cfg.policy.model_update_ratio)
update_per_collect = int(collected_transitions_num * self.cfg.policy.replay_ratio)
replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)

# Learn policy from collected data
Expand Down
4 changes: 2 additions & 2 deletions lzero/agent/sampled_efficientzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,9 @@ def train(
# Collect data by default config n_sample/n_episode.
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
if self.cfg.policy.update_per_collect is None:
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio.
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio.
collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]])
update_per_collect = int(collected_transitions_num * self.cfg.policy.model_update_ratio)
update_per_collect = int(collected_transitions_num * self.cfg.policy.replay_ratio)
# save returned new_data collected by the collector
replay_buffer.push_game_segments(new_data)
# remove the oldest data if the replay buffer is full.
Expand Down
3 changes: 2 additions & 1 deletion lzero/entry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
from .train_muzero_with_reward_model import train_muzero_with_reward_model
from .eval_muzero import eval_muzero
from .eval_muzero_with_gym_env import eval_muzero_with_gym_env
from .train_muzero_with_gym_env import train_muzero_with_gym_env
from .train_muzero_with_gym_env import train_muzero_with_gym_env
from .train_rezero import train_rezero
4 changes: 2 additions & 2 deletions lzero/entry/train_alphazero.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,9 @@ def train_alphazero(
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
new_data = sum(new_data, [])
if cfg.policy.update_per_collect is None:
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio.
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio.
collected_transitions_num = len(new_data)
update_per_collect = int(collected_transitions_num * cfg.policy.model_update_ratio)
update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio)
replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)

# Learn policy from collected data
Expand Down
10 changes: 5 additions & 5 deletions lzero/entry/train_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
from ding.envs import create_env_manager
from ding.envs import get_vec_env_setting
from ding.policy import create_policy
from ding.utils import set_pkg_seed, get_rank
from ding.rl_utils import get_epsilon_greedy_fn
from ding.utils import set_pkg_seed, get_rank
from ding.worker import BaseLearner
from tensorboardX import SummaryWriter

from lzero.entry.utils import log_buffer_memory_usage
from lzero.entry.utils import log_buffer_memory_usage, log_buffer_run_time
from lzero.policy import visit_count_temperature
from lzero.policy.random_policy import LightZeroRandomPolicy
from lzero.worker import MuZeroCollector as Collector
Expand Down Expand Up @@ -69,7 +69,6 @@ def train_muzero(
cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)
# Create main components: env, policy
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)

collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])

Expand Down Expand Up @@ -138,6 +137,7 @@ def train_muzero(

while True:
log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger)
log_buffer_run_time(learner.train_iter, replay_buffer, tb_logger)
collect_kwargs = {}
# set temperature for visit count distributions according to the train_iter,
# please refer to Appendix D in MuZero paper for details.
Expand Down Expand Up @@ -172,9 +172,9 @@ def train_muzero(
# Collect data by default config n_sample/n_episode.
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
if cfg.policy.update_per_collect is None:
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio.
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio.
collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]])
update_per_collect = int(collected_transitions_num * cfg.policy.model_update_ratio)
update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio)
# save returned new_data collected by the collector
replay_buffer.push_game_segments(new_data)
# remove the oldest data if the replay buffer is full.
Expand Down
4 changes: 2 additions & 2 deletions lzero/entry/train_muzero_with_gym_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,9 @@ def train_muzero_with_gym_env(
# Collect data by default config n_sample/n_episode.
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
if cfg.policy.update_per_collect is None:
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio.
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio.
collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]])
update_per_collect = int(collected_transitions_num * cfg.policy.model_update_ratio)
update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio)
# save returned new_data collected by the collector
replay_buffer.push_game_segments(new_data)
# remove the oldest data if the replay buffer is full.
Expand Down
Loading

0 comments on commit 87de8f9

Please sign in to comment.