Skip to content

Commit b8be282

Browse files
authored
Feature: no-op for split_trajectory when there is no need for it (#216)
1 parent 87ea43d commit b8be282

File tree

1 file changed

+23
-5
lines changed

1 file changed

+23
-5
lines changed

torchrl/collectors/utils.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,29 @@ def split_trajectories(rollout_tensordict: _TensorDict) -> _TensorDict:
3535
ndim = len(rollout_tensordict.batch_size)
3636
splits = traj_ids.view(-1)
3737
splits = [(splits == i).sum().item() for i in splits.unique_consecutive()]
38-
out_splits = {
39-
key: _d.contiguous().view(-1, *_d.shape[ndim:]).split(splits, 0)
40-
for key, _d in rollout_tensordict.items()
41-
# if key not in ("step_count", "traj_ids")
42-
}
38+
# if all splits are identical then we can skip this function
39+
if len(set(splits)) == 1:
40+
rollout_tensordict.set(
41+
"mask",
42+
torch.ones(
43+
rollout_tensordict.shape,
44+
device=rollout_tensordict.device,
45+
dtype=torch.bool,
46+
),
47+
)
48+
if rollout_tensordict.ndimension() == 1:
49+
rollout_tensordict = rollout_tensordict.unsqueeze(0).to_tensordict()
50+
return rollout_tensordict
51+
try:
52+
out_splits = {
53+
key: _d.contiguous().view(-1, *_d.shape[ndim:]).split(splits, 0)
54+
for key, _d in rollout_tensordict.items()
55+
# if key not in ("step_count", "traj_ids")
56+
}
57+
except RuntimeError as err:
58+
torch.save({"td": rollout_tensordict, "err": err}, "dump.t")
59+
raise err
60+
4361
# select complete rollouts
4462
dones = out_splits["done"]
4563
valid_ids = list(range(len(dones)))

0 commit comments

Comments
 (0)