Skip to content

Commit

Permalink
polish(pu): polish redundent data squeeze operations (#177)
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 authored Dec 26, 2023
1 parent 95e94b9 commit 6af174b
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
18 changes: 9 additions & 9 deletions lzero/policy/sampled_efficientzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]:

# shape: (batch_size, num_unroll_steps, action_dim)
# NOTE: .float(), in continuous action space.
action_batch = torch.from_numpy(action_batch).to(self._cfg.device).float().unsqueeze(-1)
action_batch = torch.from_numpy(action_batch).to(self._cfg.device).float()
data_list = [
mask_batch,
target_value_prefix.astype('float32'),
Expand All @@ -343,8 +343,8 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]:
# ==============================================================
# sampled related core code
# ==============================================================
# shape: (batch_size, num_unroll_steps+1, num_of_sampled_actions, action_dim, 1), e.g. (4, 6, 5, 1, 1)
child_sampled_actions_batch = torch.from_numpy(child_sampled_actions_batch).to(self._cfg.device).unsqueeze(-1)
# shape: (batch_size, num_unroll_steps+1, num_of_sampled_actions, action_dim), e.g. (4, 6, 5, 1)
child_sampled_actions_batch = torch.from_numpy(child_sampled_actions_batch).to(self._cfg.device)

target_value_prefix = target_value_prefix.view(self._cfg.batch_size, -1)
target_value = target_value.view(self._cfg.batch_size, -1)
Expand Down Expand Up @@ -625,9 +625,9 @@ def _calculate_policy_loss_cont(
# Set target_policy_entropy to 0 if all rows are masked
target_policy_entropy = 0

# shape: (batch_size, num_unroll_steps, num_of_sampled_actions, action_dim, 1) -> (batch_size,
# num_of_sampled_actions, action_dim) e.g. (4, 6, 20, 2, 1) -> (4, 20, 2)
target_sampled_actions = child_sampled_actions_batch[:, unroll_step].squeeze(-1)
# shape: (batch_size, num_unroll_steps, num_of_sampled_actions, action_dim) -> (batch_size,
# num_of_sampled_actions, action_dim) e.g. (4, 6, 20, 2) -> (4, 20, 2)
target_sampled_actions = child_sampled_actions_batch[:, unroll_step]

policy_entropy = dist.entropy().mean()
policy_entropy_loss = -dist.entropy()
Expand Down Expand Up @@ -724,9 +724,9 @@ def _calculate_policy_loss_disc(
target_dist = Categorical(target_normalized_visit_count_masked)
target_policy_entropy = target_dist.entropy().mean()

# shape: (batch_size, num_unroll_steps, num_of_sampled_actions, action_dim, 1) -> (batch_size,
# num_of_sampled_actions, action_dim) e.g. (4, 6, 20, 2, 1) -> (4, 20, 2)
target_sampled_actions = child_sampled_actions_batch[:, unroll_step].squeeze(-1)
# shape: (batch_size, num_unroll_steps, num_of_sampled_actions, action_dim) -> (batch_size,
# num_of_sampled_actions, action_dim) e.g. (4, 6, 20, 2) -> (4, 20, 2)
target_sampled_actions = child_sampled_actions_batch[:, unroll_step]

policy_entropy = dist.entropy().mean()
policy_entropy_loss = -dist.entropy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
continuous_action_space=continuous_action_space,
num_of_sampled_actions=K,
sigma_type='conditioned',
model_type='mlp',
model_type='mlp',
lstm_hidden_size=256,
latent_state_dim=256,
res_connection_in_dynamics=True,
Expand Down

0 comments on commit 6af174b

Please sign in to comment.