Skip to content

Commit 01e00b5

Browse files
committed
Feature: no-op for split_trajectory when there is no need for it (#216)
1 parent 87ea43d commit 01e00b5

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

torchrl/collectors/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,25 @@ def split_trajectories(rollout_tensordict: _TensorDict) -> _TensorDict:
3535
ndim = len(rollout_tensordict.batch_size)
3636
splits = traj_ids.view(-1)
3737
splits = [(splits == i).sum().item() for i in splits.unique_consecutive()]
38+
# if all splits are identical then we can skip this function
39+
if len(set(splits)) == 1:
40+
rollout_tensordict.set(
41+
"mask",
42+
torch.ones(
43+
rollout_tensordict.shape,
44+
device=rollout_tensordict.device,
45+
dtype=torch.bool,
46+
),
47+
)
48+
if rollout_tensordict.ndimension() == 1:
49+
rollout_tensordict = rollout_tensordict.unsqueeze(0).to_tensordict()
50+
return rollout_tensordict
3851
out_splits = {
3952
key: _d.contiguous().view(-1, *_d.shape[ndim:]).split(splits, 0)
4053
for key, _d in rollout_tensordict.items()
4154
# if key not in ("step_count", "traj_ids")
4255
}
56+
4357
# select complete rollouts
4458
dones = out_splits["done"]
4559
valid_ids = list(range(len(dones)))

0 commit comments

Comments
 (0)