Skip to content

Commit 5f82601

Browse files
author
Vincent Moens
authored
[BugFix] better device consistency in EGreedy (#1867)
1 parent 0672359 commit 5f82601

File tree

1 file changed

+2
-5
lines changed

1 file changed

+2
-5
lines changed

torchrl/modules/tensordict_module/exploration.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
149149

150150
out = action_tensordict.get(action_key)
151151
eps = self.eps.item()
152-
cond = (
153-
torch.rand(action_tensordict.shape, device=action_tensordict.device)
154-
< eps
155-
).to(out.dtype)
152+
cond = torch.rand(action_tensordict.shape, device=out.device) < eps
156153
cond = expand_as_right(cond, out)
157154
spec = self.spec
158155
if spec is not None:
@@ -177,7 +174,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
177174
f"Action mask key {self.action_mask_key} not found in {tensordict}."
178175
)
179176
spec.update_mask(action_mask)
180-
out = cond * spec.rand().to(out.device) + (1 - cond) * out
177+
out = torch.where(cond, spec.rand().to(out.device), out)
181178
else:
182179
raise RuntimeError("spec must be provided to the exploration wrapper.")
183180
action_tensordict.set(action_key, out)

0 commit comments

Comments
 (0)