Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lzero/entry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@
from .train_muzero_with_reward_model import train_muzero_with_reward_model
from .train_rezero import train_rezero
from .train_unizero import train_unizero
from .train_unizero_with_reward_model import train_unizero_with_reward_model
from .train_unizero_segment import train_unizero_segment
from .utils import *
287 changes: 287 additions & 0 deletions lzero/entry/train_unizero_with_reward_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@
import logging
import os
from functools import partial
from typing import Tuple, Optional

import torch
import wandb
from ding.config import compile_config
from ding.envs import create_env_manager
from ding.envs import get_vec_env_setting
from ding.policy import create_policy
from ding.rl_utils import get_epsilon_greedy_fn
from ding.utils import set_pkg_seed, get_rank
from ding.worker import BaseLearner
from tensorboardX import SummaryWriter
from torch.utils.tensorboard import SummaryWriter

from lzero.entry.utils import log_buffer_memory_usage
from lzero.policy import visit_count_temperature
from lzero.policy.random_policy import LightZeroRandomPolicy
from lzero.worker import MuZeroEvaluator as Evaluator
from lzero.worker import MuZeroCollector as Collector
from lzero.reward_model.rnd_reward_model import RNDRewardModel
from .utils import random_collect, calculate_update_per_collect
import torch.distributed as dist
from ding.utils import set_pkg_seed, get_rank, get_world_size


def train_unizero_with_reward_model(
input_cfg: Tuple[dict, dict],
seed: int = 0,
model: Optional[torch.nn.Module] = None,
model_path: Optional[str] = None,
max_train_iter: Optional[int] = int(1e10),
max_env_step: Optional[int] = int(1e10),
) -> 'Policy':
"""
Overview:
This function serves as the training entry point for UniZero, as proposed in our paper "UniZero: Generalized and Efficient Planning with Scalable Latent World Models".
UniZero aims to enhance the planning capabilities of reinforcement learning agents by addressing the limitations found in MuZero-style algorithms,
particularly in environments that require capturing long-term dependencies. More details can be found in https://arxiv.org/abs/2406.10667.

Arguments:
- input_cfg (:obj:`Tuple[dict, dict]`): Configuration in dictionary format.
``Tuple[dict, dict]`` indicates [user_config, create_cfg].
- seed (:obj:`int`): Random seed for reproducibility.
- model (:obj:`Optional[torch.nn.Module]`): Instance of a PyTorch model.
- model_path (:obj:`Optional[str]`): Path to the pretrained model, which should
point to the checkpoint file of the pretrained model. An absolute path is recommended.
In LightZero, the path typically resembles ``exp_name/ckpt/ckpt_best.pth.tar``.
- max_train_iter (:obj:`Optional[int]`): Maximum number of policy update iterations during training.
- max_env_step (:obj:`Optional[int]`): Maximum number of environment interaction steps to collect.

Returns:
- policy (:obj:`Policy`): The converged policy after training.
"""

cfg, create_cfg = input_cfg

# Ensure the specified policy type is supported
assert create_cfg.policy.type in ['unizero', 'sampled_unizero'], "train_unizero only supports the following algorithms: 'unizero', 'sampled_unizero'"
assert cfg.policy.use_rnd_model, "cfg.policy.use_rnd_model must be True to use RND reward model"
logging.info(f"Using policy type: {create_cfg.policy.type} with RND Reward Model.")

# Import the appropriate GameBuffer class based on the policy type
game_buffer_classes = {'unizero': 'UniZeroGameBuffer', 'sampled_unizero': 'SampledUniZeroGameBuffer'}
GameBuffer = getattr(__import__('lzero.mcts', fromlist=[game_buffer_classes[create_cfg.policy.type]]),
game_buffer_classes[create_cfg.policy.type])

# Check for GPU availability and set the device accordingly
cfg.policy.device = cfg.policy.model.world_model_cfg.device if torch.cuda.is_available() else 'cpu'
logging.info(f"Device set to: {cfg.policy.device}")

# Compile the configuration file
cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)

# Create environment manager
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])

# Initialize environment and random seed
collector_env.seed(cfg.seed)
evaluator_env.seed(cfg.seed, dynamic_seed=False)
set_pkg_seed(cfg.seed, use_cuda=torch.cuda.is_available())

# Initialize wandb if specified
if cfg.policy.use_wandb:
logging.info("Initializing wandb...")
wandb.init(
project="LightZero",
config=cfg,
sync_tensorboard=False,
monitor_gym=False,
save_code=True,
)
logging.info("wandb initialization completed!")

# Create policy
logging.info("Creating policy...")
policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval'])
logging.info("Policy created successfully!")

# Load pretrained model if specified
if model_path is not None:
logging.info(f"Loading pretrained model from {model_path}...")
policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device))
logging.info("Pretrained model loaded successfully!")

# Create core components for training
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if get_rank() == 0 else None
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
replay_buffer = GameBuffer(cfg.policy)
collector = Collector(env=collector_env, policy=policy.collect_mode, tb_logger=tb_logger, exp_name=cfg.exp_name,
policy_config=cfg.policy)
evaluator = Evaluator(eval_freq=cfg.policy.eval_freq, n_evaluator_episode=cfg.env.n_evaluator_episode,
stop_value=cfg.env.stop_value, env=evaluator_env, policy=policy.eval_mode,
tb_logger=tb_logger, exp_name=cfg.exp_name, policy_config=cfg.policy)

# ==============================================================
# 新增: 初始化 RND 奖励模型
# RNDRewardModel 需要策略模型中的表征网络(作为预测器)和目标表征网络(作为固定目标)
# 对于 UniZero,tokenizer 扮演了表征网络的功能。
# ==============================================================
reward_model = RNDRewardModel(
config=cfg.reward_model,
device=policy.collect_mode.get_attribute('device'),
tb_logger=tb_logger,
exp_name=cfg.exp_name,
representation_network=policy._learn_model.representation_network,
target_representation_network=policy._target_model_for_intrinsic_reward.representation_network,
use_momentum_representation_network=cfg.policy.use_momentum_representation_network,
bp_update_sync=cfg.policy.bp_update_sync,
multi_gpu=cfg.policy.multi_gpu,
)

# Execute the learner's before_run hook
learner.call_hook('before_run')

if cfg.policy.use_wandb:
policy.set_train_iter_env_step(learner.train_iter, collector.envstep)

# Randomly collect data if specified
if cfg.policy.random_collect_episode_num > 0:
logging.info("Collecting random data...")
random_collect(cfg.policy, policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer)
logging.info("Random data collection completed!")

batch_size = policy._cfg.batch_size

if cfg.policy.multi_gpu:
# Get current world size and rank
world_size = get_world_size()
rank = get_rank()
else:
world_size = 1
rank = 0

while True:
# Log memory usage of the replay buffer
log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger)
# Set temperature parameter for data collection
collect_kwargs = {
'temperature': visit_count_temperature(
cfg.policy.manual_temperature_decay,
cfg.policy.fixed_temperature_value,
cfg.policy.threshold_training_steps_for_final_temperature,
trained_steps=learner.train_iter
),
'epsilon': 0.0 # Default epsilon value
}
# Configure epsilon-greedy exploration
if cfg.policy.eps.eps_greedy_exploration_in_collect:
epsilon_greedy_fn = get_epsilon_greedy_fn(
start=cfg.policy.eps.start,
end=cfg.policy.eps.end,
decay=cfg.policy.eps.decay,
type_=cfg.policy.eps.type
)
collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep)

# Evaluate policy performance
if learner.train_iter == 0 or evaluator.should_eval(learner.train_iter):
logging.info(f"Training iteration {learner.train_iter}: Starting evaluation...")
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
logging.info(f"Training iteration {learner.train_iter}: Evaluation completed, stop condition: {stop}, current reward: {reward}")
if stop:
logging.info("Stopping condition met, training ends!")
break

# Collect new data
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
logging.info(f"Rank {rank}, Training iteration {learner.train_iter}: New data collection completed!")

# Determine updates per collection
update_per_collect = cfg.policy.update_per_collect
if update_per_collect is None:
update_per_collect = calculate_update_per_collect(cfg, new_data, world_size)

# Update replay buffer
replay_buffer.push_game_segments(new_data)
replay_buffer.remove_oldest_data_to_fit()

# ****** reward_model related code ******
# collect data for reward_model training
try:
reward_model.collect_data(new_data)
except Exception as e:
logging.exception(f'Rank {rank}: reward_model.collect_data failed {e}')
raise
# update reward_model
if reward_model.cfg.input_type == 'latent_state':
local_items_count = len(reward_model.train_latent_state)
elif reward_model.cfg.input_type in ['obs', 'obs_latent_state']:
local_items_count = len(reward_model.train_obs)
local_data_tensor = torch.tensor([local_items_count], dtype=torch.long, device=cfg.policy.device)

if world_size > 1:
try:
dist.all_reduce(local_data_tensor, op=dist.ReduceOp.MIN)
except Exception as e:
logging.error(f'Rank {rank}: Synchronization barrier failed, error: {e}')

if local_data_tensor.item() >= reward_model.cfg.batch_size:
try:
reward_model.train_with_data()
except Exception as e:
logging.error(f'Rank {rank}: reward_model.train_with_data failed, error:{e}')

reward_model.clear_old_data()

if world_size > 1:
try:
dist.barrier()
except Exception as e:
logging.error(f'Rank {rank}: Synchronization barrier failed, error: {e}')
break

# Check if there is sufficient data for training
if collector.envstep > cfg.policy.train_start_after_envsteps:
if cfg.policy.sample_type == 'episode':
data_sufficient = replay_buffer.get_num_of_game_segments() > batch_size
else:
data_sufficient = replay_buffer.get_num_of_transitions() > batch_size

if not data_sufficient:
logging.warning(
f'Rank {rank}: The data in replay_buffer is not sufficient to sample a mini-batch: '
f'batch_size: {batch_size}, replay_buffer: {replay_buffer}. Continue to collect now ....'
)
continue

# Execute multiple training rounds
for i in range(update_per_collect):
train_data = replay_buffer.sample(batch_size, policy)
if replay_buffer._cfg.reanalyze_ratio > 0 and i % 20 == 0:
policy.recompute_pos_emb_diff_and_clear_cache()

# update train_data reward using the augmented reward
try:
train_data_augmented = reward_model.estimate(train_data)
except Exception as e:
logging.exception(f'Rank {rank}: reward_model.estimate failed, error: {e}')
raise
if cfg.policy.use_wandb:
policy.set_train_iter_env_step(learner.train_iter, collector.envstep)

train_data_augmented.append(learner.train_iter)
# train_data.append(learner.train_iter)

log_vars = learner.train(train_data_augmented, collector.envstep)
# log_vars = learner.train(train_data, collector.envstep)
if cfg.policy.use_priority:
replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig'])

policy.recompute_pos_emb_diff_and_clear_cache()

# Check stopping criteria
if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
logging.info("Stopping condition met, training ends!")
break

learner.call_hook('after_run')
if cfg.policy.use_wandb:
wandb.finish()
logging.info("===== Training Completed =====")
return policy
4 changes: 2 additions & 2 deletions lzero/mcts/buffer/game_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,14 +159,14 @@ def _sample_orig_data(self, batch_size: int) -> Tuple:
if pos_in_game_segment >= self._cfg.game_segment_length - self._cfg.num_unroll_steps - self._cfg.td_steps:
pos_in_game_segment = np.random.choice(self._cfg.game_segment_length - self._cfg.num_unroll_steps - self._cfg.td_steps, 1).item()
if pos_in_game_segment >= len(game_segment.action_segment) - 1:
pos_in_game_segment = np.random.choice(len(game_segment.action_segment) - 1, 1).item()
pos_in_game_segment = np.random.choice(max(len(game_segment.action_segment) - 1, 1), 1).item()
else:
# For environments with a fixed action space (e.g., Atari),
# we can safely sample from the entire game segment range.
if pos_in_game_segment >= self._cfg.game_segment_length:
pos_in_game_segment = np.random.choice(self._cfg.game_segment_length, 1).item()
if pos_in_game_segment >= len(game_segment.action_segment) - 1:
pos_in_game_segment = np.random.choice(len(game_segment.action_segment) - 1, 1).item()
pos_in_game_segment = np.random.choice(max(len(game_segment.action_segment) - 1, 1), 1).item()

pos_in_game_segment_list.append(pos_in_game_segment)

Expand Down
5 changes: 3 additions & 2 deletions lzero/model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,8 +447,9 @@ def encode(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor:
batch_idx = torch.arange(B, device=last_hidden.device)

selected = last_hidden[batch_idx, positions] # [B, H]

latent = self.embedding_head(selected.to(self.embedding_head[0].weight.dtype))
weight = self.embedding_head[0].weight
selected = selected.to(dtype=weight.dtype, device=weight.device)
latent = self.embedding_head(selected)
return latent

def decode(self, embeddings: torch.Tensor, max_length: int = 512) -> str:
Expand Down
Loading