Skip to content

Commit

Permalink
polish(pu): polish efficiency and performance on atari and DMC, add m…
Browse files Browse the repository at this point in the history
…uzero_segment_collector (#292)

* polish(pu): polish quantize_state_hash and deepcopy

* fix(pu): fix np.array dtype bug in buffer

* polish(pu): use 0 deepcopy in kv_cache operation in collect/eval phase of unizero

* polish(pu): use custom deepcopy for kv_cache

* polish(pu): use value_array rather than value_list in compute_target_value

* polish(pu): optimize compute_target_policy_non_re

* polish(pu): optimize kv_caching update()

* polish(pu): kv_cache_dict no to_cpu

* polish(pu): optimize custom kv_cache copy

* polish(pu): kv_cache_dict no to_cpu

* feature(pu): add unizero ddp config

* fix(pu): fix unizero ddp

* sync code

* polish(pu): use de kv_cacheepcopy only in recur_infer load

* sync code

* polish(pu): polish suz dmc config

* sync code

* polish(pu): use share_polol for kv_cache in recurrent_inference and use _copy rather than clone

* polish(pu): all kv_cache copy use predefined share_pool

* polish(pu): unuse decoder_net and lpips in ddp config

* sync code

* feature(pu): add dmc save_replay_gif option

* sync code

* polish(pu): polish sampled muzero ctree

* test(pu): add sac cheetah config

* fix(pu): fix render_image in dmc_env

* fix(pu): fix reanalyze in sampled unizero

* polish(pu): polish policy projector

* feature(pu): add muzero_segment_collector.py

* polish(pu): use uniform prior in ucb_score of suz mcts

* fix(pu): fix self.action_mask_dict init bug

* test(pu): use clamp0.9->1

* polish(pu): polish suz

* fix(pu): fix muzero_segment_collector

* fix(pu): uz target-value obs also use aug when use_aug=True

* sync code

* fix(pu): fix last_game_segment bug in muzero_segment_collector.py

* fix(pu): one episode done then return in muzero_segment_collector.py

* fix(pu): fix muzero_collector

* polish(pu): polish unizero config and polish sample from segments

* fix(pu): fix reanalyze in uz

* polish(pu): add batch config and bash

* polish(pu): polish uz configs

* feature(pu): add unizero buffer_reanalyze variant

* fix(pu): fix uz reanalyze_buffer

* polish(pu): polish configs

* feature(pu): add atari_muzero_segment_config

* fix(pu): fix sampled_unizero reanalyze_policy

* polish(pu):polish configs

* polish(pu):polish suz configs

* polish(pu):polish configs

* fix(pu): fix root value in suz buffer

* fix(pu): fix suz ctree

* polish(pu): polish uz related configs, segment collector, train_entry

* polish(pu): polish unizero world_model

* polish(pu): polish reanalyze in buffer

* fix(pu): fix entry import and nparray object bug in buffer

* polish(pu): polish configs

* polish(pu): polish configs

* polish(pu): fix collector, polish configs

* fix(pu): fix truncation segment sample in buffer

* fix(pu): fix segment sample for uz in buffer

* fix(pu): use origin buffer

* fix(pu): fixvaluebugV8

* sync code

* fix(pu): fix target action when calculating bootstrap value in unizero

* fix(pu): fix target-action in sampled_unizero buffer

* polish(pu): delete wrongly added files

* polish(pu): polish entry/buffer/ctree, and fix index+1 bug in compute_target_reward_value

* polish(pu): polish buffer and config

* polish(pu): rename train_xxx_reanalyze to train_xxx_segment

* polish(pu): polish world_model

* polish(pu): polish entry comments

* fix(pu): fix reward shape bug in dmc

* fix(pu): polish sample_orig_reanalyze_batch and fix sample_orig_data due to child_visits bug

* polish(pu): polish comments in _sample_orig_reanalyze_batch

* fix(pu): add pad_action_lst in muzero_collector

* polish(pu): polish dmc suz configs

* fix(pu): fix reanalyzed_root_sampled_action in suz buffer

* fix(pu): fix logp calculation in mcts expand, use clamp_limit for sampled actions, use 1e5 as total train steps in cos_lr_decay, polish cont policy loss

* polish(pu): sample init position from the whole segment

* fix(pu): fix empirical_distribution_type compare bug and half sampled actions sampled from a flatten gaussian

* polish(pu): polish config and reward shape

* polish(pu): polish memory config

* polish(pu): polish config

* polish(pu): polish config and comments

* polish(pu): polish comments

* polish(pu): polish comments and config

* polish(pu): delete unused config

* polish(pu): polish comments and docstring in unizero

---------

Co-authored-by: PaParaZz1 <niuyazhe314@outlook.com>
Co-authored-by: dyyoungg <yangdeyu@sensetime.com>
Co-authored-by: jiayilee65 <jiayilee65@163.com>
  • Loading branch information
4 people authored Nov 14, 2024
1 parent d27f29a commit dd7a5eb
Show file tree
Hide file tree
Showing 185 changed files with 4,001 additions and 1,268 deletions.
2 changes: 1 addition & 1 deletion docs/source/tutorials/config/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ The `main_config` dictionary contains the main parameter settings for running th
- `update_per_collect`: The number of updates after each data collection.
- `batch_size`: The batch size sampled during the update.
- `optim_type`: Optimizer type.
- `lr_piecewise_constant_decay`: Whether to use piecewise constant learning rate decay.
- `piecewise_decay_lr_scheduler`: Whether to use piecewise constant learning rate decay.
- `learning_rate`: Initial learning rate.
- `num_simulations`: The number of simulations used in the MCTS algorithm.
- `reanalyze_ratio`: Reanalysis coefficient, controlling the probability of reanalysis.
Expand Down
2 changes: 1 addition & 1 deletion docs/source/tutorials/config/config_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
- `update_per_collect`: 每次数据收集后更新的次数。
- `batch_size`: 更新时采样的批量大小。
- `optim_type`: 优化器类型。
- `lr_piecewise_constant_decay`: 是否使用分段常数学习率衰减。
- `piecewise_decay_lr_scheduler`: 是否使用分段常数学习率衰减。
- `learning_rate`: 初始学习率。
- `num_simulations`: MCTS算法中使用的模拟次数。
- `reanalyze_ratio`: 重分析系数,控制进行重分析的概率。
Expand Down
2 changes: 1 addition & 1 deletion lzero/agent/config/alphazero/gomoku_play_with_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
update_per_collect=update_per_collect,
batch_size=batch_size,
optim_type='Adam',
lr_piecewise_constant_decay=False,
piecewise_decay_lr_scheduler=False,
learning_rate=0.003,
grad_clip_value=0.5,
value_weight=1.0,
Expand Down
2 changes: 1 addition & 1 deletion lzero/agent/config/alphazero/tictactoe_play_with_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
update_per_collect=update_per_collect,
batch_size=batch_size,
optim_type='Adam',
lr_piecewise_constant_decay=False,
piecewise_decay_lr_scheduler=False,
learning_rate=0.003,
grad_clip_value=0.5,
value_weight=1.0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
update_per_collect=update_per_collect,
batch_size=batch_size,
optim_type='SGD',
lr_piecewise_constant_decay=True,
piecewise_decay_lr_scheduler=True,
learning_rate=0.2,
num_simulations=num_simulations,
reanalyze_ratio=reanalyze_ratio,
Expand Down
2 changes: 1 addition & 1 deletion lzero/agent/config/efficientzero/gym_cartpole_v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
update_per_collect=update_per_collect,
batch_size=batch_size,
optim_type='Adam',
lr_piecewise_constant_decay=False,
piecewise_decay_lr_scheduler=False,
learning_rate=0.003,
num_simulations=num_simulations,
reanalyze_ratio=reanalyze_ratio,
Expand Down
2 changes: 1 addition & 1 deletion lzero/agent/config/efficientzero/gym_lunarlander_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
update_per_collect=update_per_collect,
batch_size=batch_size,
optim_type='Adam',
lr_piecewise_constant_decay=False,
piecewise_decay_lr_scheduler=False,
learning_rate=0.003,
grad_clip_value=0.5,
num_simulations=num_simulations,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
update_per_collect=update_per_collect,
batch_size=batch_size,
optim_type='SGD',
lr_piecewise_constant_decay=True,
piecewise_decay_lr_scheduler=True,
learning_rate=0.2,
num_simulations=num_simulations,
reanalyze_ratio=reanalyze_ratio,
Expand Down
2 changes: 1 addition & 1 deletion lzero/agent/config/efficientzero/gym_pendulum_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
update_per_collect=update_per_collect,
batch_size=batch_size,
optim_type='Adam',
lr_piecewise_constant_decay=False,
piecewise_decay_lr_scheduler=False,
learning_rate=0.003,
num_simulations=num_simulations,
reanalyze_ratio=reanalyze_ratio,
Expand Down
2 changes: 1 addition & 1 deletion lzero/agent/config/efficientzero/gym_pongnoframeskip_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
update_per_collect=update_per_collect,
batch_size=batch_size,
optim_type='SGD',
lr_piecewise_constant_decay=True,
piecewise_decay_lr_scheduler=True,
learning_rate=0.2,
num_simulations=num_simulations,
reanalyze_ratio=reanalyze_ratio,
Expand Down
2 changes: 1 addition & 1 deletion lzero/agent/config/gumbel_muzero/gomoku_play_with_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
update_per_collect=update_per_collect,
batch_size=batch_size,
optim_type='Adam',
lr_piecewise_constant_decay=False,
piecewise_decay_lr_scheduler=False,
learning_rate=0.003,
grad_clip_value=0.5,
num_simulations=num_simulations,
Expand Down
2 changes: 1 addition & 1 deletion lzero/agent/config/gumbel_muzero/gym_cartpole_v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
batch_size=batch_size,
optim_type='Adam',
max_num_considered_actions=2,
lr_piecewise_constant_decay=False,
piecewise_decay_lr_scheduler=False,
learning_rate=0.003,
ssl_loss_weight=2, # NOTE: default is 0.
num_simulations=num_simulations,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
update_per_collect=update_per_collect,
batch_size=batch_size,
optim_type='Adam',
lr_piecewise_constant_decay=False,
piecewise_decay_lr_scheduler=False,
learning_rate=0.003,
grad_clip_value=0.5,
num_simulations=num_simulations,
Expand Down
2 changes: 1 addition & 1 deletion lzero/agent/config/muzero/gomoku_play_with_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
update_per_collect=update_per_collect,
batch_size=batch_size,
optim_type='Adam',
lr_piecewise_constant_decay=False,
piecewise_decay_lr_scheduler=False,
learning_rate=0.003,
grad_clip_value=0.5,
num_simulations=num_simulations,
Expand Down
2 changes: 1 addition & 1 deletion lzero/agent/config/muzero/gym_breakoutnoframeskip_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
update_per_collect=update_per_collect,
batch_size=batch_size,
optim_type='SGD',
lr_piecewise_constant_decay=True,
piecewise_decay_lr_scheduler=True,
learning_rate=0.2,
num_simulations=num_simulations,
reanalyze_ratio=reanalyze_ratio,
Expand Down
2 changes: 1 addition & 1 deletion lzero/agent/config/muzero/gym_cartpole_v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
update_per_collect=update_per_collect,
batch_size=batch_size,
optim_type='Adam',
lr_piecewise_constant_decay=False,
piecewise_decay_lr_scheduler=False,
learning_rate=0.003,
ssl_loss_weight=2, # NOTE: default is 0.
num_simulations=num_simulations,
Expand Down
2 changes: 1 addition & 1 deletion lzero/agent/config/muzero/gym_lunarlander_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
update_per_collect=update_per_collect,
batch_size=batch_size,
optim_type='Adam',
lr_piecewise_constant_decay=False,
piecewise_decay_lr_scheduler=False,
learning_rate=0.003,
ssl_loss_weight=2, # NOTE: default is 0.
grad_clip_value=0.5,
Expand Down
2 changes: 1 addition & 1 deletion lzero/agent/config/muzero/gym_mspacmannoframeskip_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
update_per_collect=update_per_collect,
batch_size=batch_size,
optim_type='SGD',
lr_piecewise_constant_decay=True,
piecewise_decay_lr_scheduler=True,
learning_rate=0.2,
num_simulations=num_simulations,
reanalyze_ratio=reanalyze_ratio,
Expand Down
2 changes: 1 addition & 1 deletion lzero/agent/config/muzero/gym_pendulum_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
update_per_collect=update_per_collect,
batch_size=batch_size,
optim_type='Adam',
lr_piecewise_constant_decay=False,
piecewise_decay_lr_scheduler=False,
learning_rate=0.003,
ssl_loss_weight=2, # NOTE: default is 0.
num_simulations=num_simulations,
Expand Down
2 changes: 1 addition & 1 deletion lzero/agent/config/muzero/gym_pongnoframeskip_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
update_per_collect=update_per_collect,
batch_size=batch_size,
optim_type='SGD',
lr_piecewise_constant_decay=True,
piecewise_decay_lr_scheduler=True,
learning_rate=0.2,
num_simulations=num_simulations,
reanalyze_ratio=reanalyze_ratio,
Expand Down
2 changes: 1 addition & 1 deletion lzero/agent/config/muzero/tictactoe_play_with_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
update_per_collect=update_per_collect,
batch_size=batch_size,
optim_type='Adam',
lr_piecewise_constant_decay=False,
piecewise_decay_lr_scheduler=False,
learning_rate=0.003,
grad_clip_value=0.5,
num_simulations=num_simulations,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
update_per_collect=update_per_collect,
batch_size=batch_size,
optim_type='Adam',
lr_piecewise_constant_decay=False,
piecewise_decay_lr_scheduler=False,
learning_rate=0.003,
value_weight=1.0,
entropy_weight=0.0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
update_per_collect=update_per_collect,
batch_size=batch_size,
optim_type='Adam',
lr_piecewise_constant_decay=False,
piecewise_decay_lr_scheduler=False,
learning_rate=0.003,
grad_clip_value=0.5,
value_weight=1.0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
update_per_collect=update_per_collect,
batch_size=batch_size,
optim_type='SGD',
lr_piecewise_constant_decay=True,
piecewise_decay_lr_scheduler=True,
learning_rate=0.2,
num_simulations=num_simulations,
reanalyze_ratio=reanalyze_ratio,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
update_per_collect=update_per_collect,
batch_size=batch_size,
optim_type='Adam',
lr_piecewise_constant_decay=False,
piecewise_decay_lr_scheduler=False,
learning_rate=0.003,
num_simulations=num_simulations,
reanalyze_ratio=reanalyze_ratio,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,14 @@
update_per_collect=update_per_collect,
batch_size=batch_size,
optim_type='Adam',
lr_piecewise_constant_decay=False,
piecewise_decay_lr_scheduler=False,
learning_rate=0.003,
grad_clip_value=0.5,
num_simulations=num_simulations,
reanalyze_ratio=reanalyze_ratio,
random_collect_episode_num=0,
# NOTE: for continuous gaussian policy, we use the policy_entropy_loss as in the original Sampled MuZero paper.
policy_entropy_loss_weight=5e-3,
policy_entropy_weight=5e-3,
n_episode=n_episode,
eval_freq=int(1e3),
replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
update_per_collect=update_per_collect,
batch_size=batch_size,
optim_type='SGD',
lr_piecewise_constant_decay=True,
piecewise_decay_lr_scheduler=True,
learning_rate=0.2,
num_simulations=num_simulations,
reanalyze_ratio=reanalyze_ratio,
Expand Down
4 changes: 2 additions & 2 deletions lzero/agent/config/sampled_efficientzero/gym_pendulum_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@
update_per_collect=update_per_collect,
batch_size=batch_size,
optim_type='Adam',
lr_piecewise_constant_decay=False,
piecewise_decay_lr_scheduler=False,
learning_rate=0.003,
# NOTE: for continuous gaussian policy, we use the policy_entropy_loss as in the original Sampled MuZero paper.
policy_entropy_loss_weight=5e-3,
policy_entropy_weight=5e-3,
num_simulations=num_simulations,
reanalyze_ratio=reanalyze_ratio,
n_episode=n_episode,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
update_per_collect=update_per_collect,
batch_size=batch_size,
optim_type='SGD',
lr_piecewise_constant_decay=True,
piecewise_decay_lr_scheduler=True,
learning_rate=0.2,
num_simulations=num_simulations,
reanalyze_ratio=reanalyze_ratio,
Expand Down
3 changes: 2 additions & 1 deletion lzero/entry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from .eval_muzero_with_gym_env import eval_muzero_with_gym_env
from .train_alphazero import train_alphazero
from .train_muzero import train_muzero
from .train_muzero_with_gym_env import train_muzero_with_gym_env
from .train_muzero_segment import train_muzero_segment
from .train_muzero_with_gym_env import train_muzero_with_gym_env
from .train_muzero_with_reward_model import train_muzero_with_reward_model
from .train_rezero import train_rezero
from .train_unizero import train_unizero
from .train_unizero_segment import train_unizero_segment
7 changes: 5 additions & 2 deletions lzero/entry/train_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from lzero.policy.random_policy import LightZeroRandomPolicy
from lzero.worker import MuZeroCollector as Collector
from lzero.worker import MuZeroEvaluator as Evaluator
from .utils import random_collect, initialize_zeros_batch
from .utils import random_collect


def train_muzero(
Expand Down Expand Up @@ -175,8 +175,11 @@ def train_muzero(
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
if cfg.policy.update_per_collect is None:
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio.
collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]])
# The length of game_segment (i.e., len(game_segment.action_segment)) can be smaller than cfg.policy.game_segment_length if it represents the final segment of the game.
# On the other hand, its length will be less than cfg.policy.game_segment_length + padding_length when it is not the last game segment. Typically, padding_length is the sum of unroll_steps and td_steps.
collected_transitions_num = sum(min(len(game_segment), cfg.policy.game_segment_length) for game_segment in new_data[0])
update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio)

# save returned new_data collected by the collector
replay_buffer.push_game_segments(new_data)
# remove the oldest data if the replay buffer is full.
Expand Down
Loading

0 comments on commit dd7a5eb

Please sign in to comment.