@@ -926,7 +926,9 @@ def __repr__(self):
926
926
)
927
927
928
928
@classmethod
929
- def _find_start_stop_traj (cls , * , trajectory = None , end = None , at_capacity : bool ):
929
+ def _find_start_stop_traj (
930
+ cls , * , trajectory = None , end = None , at_capacity : bool , cursor = None
931
+ ):
930
932
if trajectory is not None :
931
933
# slower
932
934
# _, stop_idx = torch.unique_consecutive(trajectory, return_counts=True)
@@ -954,12 +956,28 @@ def _find_start_stop_traj(cls, *, trajectory=None, end=None, at_capacity: bool):
954
956
dim = 0 ,
955
957
value = 1 ,
956
958
)
957
- elif not end . any ( 0 ). all () :
958
- # we must have at least one end by traj to delimitate trajectories
959
+ else :
960
+ # we must have at least one end by traj to individuate trajectories
959
961
# so if no end can be found we set it manually
960
- mask = ~ end .any (0 , True )
961
- mask = torch .cat ([torch .zeros_like (end [:- 1 ]), mask ])
962
- end = torch .masked_fill (mask , end , 1 )
962
+ if cursor is not None :
963
+ if isinstance (cursor , torch .Tensor ):
964
+ cursor = cursor [- 1 ].item ()
965
+ elif isinstance (cursor , range ):
966
+ cursor = cursor [- 1 ]
967
+ if not _is_int (cursor ):
968
+ raise RuntimeError (
969
+ "cursor should be an integer or a 1d tensor or a range."
970
+ )
971
+ end = torch .index_fill (
972
+ end ,
973
+ index = torch .tensor (cursor , device = end .device , dtype = torch .long ),
974
+ dim = 0 ,
975
+ value = 1 ,
976
+ )
977
+ if not end .any (0 ).all ():
978
+ mask = ~ end .any (0 , True )
979
+ mask = torch .cat ([torch .zeros_like (end [:- 1 ]), mask ])
980
+ end = torch .masked_fill (mask , end , 1 )
963
981
ndim = end .ndim
964
982
if ndim == 0 :
965
983
raise RuntimeError (
@@ -994,7 +1012,7 @@ def _end_to_start_stop(end, length):
994
1012
# In this case we have only one start and stop has already been set
995
1013
pass
996
1014
lengths = stop_idx [:, 0 ] - start_idx [:, 0 ] + 1
997
- lengths [lengths < 0 ] = lengths [lengths < 0 ] + length
1015
+ lengths [lengths <= 0 ] = lengths [lengths <= 0 ] + length
998
1016
return start_idx , stop_idx , lengths
999
1017
1000
1018
def _start_to_end (self , st : torch .Tensor , length : int ):
@@ -1072,7 +1090,9 @@ def _get_stop_and_length(self, storage, fallback=True):
1072
1090
"Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories."
1073
1091
)
1074
1092
vals = self ._find_start_stop_traj (
1075
- end = done .squeeze ()[: len (storage )], at_capacity = storage ._is_full
1093
+ end = done .squeeze ()[: len (storage )],
1094
+ at_capacity = storage ._is_full ,
1095
+ cursor = getattr (storage , "_last_cursor" , None ),
1076
1096
)
1077
1097
if self .cache_values :
1078
1098
self ._cache ["stop-and-length" ] = vals
@@ -1270,7 +1290,6 @@ def _get_index(
1270
1290
],
1271
1291
1 ,
1272
1292
)
1273
-
1274
1293
index = self ._tensor_slices_from_startend (seq_length , starts , storage_length )
1275
1294
if self .truncated_key is not None :
1276
1295
truncated_key = self .truncated_key
0 commit comments