Skip to content

Commit

Permalink
[rllib] Fix incorrect sequence length for rnn (ray-project#23830)
Browse files Browse the repository at this point in the history
Update the torch policy to find the seq_lens using state_batches instead of input_dict. This helps handle the complex inputs to the model when the inbuilt preprocessing API is disabled.
  • Loading branch information
kinalmehta authored Apr 12, 2022
1 parent 4cb6205 commit 758e758
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions rllib/policy/torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,9 +317,9 @@ def compute_actions_from_input_dict(
# Calculate RNN sequence lengths.
seq_lens = (
torch.tensor(
[1] * len(input_dict["obs"]),
[1] * len(state_batches[0]),
dtype=torch.long,
device=input_dict["obs"].device,
device=state_batches[0].device,
)
if state_batches
else None
Expand Down

0 comments on commit 758e758

Please sign in to comment.