Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions lzero/mcts/buffer/game_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,17 @@ def _sample_orig_data(self, batch_size: int) -> Tuple:
# For some environments (e.g., Jericho), the action space size may be different.
# To ensure we can always unroll `num_unroll_steps` steps starting from the sampled position (without exceeding segment length),
# we avoid sampling from the last `num_unroll_steps` steps of the game segment.
if pos_in_game_segment >= self._cfg.game_segment_length - self._cfg.num_unroll_steps:
pos_in_game_segment = np.random.choice(self._cfg.game_segment_length - self._cfg.num_unroll_steps, 1).item()
if pos_in_game_segment >= self._cfg.game_segment_length - self._cfg.num_unroll_steps - self._cfg.td_steps:
pos_in_game_segment = np.random.choice(self._cfg.game_segment_length - self._cfg.num_unroll_steps - self._cfg.td_steps, 1).item()
if pos_in_game_segment >= len(game_segment.action_segment) - 1:
pos_in_game_segment = np.random.choice(len(game_segment.action_segment) - 1, 1).item()
else:
# For environments with a fixed action space (e.g., Atari),
# we can safely sample from the entire game segment range.
if pos_in_game_segment >= self._cfg.game_segment_length:
pos_in_game_segment = np.random.choice(self._cfg.game_segment_length, 1).item()
if pos_in_game_segment >= len(game_segment.action_segment) - 1:
pos_in_game_segment = np.random.choice(len(game_segment.action_segment) - 1, 1).item()

pos_in_game_segment_list.append(pos_in_game_segment)

Expand Down
9 changes: 8 additions & 1 deletion lzero/mcts/buffer/game_buffer_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,10 @@ def _prepare_reward_value_context(
td_steps_list, action_mask_segment, to_play_segment
"""
zero_obs = game_segment_list[0].zero_obs()
zero_manual = game_segment_list[0].zero_manual()

value_obs_list = []
value_manual_embeds_list = []
# the value is valid or not (out of game_segment)
value_mask = []
rewards_list = []
Expand All @@ -300,6 +303,7 @@ def _prepare_reward_value_context(
# o[t+ td_steps, t + td_steps + stack frames + num_unroll_steps]
# t=2+3 -> o[2+3, 2+3+4+5] -> o[5, 14]
game_obs = game_segment.get_unroll_obs(state_index + td_steps, self._cfg.num_unroll_steps)
game_manual_embeds = game_segment.get_unroll_manual(state_index + td_steps, self._cfg.num_unroll_steps)

rewards_list.append(game_segment.reward_segment)

Expand All @@ -321,15 +325,18 @@ def _prepare_reward_value_context(
end_index = beg_index + self._cfg.model.frame_stack_num
# the stacked obs in time t
obs = game_obs[beg_index:end_index]
manual_embeds = game_manual_embeds[beg_index:end_index]
else:
value_mask.append(0)
obs = zero_obs
manual_embeds = zero_manual

value_obs_list.append(obs)
value_manual_embeds_list.append(manual_embeds)

reward_value_context = [
value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, root_values, game_segment_lens, td_steps_list,
action_mask_segment, to_play_segment
action_mask_segment, to_play_segment, value_manual_embeds_list
]
return reward_value_context

Expand Down
21 changes: 15 additions & 6 deletions lzero/mcts/buffer/game_buffer_unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,15 @@ def sample(
policy._target_model.eval()

# obtain the current_batch and prepare target context
reward_value_context, policy_re_context, policy_non_re_context, current_batch = self._make_batch(
reward_value_context, policy_re_context, policy_non_re_context, current_batch, batch_manual_embeds = self._make_batch(
batch_size, self._cfg.reanalyze_ratio
)

# current_batch = [obs_list, action_list, bootstrap_action_list, mask_list, batch_index_list, weights_list, make_time_list, timestep_list]

# target reward, target value
batch_rewards, batch_target_values = self._compute_target_reward_value(
reward_value_context, policy._target_model, current_batch[2], current_batch[-1] # current_batch[2] is batch_target_action
reward_value_context, policy._target_model, current_batch[2], current_batch[-1] # current_batch[2] is batch_target_action
)

# target policy
Expand All @@ -92,7 +92,7 @@ def sample(
target_batch = [batch_rewards, batch_target_values, batch_target_policies]

# a batch contains the current_batch and the target_batch
train_data = [current_batch, target_batch]
train_data = [current_batch, target_batch, batch_manual_embeds]
return train_data

def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
Expand Down Expand Up @@ -120,6 +120,7 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
obs_list, action_list, mask_list = [], [], []
timestep_list = []
bootstrap_action_list = []
manual_embeds_list = []

# prepare the inputs of a batch
for i in range(batch_size):
Expand Down Expand Up @@ -156,6 +157,12 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True
)
)
manual_embeds_list.append(
game_segment_list[i].get_unroll_manual(
pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True
)
)

action_list.append(actions_tmp)

mask_list.append(mask_tmp)
Expand Down Expand Up @@ -214,7 +221,8 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
else:
policy_non_re_context = None

context = reward_value_context, policy_re_context, policy_non_re_context, current_batch
manual_embeds_array = np.asarray(manual_embeds_list)
context = reward_value_context, policy_re_context, policy_non_re_context, current_batch, manual_embeds_array
return context

def reanalyze_buffer(
Expand Down Expand Up @@ -527,20 +535,21 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
- batch_target_values (:obj:'np.ndarray): batch of value estimation
"""
value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, root_values, game_segment_lens, td_steps_list, action_mask_segment, \
to_play_segment = reward_value_context # noqa
to_play_segment, value_manual_embeds_list = reward_value_context # noqa
# transition_batch_size = game_segment_batch_size * (num_unroll_steps+1)
transition_batch_size = len(value_obs_list)

batch_target_values, batch_rewards = [], []
with torch.no_grad():
value_obs_list = prepare_observation(value_obs_list, self._cfg.model.model_type)
batch_manual = torch.from_numpy(np.array(value_manual_embeds_list))
network_output = []
batch_obs = torch.from_numpy(value_obs_list).to(self._cfg.device)

# =============== NOTE: The key difference with MuZero =================
# calculate the bootstrapped value and target value
# NOTE: batch_obs(value_obs_list) is at t+td_steps, batch_action is at timestep t+td_steps
m_output = model.initial_inference(batch_obs, batch_action, start_pos=batch_timestep)
m_output = model.initial_inference(batch_obs, batch_action, start_pos=batch_timestep, manual_embeds=batch_manual)
# ======================================================================

# if not in training, obtain the scalars of the value/reward
Expand Down
35 changes: 34 additions & 1 deletion lzero/mcts/buffer/game_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(self, action_space: int, game_segment_length: int = 200, config: Ea
# image obs input, e.g. atari environments
self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape[-2], config.model.observation_shape[-1])

self.manual_embed_dim = config.model.world_model_cfg.manual_embed_dim
self.obs_segment = []
self.action_segment = []
self.reward_segment = []
Expand All @@ -69,6 +70,7 @@ def __init__(self, action_space: int, game_segment_length: int = 200, config: Ea
self.action_mask_segment = []
self.to_play_segment = []
self.timestep_segment = []
self.manual_embeds_segment = []

self.target_values = []
self.target_rewards = []
Expand Down Expand Up @@ -102,6 +104,23 @@ def get_unroll_obs(self, timestep: int, num_unroll_steps: int = 0, padding: bool
stacked_obs = [jpeg_data_decompressor(obs, self.gray_scale) for obs in stacked_obs]
return stacked_obs

def get_unroll_manual(self, timestep: int, num_unroll_steps: int = 0, padding: bool = False) -> np.ndarray:
"""
Overview:
Get an observation of the correct format: o[t, t + stack frames + num_unroll_steps].
Arguments:
- timestep (int): The time step.
- num_unroll_steps (int): The extra length of the observation frames.
- padding (bool): If True, pad frames if (t + stack frames) is outside of the trajectory.
"""
stacked_manual_embeds = self.manual_embeds_segment[timestep:timestep + self.frame_stack_num + num_unroll_steps]
if padding:
pad_len = self.frame_stack_num + num_unroll_steps - len(stacked_manual_embeds)
if pad_len > 0:
pad_frames = np.array([stacked_manual_embeds[-1] for _ in range(pad_len)])
stacked_manual_embeds = np.concatenate((stacked_manual_embeds, pad_frames))
return stacked_manual_embeds

def zero_obs(self) -> List:
"""
Overview:
Expand All @@ -111,6 +130,15 @@ def zero_obs(self) -> List:
"""
return [np.zeros(self.zero_obs_shape, dtype=np.float32) for _ in range(self.frame_stack_num)]

def zero_manual(self) -> List:
"""
Overview:
Return an manual embed frame filled with zeros.
Returns:
ndarray: An array filled with zeros.
"""
return [np.zeros((self.manual_embed_dim, ), dtype=np.float32) for _ in range(self.frame_stack_num)]

def get_obs(self) -> List:
"""
Overview:
Expand Down Expand Up @@ -138,6 +166,7 @@ def append(
to_play: int = -1,
timestep: int = 0,
chance: int = 0,
manual_embeds = None,
) -> None:
"""
Overview:
Expand All @@ -150,6 +179,7 @@ def append(
self.action_mask_segment.append(action_mask)
self.to_play_segment.append(to_play)
self.timestep_segment.append(timestep)
self.manual_embeds_segment.append(manual_embeds)

if self.use_ture_chance_label_in_chance_encoder:
self.chance_segment.append(chance)
Expand Down Expand Up @@ -285,6 +315,7 @@ def game_segment_to_array(self) -> None:
self.obs_segment = np.array(self.obs_segment)
self.action_segment = np.array(self.action_segment)
self.reward_segment = np.array(self.reward_segment)
self.manual_embeds_segment = np.array(self.manual_embeds_segment)

# Check if all elements in self.child_visit_segment have the same length
if all(len(x) == len(self.child_visit_segment[0]) for x in self.child_visit_segment):
Expand All @@ -305,7 +336,7 @@ def game_segment_to_array(self) -> None:
if self.use_ture_chance_label_in_chance_encoder:
self.chance_segment = np.array(self.chance_segment)

def reset(self, init_observations: np.ndarray) -> None:
def reset(self, init_observations: np.ndarray, init_manual_embeds = None) -> None:
"""
Overview:
Initialize the game segment using ``init_observations``,
Expand All @@ -323,6 +354,7 @@ def reset(self, init_observations: np.ndarray) -> None:
self.action_mask_segment = []
self.to_play_segment = []
self.timestep_segment = []
self.manual_embeds_segment = []

if self.use_ture_chance_label_in_chance_encoder:
self.chance_segment = []
Expand All @@ -331,6 +363,7 @@ def reset(self, init_observations: np.ndarray) -> None:

for observation in init_observations:
self.obs_segment.append(copy.deepcopy(observation))
self.manual_embeds_segment.append(copy.deepcopy(init_manual_embeds))

def is_full(self) -> bool:
"""
Expand Down
10 changes: 9 additions & 1 deletion lzero/model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,8 @@ def __init__(
embedding_dim: int = 256,
group_size: int = 8,
final_norm_option_in_encoder: str = 'LayerNorm', # TODO
use_manual: bool = False,
manual_dim: int = 768
) -> None:
"""
Overview:
Expand All @@ -496,8 +498,10 @@ def __init__(
logging.info(f"Using norm type: {norm_type}")
logging.info(f"Using activation type: {activation}")

self.observation_shape = observation_shape
self.observation_shape = observation_shape
self.downsample = downsample
self.use_manual = use_manual

if self.downsample:
self.downsample_net = DownSample(
observation_shape,
Expand Down Expand Up @@ -533,6 +537,8 @@ def __init__(

elif self.observation_shape[1] in [84, 96]:
self.last_linear = nn.Linear(64 * 6 * 6, self.embedding_dim, bias=False)
elif self.observation_shape[1] == 10:
self.last_linear = nn.Linear(64 * 10 * 10, self.embedding_dim, bias=False)

self.final_norm_option_in_encoder = final_norm_option_in_encoder
if self.final_norm_option_in_encoder == 'LayerNorm':
Expand All @@ -542,6 +548,8 @@ def __init__(
else:
raise ValueError(f"Unsupported final_norm_option_in_encoder: {self.final_norm_option_in_encoder}")

if use_manual:
self.feature_merge_linear = nn.Linear(self.embedding_dim + manual_dim, self.embedding_dim)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.feature_merge_linearh后面应该和原来的obs_embeddings执行相同的norm?

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Expand Down
9 changes: 5 additions & 4 deletions lzero/model/unizero_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, List

import torch
import torch.nn as nn
Expand Down Expand Up @@ -127,7 +127,7 @@ def __init__(
norm_type=norm_type,
embedding_dim=world_model_cfg.embed_dim,
group_size=world_model_cfg.group_size,
final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder
final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder,
)

# ====== for analysis ======
Expand Down Expand Up @@ -177,7 +177,7 @@ def __init__(
print('==' * 20)

def initial_inference(self, obs_batch: torch.Tensor, action_batch: Optional[torch.Tensor] = None,
current_obs_batch: Optional[torch.Tensor] = None, start_pos: int = 0) -> MZNetworkOutput:
current_obs_batch: Optional[torch.Tensor] = None, start_pos: int = 0, manual_embeds: List[torch.Tensor] = None) -> MZNetworkOutput:
"""
Overview:
Initial inference of the UniZero model, which is the first step of the UniZero model.
Expand Down Expand Up @@ -205,7 +205,8 @@ def initial_inference(self, obs_batch: torch.Tensor, action_batch: Optional[torc
obs_act_dict = {
'obs': obs_batch,
'action': action_batch,
'current_obs': current_obs_batch
'current_obs': current_obs_batch,
'manual_embeds': manual_embeds
}

# Perform initial inference using the world model
Expand Down
Loading