Skip to content

Commit

Permalink
feature(pu): add Sampled MuZero/UniZero, DMC env and related configs (#…
Browse files Browse the repository at this point in the history
…260)

* feature(pu): add dmc2gym and related configs

* fix(pu): fix sampled_efficientzero_model for dmc2gym

* sync code

* polish(pu): polish sez config for dmc2gym and lunarlander

* feature(pu): add Sampled MuZero

* feature(pu): add lunarlander sampled muzero config

* feature(pu): add sampled unizero and its pendulum config

* feature(pu): add sampled unizero's lunarlander and bipedalwalker config

* sync code

* polish(pu): polish sampled muzero

* fix(pu): fix min_max_stats bug in ctree_sampled_muzero

* fix(pu): fix min_max_stats bug in ctree_sampled_muzero

* polish(pu): polish sampled related configs

* feature(pu): add dmc2gym sampled related configs

* fix(pu): fix dmc2gym suz config

* fix(pu): use LN in sampled unizero

* polish(pu): use sim_norm in act_embedding in continuous action space

* fix(pu): fix expand bug in policy_loss of sampled unizero

* polish(pu): polish sampled unziero lunarlander configs

* fix(pu): fix sampled unizero action .long() bug in continuous action space

* polish(pu): polish dmc state sampled-unizero configs

* fix(pu): fix label_policy in sampled unizero

* polish(pu): polish sampled related env/configs/policy/model/buffer

* polish(pu): update readme

---------

Co-authored-by: dyyoungg <yangdeyu@sensetime.com>
Co-authored-by: jiayilee65 <jiayilee65@163.com>
  • Loading branch information
3 people authored Aug 18, 2024
1 parent 0064381 commit 8300a52
Show file tree
Hide file tree
Showing 58 changed files with 8,072 additions and 308 deletions.
41 changes: 21 additions & 20 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
[![GitHub license](https://img.shields.io/github/license/opendilab/LightZero)](https://github.com/opendilab/LightZero/blob/master/LICENSE)
[![discord badge](https://dcbadge.vercel.app/api/server/dkZS2JF56X?style=flat)](https://discord.gg/dkZS2JF56X)

Updated on 2024.07.12 LightZero-v0.1.0
Updated on 2024.08.18 LightZero-v0.1.0

English | [简体中文(Simplified Chinese)](https://github.com/opendilab/LightZero/blob/main/README.zh.md) | [Documentation](https://opendilab.github.io/LightZero) | [LightZero Paper](https://arxiv.org/abs/2310.08348) | [🔥UniZero Paper](https://arxiv.org/abs/2406.10667) | [🔥ReZero Paper](https://arxiv.org/abs/2404.16364)

Expand Down Expand Up @@ -127,25 +127,26 @@ LightZero is a library with a [PyTorch](https://pytorch.org/) implementation of
The environments and algorithms currently supported by LightZero are shown in the table below:


| Env./Algo. | AlphaZero | MuZero | EfficientZero | Sampled EfficientZero | Gumbel MuZero | Stochastic MuZero | UniZero |ReZero |
|---------------| -------- | ------ |-------------| ------------------ | ---------- |----------------|---------------|----------------|
| TicTacToe ||| 🔒 | 🔒 || 🔒 ||🔒 |
| Gomoku ||| 🔒 | 🔒 || 🔒 |||
| Connect4 ||| 🔒 | 🔒 | 🔒 | 🔒 |||
| 2048 | --- || 🔒 | 🔒 | 🔒 |||🔒 |
| Chess | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 |🔒|🔒 |
| Go | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 |🔒|🔒 |
| CartPole | --- ||||||||
| Pendulum | --- ||||||🔒|🔒 |
| LunarLander | --- |||||||🔒 |
| BipedalWalker | --- ||||| 🔒 |🔒|🔒 |
| Atari | --- ||||||||
| MuJoCo | --- |||| 🔒 | 🔒 |🔒|🔒 |
| MiniGrid | --- |||| 🔒 | 🔒 ||🔒 |
| Bsuite | --- |||| 🔒 | 🔒 ||🔒 |
| Memory | --- |||| 🔒 | 🔒 ||🔒 |
| SumToThree (billiards) | --- | 🔒 | 🔒 || 🔒 | 🔒 |🔒|🔒 |
| MetaDrive | --- | 🔒 | 🔒 || 🔒 | 🔒 | 🔒 |🔒 |
| Env./Algo. | AlphaZero | MuZero | Sampled MuZero | EfficientZero | Sampled EfficientZero | Gumbel MuZero | Stochastic MuZero | UniZero | Sampled UniZero | ReZero |
|------------------------| -------- | ---- |---------------| ---------- | ------------------ | ------------- | ---------------- | ------- | --- | ------ |
| TicTacToe ||| 🔒 | 🔒 | 🔒 || 🔒 || 🔒 | 🔒 |
| Gomoku ||| 🔒 | 🔒 | 🔒 || 🔒 || 🔒 ||
| Connect4 ||| 🔒 | 🔒 | 🔒 | 🔒 | 🔒 || 🔒 ||
| 2048 | --- || 🔒 | 🔒 | 🔒 | 🔒 ||| 🔒 | 🔒 |
| Chess | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 |
| Go | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 |
| CartPole | --- || 🔒 |||||| 🔒 ||
| Pendulum | --- ||||||| 🔒 || 🔒 |
| LunarLander | --- ||||||||| 🔒 |
| BipedalWalker | --- |||||| 🔒 | 🔒 || 🔒 |
| Atari | --- || 🔒 |||||| 🔒 ||
| DeepMind Control | --- | --- || --- || 🔒 | 🔒 | 🔒 || 🔒 |
| MuJoCo | --- || 🔒 ||| 🔒 | 🔒 | 🔒 | 🔒 | 🔒 |
| MiniGrid | --- || 🔒 ||| 🔒 | 🔒 || 🔒 | 🔒 |
| Bsuite | --- || 🔒 ||| 🔒 | 🔒 || 🔒 | 🔒 |
| Memory | --- || 🔒 ||| 🔒 | 🔒 || 🔒 | 🔒 |
| SumToThree (billiards) | --- | 🔒 | 🔒 | 🔒 || 🔒 | 🔒 | 🔒 | 🔒 | 🔒 |
| MetaDrive | --- | 🔒 | 🔒 | 🔒 || 🔒 | 🔒 | 🔒 | 🔒 |🔒 |


<sup>(1): "✔" means that the corresponding item is finished and well-tested.</sup>
Expand Down
39 changes: 20 additions & 19 deletions README.zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
[![Contributors](https://img.shields.io/github/contributors/opendilab/LightZero)](https://github.com/opendilab/LightZero/graphs/contributors)
[![GitHub license](https://img.shields.io/github/license/opendilab/LightZero)](https://github.com/opendilab/LightZero/blob/master/LICENSE)

最近更新于 2024.07.12 LightZero-v0.1.0
最近更新于 2024.08.18 LightZero-v0.1.0

[English](https://github.com/opendilab/LightZero/blob/main/README.md) | 简体中文 | [文档](https://opendilab.github.io/LightZero) | [LightZero 论文](https://arxiv.org/abs/2310.08348) | [🔥UniZero 论文](https://arxiv.org/abs/2406.10667) | [🔥ReZero 论文](https://arxiv.org/abs/2404.16364)

Expand Down Expand Up @@ -112,24 +112,25 @@ LightZero 是基于 [PyTorch](https://pytorch.org/) 实现的 MCTS 算法库,

LightZero 目前支持的环境及算法如下表所示:

| Env./Algo. | AlphaZero | MuZero | EfficientZero | Sampled EfficientZero | Gumbel MuZero | Stochastic MuZero | UniZero |ReZero |
|---------------| -------- | ------ |-------------| ------------------ | ---------- |----------------|---------------|----------------|
| TicTacToe ||| 🔒 | 🔒 || 🔒 ||🔒 |
| Gomoku ||| 🔒 | 🔒 || 🔒 |||
| Connect4 ||| 🔒 | 🔒 | 🔒 | 🔒 |||
| 2048 | --- || 🔒 | 🔒 | 🔒 |||🔒 |
| Chess | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 |🔒|🔒 |
| Go | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 |🔒|🔒 |
| CartPole | --- ||||||||
| Pendulum | --- ||||||🔒|🔒 |
| LunarLander | --- |||||||🔒 |
| BipedalWalker | --- ||||| 🔒 |🔒|🔒 |
| Atari | --- ||||||||
| MuJoCo | --- |||| 🔒 | 🔒 |🔒|🔒 |
| MiniGrid | --- |||| 🔒 | 🔒 ||🔒 |
| Bsuite | --- |||| 🔒 | 🔒 ||🔒 |
| Memory | --- |||| 🔒 | 🔒 ||🔒 |
| SumToThree (billiards) | --- | 🔒 | 🔒 || 🔒 | 🔒 |🔒|🔒 |
| Env./Algo. | AlphaZero | MuZero | Sampled MuZero | EfficientZero | Sampled EfficientZero | Gumbel MuZero | Stochastic MuZero | UniZero | Sampled UniZero | ReZero |
|------------------------| -------- | ---- |---------------| ---------- | ------------------ | ------------- | ---------------- | ------- | --- | ------ |
| TicTacToe ||| 🔒 | 🔒 | 🔒 || 🔒 || 🔒 | 🔒 |
| Gomoku ||| 🔒 | 🔒 | 🔒 || 🔒 || 🔒 ||
| Connect4 ||| 🔒 | 🔒 | 🔒 | 🔒 | 🔒 || 🔒 ||
| 2048 | --- || 🔒 | 🔒 | 🔒 | 🔒 ||| 🔒 | 🔒 |
| Chess | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 |
| Go | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 |
| CartPole | --- || 🔒 |||||| 🔒 ||
| Pendulum | --- ||||||| 🔒 || 🔒 |
| LunarLander | --- ||||||||| 🔒 |
| BipedalWalker | --- |||||| 🔒 | 🔒 || 🔒 |
| Atari | --- || 🔒 |||||| 🔒 ||
| DeepMind Control | --- | --- || --- || 🔒 | 🔒 | 🔒 || 🔒 |
| MuJoCo | --- || 🔒 ||| 🔒 | 🔒 | 🔒 | 🔒 | 🔒 |
| MiniGrid | --- || 🔒 ||| 🔒 | 🔒 || 🔒 | 🔒 |
| Bsuite | --- || 🔒 ||| 🔒 | 🔒 || 🔒 | 🔒 |
| Memory | --- || 🔒 ||| 🔒 | 🔒 || 🔒 | 🔒 |
| SumToThree (billiards) | --- | 🔒 | 🔒 | 🔒 || 🔒 | 🔒 | 🔒 | 🔒 | 🔒 |

<sup>(1): "✔" 表示对应的项目已经完成并经过良好的测试。</sup>

Expand Down
4 changes: 3 additions & 1 deletion lzero/entry/train_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def train_muzero(
"""

cfg, create_cfg = input_cfg
assert create_cfg.policy.type in ['efficientzero', 'muzero', 'muzero_context', 'muzero_rnn_full_obs', 'sampled_efficientzero', 'gumbel_muzero', 'stochastic_muzero'], \
assert create_cfg.policy.type in ['efficientzero', 'muzero', 'muzero_context', 'muzero_rnn_full_obs', 'sampled_efficientzero', 'sampled_muzero', 'gumbel_muzero', 'stochastic_muzero'], \
"train_muzero entry now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'stochastic_muzero'"

if create_cfg.policy.type in ['muzero', 'muzero_context', 'muzero_rnn_full_obs']:
Expand All @@ -56,6 +56,8 @@ def train_muzero(
from lzero.mcts import EfficientZeroGameBuffer as GameBuffer
elif create_cfg.policy.type == 'sampled_efficientzero':
from lzero.mcts import SampledEfficientZeroGameBuffer as GameBuffer
elif create_cfg.policy.type == 'sampled_muzero':
from lzero.mcts import SampledMuZeroGameBuffer as GameBuffer
elif create_cfg.policy.type == 'gumbel_muzero':
from lzero.mcts import GumbelMuZeroGameBuffer as GameBuffer
elif create_cfg.policy.type == 'stochastic_muzero':
Expand Down
6 changes: 3 additions & 3 deletions lzero/entry/train_unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ def train_unizero(
cfg, create_cfg = input_cfg

# Ensure the specified policy type is supported
assert create_cfg.policy.type in ['unizero'], "train_unizero entry now only supports the following algo.: 'unizero'"
assert create_cfg.policy.type in ['unizero', 'sampled_unizero'], "train_unizero entry now only supports the following algo.: 'unizero', 'sampled_unizero'"

# Import the correct GameBuffer class based on the policy type
game_buffer_classes = {'unizero': 'UniZeroGameBuffer'}
game_buffer_classes = {'unizero': 'UniZeroGameBuffer', 'sampled_unizero': 'SampledUniZeroGameBuffer'}

GameBuffer = getattr(__import__('lzero.mcts', fromlist=[game_buffer_classes[create_cfg.policy.type]]),
game_buffer_classes[create_cfg.policy.type])
Expand Down Expand Up @@ -107,7 +107,7 @@ def train_unizero(
batch_size = policy._cfg.batch_size

# TODO: for visualize
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
# stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)

while True:
# Log buffer memory usage
Expand Down
2 changes: 2 additions & 0 deletions lzero/mcts/buffer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from .game_buffer_unizero import UniZeroGameBuffer
from .game_buffer_efficientzero import EfficientZeroGameBuffer
from .game_buffer_sampled_efficientzero import SampledEfficientZeroGameBuffer
from .game_buffer_sampled_muzero import SampledMuZeroGameBuffer
from .game_buffer_sampled_unizero import SampledUniZeroGameBuffer
from .game_buffer_gumbel_muzero import GumbelMuZeroGameBuffer
from .game_buffer_stochastic_muzero import StochasticMuZeroGameBuffer
from .game_buffer_rezero_mz import ReZeroMZGameBuffer
Expand Down
11 changes: 7 additions & 4 deletions lzero/mcts/buffer/game_buffer_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,15 +504,18 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
target_values.append(value_list[value_index])
target_rewards.append(reward_list[current_index])
else:
target_values.append(0)
target_rewards.append(0.0)
target_values.append(np.array([0.]))
target_rewards.append(np.array([0.]))
value_index += 1

batch_rewards.append(target_rewards)
batch_target_values.append(target_values)

batch_rewards = np.asarray(batch_rewards, dtype=object)
batch_target_values = np.asarray(batch_target_values, dtype=object)
batch_rewards = np.asarray(batch_rewards)
batch_target_values = np.asarray(batch_target_values)
batch_rewards = np.squeeze(batch_rewards, axis=-1)
batch_target_values = np.squeeze(batch_target_values, axis=-1)

return batch_rewards, batch_target_values

# @profile
Expand Down
12 changes: 7 additions & 5 deletions lzero/mcts/buffer/game_buffer_sampled_efficientzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
target_values = []
target_value_prefixs = []

value_prefix = 0.0
value_prefix = np.array([0.])
base_index = state_index
for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1):
bootstrap_index = current_index + td_steps_list[value_index]
Expand All @@ -393,7 +393,7 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A

# reset every lstm_horizon_len
if horizon_id % self._cfg.lstm_horizon_len == 0:
value_prefix = 0.0
value_prefix = np.array([0.])
base_index = current_index
horizon_id += 1

Expand All @@ -405,16 +405,18 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
] # * config.discount_factor ** (current_index - base_index)
target_value_prefixs.append(value_prefix)
else:
target_values.append(0)
target_values.append(np.array([0.]))
target_value_prefixs.append(value_prefix)

value_index += 1

batch_value_prefixs.append(target_value_prefixs)
batch_target_values.append(target_values)

batch_value_prefixs = np.asarray(batch_value_prefixs, dtype=object)
batch_target_values = np.asarray(batch_target_values, dtype=object)
batch_value_prefixs = np.asarray(batch_value_prefixs)
batch_target_values = np.asarray(batch_target_values)
batch_value_prefixs = np.squeeze(batch_value_prefixs, axis=-1)
batch_target_values = np.squeeze(batch_target_values, axis=-1)

return batch_value_prefixs, batch_target_values

Expand Down
Loading

0 comments on commit 8300a52

Please sign in to comment.