@@ -35,11 +35,29 @@ def split_trajectories(rollout_tensordict: _TensorDict) -> _TensorDict:
35
35
ndim = len (rollout_tensordict .batch_size )
36
36
splits = traj_ids .view (- 1 )
37
37
splits = [(splits == i ).sum ().item () for i in splits .unique_consecutive ()]
38
- out_splits = {
39
- key : _d .contiguous ().view (- 1 , * _d .shape [ndim :]).split (splits , 0 )
40
- for key , _d in rollout_tensordict .items ()
41
- # if key not in ("step_count", "traj_ids")
42
- }
38
+ # if all splits are identical then we can skip this function
39
+ if len (set (splits )) == 1 :
40
+ rollout_tensordict .set (
41
+ "mask" ,
42
+ torch .ones (
43
+ rollout_tensordict .shape ,
44
+ device = rollout_tensordict .device ,
45
+ dtype = torch .bool ,
46
+ ),
47
+ )
48
+ if rollout_tensordict .ndimension () == 1 :
49
+ rollout_tensordict = rollout_tensordict .unsqueeze (0 ).to_tensordict ()
50
+ return rollout_tensordict
51
+ try :
52
+ out_splits = {
53
+ key : _d .contiguous ().view (- 1 , * _d .shape [ndim :]).split (splits , 0 )
54
+ for key , _d in rollout_tensordict .items ()
55
+ # if key not in ("step_count", "traj_ids")
56
+ }
57
+ except RuntimeError as err :
58
+ torch .save ({"td" : rollout_tensordict , "err" : err }, "dump.t" )
59
+ raise err
60
+
43
61
# select complete rollouts
44
62
dones = out_splits ["done" ]
45
63
valid_ids = list (range (len (dones )))
0 commit comments