Skip to content

Commit 55ec016

Browse files
author
Vincent Moens
authored
[BugFix] Fix device of container generated values in transforms (#1827)
1 parent 3f04131 commit 55ec016

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

torchrl/envs/transforms/transforms.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5151,6 +5151,8 @@ def _reset(
51515151
step_count = tensordict.get(step_count_key, default=None)
51525152
if step_count is None:
51535153
step_count = self.container.observation_spec[step_count_key].zero()
5154+
if step_count.device != reset.device:
5155+
step_count = step_count.to(reset.device, non_blocking=True)
51545156

51555157
# zero the step count if reset is needed
51565158
step_count = torch.where(~expand_as_right(reset, step_count), step_count, 0)
@@ -6413,7 +6415,7 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
64136415
raise ValueError(
64146416
self.SPEC_TYPE_ERROR.format(self.ACCEPTED_SPECS, type(action_spec))
64156417
)
6416-
action_spec.update_mask(mask)
6418+
action_spec.update_mask(mask.to(action_spec.device))
64176419
return tensordict
64186420

64196421
def _reset(
@@ -6424,7 +6426,10 @@ def _reset(
64246426
raise ValueError(
64256427
self.SPEC_TYPE_ERROR.format(self.ACCEPTED_SPECS, type(action_spec))
64266428
)
6427-
action_spec.update_mask(tensordict.get(self.in_keys[1], None))
6429+
mask = tensordict.get(self.in_keys[1], None)
6430+
if mask is not None:
6431+
mask = mask.to(action_spec.device)
6432+
action_spec.update_mask(mask)
64286433

64296434
# TODO: Check that this makes sense
64306435
with _set_missing_tolerance(self, True):

0 commit comments

Comments
 (0)