Skip to content

Commit 47a1005

Browse files
author
Vincent Moens
authored
[BugFix] Fix slice sampler end computation at the cursor place (#2225)
1 parent f613eef commit 47a1005

File tree

3 files changed

+31
-11
lines changed

3 files changed

+31
-11
lines changed

torchrl/data/replay_buffers/samplers.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -926,7 +926,9 @@ def __repr__(self):
926926
)
927927

928928
@classmethod
929-
def _find_start_stop_traj(cls, *, trajectory=None, end=None, at_capacity: bool):
929+
def _find_start_stop_traj(
930+
cls, *, trajectory=None, end=None, at_capacity: bool, cursor=None
931+
):
930932
if trajectory is not None:
931933
# slower
932934
# _, stop_idx = torch.unique_consecutive(trajectory, return_counts=True)
@@ -954,12 +956,28 @@ def _find_start_stop_traj(cls, *, trajectory=None, end=None, at_capacity: bool):
954956
dim=0,
955957
value=1,
956958
)
957-
elif not end.any(0).all():
958-
# we must have at least one end by traj to delimitate trajectories
959+
else:
960+
# we must have at least one end by traj to individuate trajectories
959961
# so if no end can be found we set it manually
960-
mask = ~end.any(0, True)
961-
mask = torch.cat([torch.zeros_like(end[:-1]), mask])
962-
end = torch.masked_fill(mask, end, 1)
962+
if cursor is not None:
963+
if isinstance(cursor, torch.Tensor):
964+
cursor = cursor[-1].item()
965+
elif isinstance(cursor, range):
966+
cursor = cursor[-1]
967+
if not _is_int(cursor):
968+
raise RuntimeError(
969+
"cursor should be an integer or a 1d tensor or a range."
970+
)
971+
end = torch.index_fill(
972+
end,
973+
index=torch.tensor(cursor, device=end.device, dtype=torch.long),
974+
dim=0,
975+
value=1,
976+
)
977+
if not end.any(0).all():
978+
mask = ~end.any(0, True)
979+
mask = torch.cat([torch.zeros_like(end[:-1]), mask])
980+
end = torch.masked_fill(mask, end, 1)
963981
ndim = end.ndim
964982
if ndim == 0:
965983
raise RuntimeError(
@@ -994,7 +1012,7 @@ def _end_to_start_stop(end, length):
9941012
# In this case we have only one start and stop has already been set
9951013
pass
9961014
lengths = stop_idx[:, 0] - start_idx[:, 0] + 1
997-
lengths[lengths < 0] = lengths[lengths < 0] + length
1015+
lengths[lengths <= 0] = lengths[lengths <= 0] + length
9981016
return start_idx, stop_idx, lengths
9991017

10001018
def _start_to_end(self, st: torch.Tensor, length: int):
@@ -1072,7 +1090,9 @@ def _get_stop_and_length(self, storage, fallback=True):
10721090
"Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories."
10731091
)
10741092
vals = self._find_start_stop_traj(
1075-
end=done.squeeze()[: len(storage)], at_capacity=storage._is_full
1093+
end=done.squeeze()[: len(storage)],
1094+
at_capacity=storage._is_full,
1095+
cursor=getattr(storage, "_last_cursor", None),
10761096
)
10771097
if self.cache_values:
10781098
self._cache["stop-and-length"] = vals
@@ -1270,7 +1290,6 @@ def _get_index(
12701290
],
12711291
1,
12721292
)
1273-
12741293
index = self._tensor_slices_from_startend(seq_length, starts, storage_length)
12751294
if self.truncated_key is not None:
12761295
truncated_key = self.truncated_key

torchrl/envs/common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2609,9 +2609,10 @@ def rollout(
26092609
for key in self.done_keys:
26102610
if _ends_with(key, "truncated"):
26112611
val = out_td.get(("next", key))
2612+
done = out_td.get(("next", _replace_last(key, "done")))
26122613
val[(slice(None),) * (out_td.ndim - 1) + (-1,)] = True
26132614
out_td.set(("next", key), val)
2614-
out_td.set(("next", _replace_last(key, "done")), val)
2615+
out_td.set(("next", _replace_last(key, "done")), val | done)
26152616
found_truncated = True
26162617
if not found_truncated:
26172618
raise RuntimeError(

torchrl/envs/transforms/transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3048,7 +3048,7 @@ def unfold_done(done, N):
30483048
reset_unfold_list = [torch.zeros_like(reset_unfold_slice)]
30493049
for r in reversed(reset_unfold.unbind(-1)):
30503050
reset_unfold_list.append(r | reset_unfold_list[-1])
3051-
reset_unfold_slice = reset_unfold_list[-1]
3051+
# reset_unfold_slice = reset_unfold_list[-1]
30523052
reset_unfold = torch.stack(list(reversed(reset_unfold_list))[1:], -1)
30533053
reset = reset[prefix + (slice(self.N - 1, None),)]
30543054
reset[prefix + (0,)] = 1

0 commit comments

Comments
 (0)