Skip to content

Commit 645fbba

Browse files
author
Vincent Moens
committed
init
1 parent 332499a commit 645fbba

File tree

1 file changed

+25
-11
lines changed

1 file changed

+25
-11
lines changed

torchrl/data/replay_buffers/samplers.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from torchrl._extension import EXTENSION_WARNING
2424

25-
from torchrl._utils import _replace_last, logger
25+
from torchrl._utils import _replace_last, implement_for, logger
2626
from torchrl.data.replay_buffers.storages import Storage, StorageEnsemble, TensorStorage
2727
from torchrl.data.replay_buffers.utils import _is_int
2828

@@ -1681,6 +1681,29 @@ def mark_update(
16811681
) -> None:
16821682
return PrioritizedSampler.mark_update(self, index, storage=storage)
16831683

1684+
@implement_for("torch", "2.4")
1685+
def _padded_indices(self, shapes, arange) -> torch.Tensor:
1686+
# this complex mumbo jumbo creates a left padded tensor with valid indices on the right, e.g.
1687+
# tensor([[ 0, 1, 2, 3, 4],
1688+
# [-1, -1, 5, 6, 7],
1689+
# [-1, 8, 9, 10, 11]])
1690+
# where the -1 items on the left are padded values
1691+
st, off = torch._nested_compute_contiguous_strides_offsets(shapes.flip(0))
1692+
nt = torch._nested_view_from_buffer(
1693+
arange.flip(0).contiguous(), shapes.flip(0), st, off
1694+
)
1695+
pad = nt.to_padded_tensor(-1).flip(-1).flip(0)
1696+
return pad
1697+
1698+
@implement_for("torch", None, "2.4")
1699+
def _padded_indices(self, shapes, arange) -> torch.Tensor:
1700+
arange = arange.flip(0).split(shapes.flip(0).squeeze().unbind())
1701+
return (
1702+
torch.nn.utils.rnn.pad_sequence(arange, batch_first=True, padding_value=-1)
1703+
.flip(-1)
1704+
.flip(0)
1705+
)
1706+
16841707
def _preceding_stop_idx(self, storage, lengths, seq_length):
16851708
preceding_stop_idx = self._cache.get("preceding_stop_idx")
16861709
if preceding_stop_idx is not None:
@@ -1698,16 +1721,7 @@ def _preceding_stop_idx(self, storage, lengths, seq_length):
16981721
all_but_starts[starts] = False
16991722
arange = arange[all_but_starts]
17001723
shapes = shapes - 1
1701-
# this complex mumbo jumbo creates a left padded tensor with valid indices on the right, e.g.
1702-
# tensor([[ 0, 1, 2, 3, 4],
1703-
# [-1, -1, 5, 6, 7],
1704-
# [-1, 8, 9, 10, 11]])
1705-
# where the -1 items on the left are padded values
1706-
st, off = torch._nested_compute_contiguous_strides_offsets(shapes.flip(0))
1707-
nt = torch._nested_view_from_buffer(
1708-
arange.flip(0).contiguous(), shapes.flip(0).contiguous(), st, off
1709-
)
1710-
pad = nt.to_padded_tensor(-1).flip(-1).flip(0).contiguous()
1724+
pad = self._padded_indices(shapes, arange)
17111725
_, span_right = self.span[0], self.span[1]
17121726
if span_right and isinstance(span_right, bool):
17131727
preceding_stop_idx = pad[:, -1:]

0 commit comments

Comments
 (0)