File tree Expand file tree Collapse file tree 1 file changed +14
-0
lines changed Expand file tree Collapse file tree 1 file changed +14
-0
lines changed Original file line number Diff line number Diff line change @@ -35,11 +35,25 @@ def split_trajectories(rollout_tensordict: _TensorDict) -> _TensorDict:
35
35
ndim = len (rollout_tensordict .batch_size )
36
36
splits = traj_ids .view (- 1 )
37
37
splits = [(splits == i ).sum ().item () for i in splits .unique_consecutive ()]
38
+ # if all splits are identical then we can skip this function
39
+ if len (set (splits )) == 1 :
40
+ rollout_tensordict .set (
41
+ "mask" ,
42
+ torch .ones (
43
+ rollout_tensordict .shape ,
44
+ device = rollout_tensordict .device ,
45
+ dtype = torch .bool ,
46
+ ),
47
+ )
48
+ if rollout_tensordict .ndimension () == 1 :
49
+ rollout_tensordict = rollout_tensordict .unsqueeze (0 ).to_tensordict ()
50
+ return rollout_tensordict
38
51
out_splits = {
39
52
key : _d .contiguous ().view (- 1 , * _d .shape [ndim :]).split (splits , 0 )
40
53
for key , _d in rollout_tensordict .items ()
41
54
# if key not in ("step_count", "traj_ids")
42
55
}
56
+
43
57
# select complete rollouts
44
58
dones = out_splits ["done" ]
45
59
valid_ids = list (range (len (dones )))
You can’t perform that action at this time.
0 commit comments