@@ -982,22 +982,20 @@ def _find_start_stop_traj(
982
982
983
983
# faster
984
984
end = trajectory [:- 1 ] != trajectory [1 :]
985
- end = torch .cat ([end , trajectory [- 1 :] != trajectory [:1 ]], 0 )
985
+ if not at_capacity :
986
+ end = torch .cat ([end , torch .ones_like (end [:1 ])], 0 )
987
+ else :
988
+ end = torch .cat ([end , trajectory [- 1 :] != trajectory [:1 ]], 0 )
986
989
length = trajectory .shape [0 ]
987
990
else :
988
- # TODO: check that storage is at capacity here, if not we need to assume that the last element of end is True
989
-
990
991
# We presume that not done at the end means that the traj spans across end and beginning of storage
991
992
length = end .shape [0 ]
993
+ if not at_capacity :
994
+ end = end .clone ()
995
+ end [length - 1 ] = True
996
+ ndim = end .ndim
992
997
993
- if not at_capacity :
994
- end = torch .index_fill (
995
- end ,
996
- index = torch .tensor (- 1 , device = end .device , dtype = torch .long ),
997
- dim = 0 ,
998
- value = 1 ,
999
- )
1000
- else :
998
+ if at_capacity :
1001
999
# we must have at least one end by traj to individuate trajectories
1002
1000
# so if no end can be found we set it manually
1003
1001
if cursor is not None :
@@ -1019,7 +1017,6 @@ def _find_start_stop_traj(
1019
1017
mask = ~ end .any (0 , True )
1020
1018
mask = torch .cat ([torch .zeros_like (end [:- 1 ]), mask ])
1021
1019
end = torch .masked_fill (mask , end , 1 )
1022
- ndim = end .ndim
1023
1020
if ndim == 0 :
1024
1021
raise RuntimeError (
1025
1022
"Expected the end-of-trajectory signal to be at least 1-dimensional."
@@ -1126,7 +1123,7 @@ def _get_stop_and_length(self, storage, fallback=True):
1126
1123
"Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories."
1127
1124
)
1128
1125
vals = self ._find_start_stop_traj (
1129
- trajectory = trajectory ,
1126
+ trajectory = trajectory . clone () ,
1130
1127
at_capacity = storage ._is_full ,
1131
1128
cursor = getattr (storage , "_last_cursor" , None ),
1132
1129
)
0 commit comments