@@ -193,7 +193,7 @@ def forward(
193
193
vis_inputs : List [torch .Tensor ],
194
194
masks : Optional [torch .Tensor ] = None ,
195
195
memories : Optional [torch .Tensor ] = None ,
196
- ) -> Tuple [torch .Tensor , torch . Tensor , int , int , int , int ]:
196
+ ) -> Tuple [torch .Tensor , int , int , int , int ]:
197
197
"""
198
198
Forward pass of the Actor for inference. This is required for export to ONNX, and
199
199
the inputs and outputs of this method should not be changed without a respective change
@@ -325,20 +325,19 @@ def forward(
325
325
vis_inputs : List [torch .Tensor ],
326
326
masks : Optional [torch .Tensor ] = None ,
327
327
memories : Optional [torch .Tensor ] = None ,
328
- ) -> Tuple [torch .Tensor , torch . Tensor , int , int , int , int ]:
328
+ ) -> Tuple [torch .Tensor , int , int , int , int ]:
329
329
"""
330
330
Note: This forward() method is required for exporting to ONNX. Don't modify the inputs and outputs.
331
331
"""
332
332
dists , _ = self .get_dists (vec_inputs , vis_inputs , masks , memories , 1 )
333
333
action_list = self .sample_action (dists )
334
334
sampled_actions = torch .stack (action_list , dim = - 1 )
335
335
if self .act_type == ActionType .CONTINUOUS :
336
- log_probs = dists [ 0 ]. log_prob ( sampled_actions )
336
+ action_out = sampled_actions
337
337
else :
338
- log_probs = dists [0 ].all_log_prob ()
338
+ action_out = dists [0 ].all_log_prob ()
339
339
return (
340
- sampled_actions ,
341
- log_probs ,
340
+ action_out ,
342
341
self .version_number ,
343
342
torch .Tensor ([self .network_body .memory_size ]),
344
343
self .is_continuous_int ,
0 commit comments