@@ -5151,6 +5151,8 @@ def _reset(
5151
5151
step_count = tensordict .get (step_count_key , default = None )
5152
5152
if step_count is None :
5153
5153
step_count = self .container .observation_spec [step_count_key ].zero ()
5154
+ if step_count .device != reset .device :
5155
+ step_count = step_count .to (reset .device , non_blocking = True )
5154
5156
5155
5157
# zero the step count if reset is needed
5156
5158
step_count = torch .where (~ expand_as_right (reset , step_count ), step_count , 0 )
@@ -6413,7 +6415,7 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
6413
6415
raise ValueError (
6414
6416
self .SPEC_TYPE_ERROR .format (self .ACCEPTED_SPECS , type (action_spec ))
6415
6417
)
6416
- action_spec .update_mask (mask )
6418
+ action_spec .update_mask (mask . to ( action_spec . device ) )
6417
6419
return tensordict
6418
6420
6419
6421
def _reset (
@@ -6424,7 +6426,10 @@ def _reset(
6424
6426
raise ValueError (
6425
6427
self .SPEC_TYPE_ERROR .format (self .ACCEPTED_SPECS , type (action_spec ))
6426
6428
)
6427
- action_spec .update_mask (tensordict .get (self .in_keys [1 ], None ))
6429
+ mask = tensordict .get (self .in_keys [1 ], None )
6430
+ if mask is not None :
6431
+ mask = mask .to (action_spec .device )
6432
+ action_spec .update_mask (mask )
6428
6433
6429
6434
# TODO: Check that this makes sense
6430
6435
with _set_missing_tolerance (self , True ):
0 commit comments