Skip to content

Commit

Permalink
Copy initial state of an RNN to a CPU before converting it to a NumPy…
Browse files Browse the repository at this point in the history
… array (ray-project#8097)
  • Loading branch information
iamhatesz authored Apr 26, 2020
1 parent b506f87 commit b508166
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion rllib/policy/torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,9 @@ def num_state_tensors(self):

@override(Policy)
def get_initial_state(self):
return [s.numpy() for s in self.model.get_initial_state()]
return [
s.cpu().detach().numpy() for s in self.model.get_initial_state()
]

def extra_grad_process(self, optimizer, loss):
"""Called after each optimizer.zero_grad() + loss.backward() call.
Expand Down

0 comments on commit b508166

Please sign in to comment.