22
22
23
23
from torchrl ._extension import EXTENSION_WARNING
24
24
25
- from torchrl ._utils import _replace_last , logger
25
+ from torchrl ._utils import _replace_last , implement_for , logger
26
26
from torchrl .data .replay_buffers .storages import Storage , StorageEnsemble , TensorStorage
27
27
from torchrl .data .replay_buffers .utils import _is_int
28
28
@@ -1681,6 +1681,29 @@ def mark_update(
1681
1681
) -> None :
1682
1682
return PrioritizedSampler .mark_update (self , index , storage = storage )
1683
1683
1684
+ @implement_for ("torch" , "2.4" )
1685
+ def _padded_indices (self , shapes , arange ) -> torch .Tensor :
1686
+ # this complex mumbo jumbo creates a left padded tensor with valid indices on the right, e.g.
1687
+ # tensor([[ 0, 1, 2, 3, 4],
1688
+ # [-1, -1, 5, 6, 7],
1689
+ # [-1, 8, 9, 10, 11]])
1690
+ # where the -1 items on the left are padded values
1691
+ st , off = torch ._nested_compute_contiguous_strides_offsets (shapes .flip (0 ))
1692
+ nt = torch ._nested_view_from_buffer (
1693
+ arange .flip (0 ).contiguous (), shapes .flip (0 ), st , off
1694
+ )
1695
+ pad = nt .to_padded_tensor (- 1 ).flip (- 1 ).flip (0 )
1696
+ return pad
1697
+
1698
+ @implement_for ("torch" , None , "2.4" )
1699
+ def _padded_indices (self , shapes , arange ) -> torch .Tensor :
1700
+ arange = arange .flip (0 ).split (shapes .flip (0 ).squeeze ().unbind ())
1701
+ return (
1702
+ torch .nn .utils .rnn .pad_sequence (arange , batch_first = True , padding_value = - 1 )
1703
+ .flip (- 1 )
1704
+ .flip (0 )
1705
+ )
1706
+
1684
1707
def _preceding_stop_idx (self , storage , lengths , seq_length ):
1685
1708
preceding_stop_idx = self ._cache .get ("preceding_stop_idx" )
1686
1709
if preceding_stop_idx is not None :
@@ -1698,16 +1721,7 @@ def _preceding_stop_idx(self, storage, lengths, seq_length):
1698
1721
all_but_starts [starts ] = False
1699
1722
arange = arange [all_but_starts ]
1700
1723
shapes = shapes - 1
1701
- # this complex mumbo jumbo creates a left padded tensor with valid indices on the right, e.g.
1702
- # tensor([[ 0, 1, 2, 3, 4],
1703
- # [-1, -1, 5, 6, 7],
1704
- # [-1, 8, 9, 10, 11]])
1705
- # where the -1 items on the left are padded values
1706
- st , off = torch ._nested_compute_contiguous_strides_offsets (shapes .flip (0 ))
1707
- nt = torch ._nested_view_from_buffer (
1708
- arange .flip (0 ).contiguous (), shapes .flip (0 ).contiguous (), st , off
1709
- )
1710
- pad = nt .to_padded_tensor (- 1 ).flip (- 1 ).flip (0 ).contiguous ()
1724
+ pad = self ._padded_indices (shapes , arange )
1711
1725
_ , span_right = self .span [0 ], self .span [1 ]
1712
1726
if span_right and isinstance (span_right , bool ):
1713
1727
preceding_stop_idx = pad [:, - 1 :]
0 commit comments