Skip to content

Commit e6eb502

Browse files
author
Ervin T
authored
Always export one Action tensor (#4388)
1 parent 0a8b5e0 commit e6eb502

File tree

2 files changed

+7
-16
lines changed

2 files changed

+7
-16
lines changed

ml-agents/mlagents/trainers/torch/model_serialization.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,24 +32,16 @@ def __init__(self, policy):
3232
+ ["action_masks", "memories"]
3333
)
3434

35-
if self.policy.use_continuous_act:
36-
action_name = "action"
37-
action_prob_name = "action_probs"
38-
else:
39-
action_name = "action_unused"
40-
action_prob_name = "action"
41-
4235
self.output_names = [
43-
action_name,
44-
action_prob_name,
36+
"action",
4537
"version_number",
4638
"memory_size",
4739
"is_continuous_control",
4840
"action_output_shape",
4941
]
5042

5143
self.dynamic_axes = {name: {0: "batch"} for name in self.input_names}
52-
self.dynamic_axes.update({"action": {0: "batch"}, "action_probs": {0: "batch"}})
44+
self.dynamic_axes.update({"action": {0: "batch"}})
5345

5446
def export_policy_model(self, output_filepath: str) -> None:
5547
"""

ml-agents/mlagents/trainers/torch/networks.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def forward(
193193
vis_inputs: List[torch.Tensor],
194194
masks: Optional[torch.Tensor] = None,
195195
memories: Optional[torch.Tensor] = None,
196-
) -> Tuple[torch.Tensor, torch.Tensor, int, int, int, int]:
196+
) -> Tuple[torch.Tensor, int, int, int, int]:
197197
"""
198198
Forward pass of the Actor for inference. This is required for export to ONNX, and
199199
the inputs and outputs of this method should not be changed without a respective change
@@ -325,20 +325,19 @@ def forward(
325325
vis_inputs: List[torch.Tensor],
326326
masks: Optional[torch.Tensor] = None,
327327
memories: Optional[torch.Tensor] = None,
328-
) -> Tuple[torch.Tensor, torch.Tensor, int, int, int, int]:
328+
) -> Tuple[torch.Tensor, int, int, int, int]:
329329
"""
330330
Note: This forward() method is required for exporting to ONNX. Don't modify the inputs and outputs.
331331
"""
332332
dists, _ = self.get_dists(vec_inputs, vis_inputs, masks, memories, 1)
333333
action_list = self.sample_action(dists)
334334
sampled_actions = torch.stack(action_list, dim=-1)
335335
if self.act_type == ActionType.CONTINUOUS:
336-
log_probs = dists[0].log_prob(sampled_actions)
336+
action_out = sampled_actions
337337
else:
338-
log_probs = dists[0].all_log_prob()
338+
action_out = dists[0].all_log_prob()
339339
return (
340-
sampled_actions,
341-
log_probs,
340+
action_out,
342341
self.version_number,
343342
torch.Tensor([self.network_body.memory_size]),
344343
self.is_continuous_int,

0 commit comments

Comments
 (0)