Skip to content

Commit

Permalink
Torch sequence_mask now works for tensors on different devices (ray-p…
Browse files Browse the repository at this point in the history
  • Loading branch information
janblumenkamp authored Apr 15, 2020
1 parent c174049 commit 8e43968
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions rllib/utils/torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def sequence_mask(lengths, maxlen, dtype=None):
if maxlen is None:
maxlen = lengths.max()

mask = ~(torch.ones((len(lengths), maxlen)).cumsum(dim=1).t() > lengths). \
t()
mask = ~(torch.ones((len(lengths), maxlen)).to(
lengths.device).cumsum(dim=1).t() > lengths).t()
mask.type(dtype or torch.bool)

return mask
Expand Down

0 comments on commit 8e43968

Please sign in to comment.