Skip to content

Commit 3442de5

Browse files
author
Ervin T
authored
Fix for discrete actions (#4181)
1 parent 2a22e17 commit 3442de5

File tree

4 files changed

+5
-6
lines changed

4 files changed

+5
-6
lines changed

ml-agents/mlagents/trainers/distributions_torch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def sample(self):
4242
return torch.multinomial(self.probs, 1)
4343

4444
def pdf(self, value):
45-
return torch.diag(self.probs.T[value.flatten()])
45+
return torch.diag(self.probs.T[value.flatten().long()])
4646

4747
def log_prob(self, value):
4848
return torch.log(self.pdf(value))

ml-agents/mlagents/trainers/policy/policy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def __init__(
3232
self.num_branches = len(self.brain.vector_action_space_size)
3333
self.previous_action_dict: Dict[str, np.array] = {}
3434
self.memory_dict: Dict[str, np.ndarray] = {}
35-
self.normalize = trainer_settings
35+
self.normalize = trainer_settings.network_settings.normalize
3636
self.use_recurrent = trainer_settings.network_settings.memory is not None
3737
self.model_path = trainer_settings.init_path
3838

ml-agents/mlagents/trainers/policy/torch_policy.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,7 @@ def sample_actions(self, vec_obs, vis_obs, masks=None, memories=None, seq_len=1)
150150

151151
actions = self.actor_critic.sample_action(dists)
152152
log_probs, entropies = self.actor_critic.get_probs_and_entropy(actions, dists)
153-
if self.act_type == "continuous":
154-
actions.squeeze_(-1)
153+
actions = torch.squeeze(actions)
155154

156155
return actions, log_probs, entropies, value_heads, memories
157156

@@ -250,7 +249,7 @@ def export_model(self, step=0):
250249
fake_vec_obs = [torch.zeros([1] + [self.brain.vector_observation_space_size])]
251250
fake_vis_obs = [torch.zeros([1] + [84, 84, 3])]
252251
fake_masks = torch.ones([1] + self.actor_critic.act_size)
253-
fake_memories = torch.zeros([1] + [self.m_size])
252+
# fake_memories = torch.zeros([1] + [self.m_size])
254253
export_path = "./model-" + str(step) + ".onnx"
255254
output_names = ["action", "action_probs"]
256255
input_names = ["vector_observation", "action_mask"]

ml-agents/mlagents/trainers/ppo/optimizer_torch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
9999
if self.policy.use_continuous_act:
100100
actions = torch.as_tensor(batch["actions"]).unsqueeze(-1)
101101
else:
102-
actions = torch.as_tensor(batch["actions"])
102+
actions = torch.as_tensor(batch["actions"], dtype=torch.long)
103103

104104
memories = [
105105
torch.as_tensor(batch["memory"][i])

0 commit comments

Comments
 (0)