Skip to content

Commit 949aa1f

Browse files
author
Ervin T
authored
[bug-fix] Fix non-LSTM SeparateActorCritic (#4306)
1 parent 17bacbb commit 949aa1f

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

ml-agents/mlagents/trainers/tests/torch/test_networks.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,9 @@ def test_actor_critic(ac_type, lstm):
188188
)
189189
else:
190190
sample_obs = torch.ones((1, obs_size))
191-
memories = None
191+
memories = torch.tensor([])
192+
# memories isn't always set to None, the network should be able to
193+
# deal with that.
192194
# Test critic pass
193195
value_out = actor.critic_pass([sample_obs], [], memories=memories)
194196
for stream in stream_names:

ml-agents/mlagents/trainers/torch/networks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,7 @@ def critic_pass(
428428
vis_inputs: List[torch.Tensor],
429429
memories: Optional[torch.Tensor] = None,
430430
) -> Dict[str, torch.Tensor]:
431-
if memories is not None:
431+
if self.use_lstm:
432432
# Use only the back half of memories for critic
433433
_, critic_mem = torch.split(memories, self.half_mem_size, -1)
434434
else:
@@ -446,7 +446,7 @@ def get_dist_and_value(
446446
memories: Optional[torch.Tensor] = None,
447447
sequence_length: int = 1,
448448
) -> Tuple[List[DistInstance], Dict[str, torch.Tensor], torch.Tensor]:
449-
if memories is not None:
449+
if self.use_lstm:
450450
# Use only the back half of memories for critic and actor
451451
actor_mem, critic_mem = torch.split(memories, self.half_mem_size, dim=-1)
452452
else:

0 commit comments

Comments
 (0)