Skip to content

Commit d171847

Browse files
author
Vincent Moens
committed
[Refactor] Use empty_like in storage construction
ghstack-source-id: 28cd569 Pull Request resolved: #2455
1 parent ca3a595 commit d171847

File tree

2 files changed

+11
-16
lines changed

2 files changed

+11
-16
lines changed

torchrl/data/replay_buffers/samplers.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -982,22 +982,20 @@ def _find_start_stop_traj(
982982

983983
# faster
984984
end = trajectory[:-1] != trajectory[1:]
985-
end = torch.cat([end, trajectory[-1:] != trajectory[:1]], 0)
985+
if not at_capacity:
986+
end = torch.cat([end, torch.ones_like(end[:1])], 0)
987+
else:
988+
end = torch.cat([end, trajectory[-1:] != trajectory[:1]], 0)
986989
length = trajectory.shape[0]
987990
else:
988-
# TODO: check that storage is at capacity here, if not we need to assume that the last element of end is True
989-
990991
# We presume that not done at the end means that the traj spans across end and beginning of storage
991992
length = end.shape[0]
993+
if not at_capacity:
994+
end = end.clone()
995+
end[length - 1] = True
996+
ndim = end.ndim
992997

993-
if not at_capacity:
994-
end = torch.index_fill(
995-
end,
996-
index=torch.tensor(-1, device=end.device, dtype=torch.long),
997-
dim=0,
998-
value=1,
999-
)
1000-
else:
998+
if at_capacity:
1001999
# we must have at least one end by traj to individuate trajectories
10021000
# so if no end can be found we set it manually
10031001
if cursor is not None:
@@ -1019,7 +1017,6 @@ def _find_start_stop_traj(
10191017
mask = ~end.any(0, True)
10201018
mask = torch.cat([torch.zeros_like(end[:-1]), mask])
10211019
end = torch.masked_fill(mask, end, 1)
1022-
ndim = end.ndim
10231020
if ndim == 0:
10241021
raise RuntimeError(
10251022
"Expected the end-of-trajectory signal to be at least 1-dimensional."
@@ -1126,7 +1123,7 @@ def _get_stop_and_length(self, storage, fallback=True):
11261123
"Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories."
11271124
)
11281125
vals = self._find_start_stop_traj(
1129-
trajectory=trajectory,
1126+
trajectory=trajectory.clone(),
11301127
at_capacity=storage._is_full,
11311128
cursor=getattr(storage, "_last_cursor", None),
11321129
)

torchrl/data/replay_buffers/storages.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -901,9 +901,7 @@ def max_size_along_dim0(data_shape):
901901

902902
if is_tensor_collection(data):
903903
out = data.to(self.device)
904-
out = out.expand(max_size_along_dim0(data.shape))
905-
out = out.clone()
906-
out = out.zero_()
904+
out = torch.empty_like(out.expand(max_size_along_dim0(data.shape)))
907905
else:
908906
# if Tensor, we just create a MemoryMappedTensor of the desired shape, device and dtype
909907
out = tree_map(

0 commit comments

Comments
 (0)