Skip to content

Commit 609c72e

Browse files
authored
[BugFix] Various device fix (#558)
1 parent a243cdd commit 609c72e

File tree

8 files changed

+49
-38
lines changed

8 files changed

+49
-38
lines changed

test/test_exploration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,10 @@ def test_additivegaussian_sd(
116116
net,
117117
in_keys=["observation"],
118118
out_keys=["loc", "scale"],
119-
spec=CompositeSpec(action=action_spec) if spec_origin == "policy" else None,
119+
spec=None,
120120
)
121121
policy = ProbabilisticActor(
122-
spec=action_spec if spec_origin is not None else None,
122+
spec=CompositeSpec(action=action_spec) if spec_origin is not None else None,
123123
module=module,
124124
dist_param_keys=["loc", "scale"],
125125
distribution_class=TanhNormal,

test/test_helpers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
DiscreteActionVecMockEnv,
2525
)
2626
from torchrl.envs.libs.gym import _has_gym
27+
from torchrl.envs.transforms.transforms import _has_tv
2728
from torchrl.envs.utils import set_exploration_mode
2829
from torchrl.trainers.helpers import transformed_env_constructor
2930
from torchrl.trainers.helpers.envs import EnvConfig
@@ -55,6 +56,7 @@ def _assert_keys_match(td, expeceted_keys):
5556

5657

5758
@pytest.mark.skipif(not _has_gym, reason="No gym library found")
59+
@pytest.mark.skipif(not _has_tv, reason="No torchvision library found")
5860
@pytest.mark.skipif(not _has_hydra, reason="No hydra library found")
5961
@pytest.mark.parametrize("device", get_available_devices())
6062
@pytest.mark.parametrize("noisy", [tuple(), ("noisy=True",)])

test/test_tensordictmodules.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,9 @@ def test_stateful_probabilistic(self, safe, spec_type, lazy, exp_mode, out_keys)
106106
else:
107107
raise NotImplementedError
108108
spec = (
109-
CompositeSpec(out=spec, loc=None, scale=None) if spec is not None else None
109+
CompositeSpec(out=spec, **{out_key: None for out_key in out_keys})
110+
if spec is not None
111+
else None
110112
)
111113

112114
kwargs = {"distribution_class": TanhNormal}

torchrl/collectors/collectors.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,9 @@ def __init__(
355355
# otherwise, we perform a small number of steps with the policy to
356356
# determine the relevant keys with which to pre-populate _tensordict_out.
357357
# See #505 for additional context.
358-
self._tensordict_out = env.rollout(3, policy)
358+
self._tensordict_out = self.env.rollout(
359+
3, self.policy, auto_cast_to_device=True
360+
)
359361
if env.batch_size:
360362
self._tensordict_out = self._tensordict_out[..., :1]
361363
else:

torchrl/data/tensor_specs.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import abc
99
import os
10+
from copy import deepcopy
1011
from dataclasses import dataclass
1112
from textwrap import indent
1213
from typing import (
@@ -1066,6 +1067,11 @@ class CompositeSpec(TensorSpec):
10661067

10671068
domain: str = "composite"
10681069

1070+
@classmethod
1071+
def __new__(cls, *args, **kwargs):
1072+
cls._device = torch.device("cpu")
1073+
return super().__new__(cls)
1074+
10691075
def __init__(self, **kwargs):
10701076
self._specs = kwargs
10711077
if len(kwargs):
@@ -1205,11 +1211,11 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec:
12051211
"Only device casting is allowed with specs of type CompositeSpec."
12061212
)
12071213

1208-
for value in self.values():
1214+
self.device = torch.device(dest)
1215+
for key, value in list(self.items()):
12091216
if value is None:
12101217
continue
1211-
value.to(dest)
1212-
self.device = torch.device(dest)
1218+
self[key] = value.to(dest)
12131219
return self
12141220

12151221
def to_numpy(self, val: TensorDict, safe: bool = True) -> dict:
@@ -1230,3 +1236,9 @@ def __eq__(self, other):
12301236
and self._device == other._device
12311237
and self._specs == other._specs
12321238
)
1239+
1240+
def update(self, dict_or_spec: Union[CompositeSpec, Dict[str, TensorSpec]]) -> None:
1241+
for key, item in dict_or_spec.items():
1242+
if isinstance(item, TensorSpec) and item.device != self.device:
1243+
item = deepcopy(item).to(self.device)
1244+
self[key] = item

torchrl/modules/tensordict_module/common.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -196,12 +196,20 @@ def __init__(
196196
"Consider using a CompositeSpec object or no spec at all."
197197
)
198198
spec = CompositeSpec(**{self.out_keys[0]: spec})
199-
if spec and len(spec) < len(self.out_keys):
199+
elif spec is None:
200+
spec = CompositeSpec()
201+
202+
if set(spec.keys()) != set(self.out_keys):
200203
# then assume that all the non indicated specs are None
201204
for key in self.out_keys:
202205
if key not in spec:
203206
spec[key] = None
204207

208+
if set(spec.keys()) != set(self.out_keys):
209+
raise RuntimeError(
210+
f"spec keys and out_keys do not match, got: {spec.keys()} and {self.out_keys} respectively"
211+
)
212+
205213
self._spec = spec
206214
self.safe = safe
207215
if safe:
@@ -217,12 +225,6 @@ def __init__(
217225

218226
self.module = module
219227

220-
def __setattr__(self, key: str, attribute: Any) -> None:
221-
if key == "spec" and isinstance(attribute, TensorSpec):
222-
self._spec = attribute
223-
return
224-
super().__setattr__(key, attribute)
225-
226228
@property
227229
def is_functional(self):
228230
return isinstance(
@@ -231,14 +233,14 @@ def is_functional(self):
231233
)
232234

233235
@property
234-
def spec(self) -> TensorSpec:
236+
def spec(self) -> CompositeSpec:
235237
return self._spec
236238

237239
@spec.setter
238-
def spec(self, spec: TensorSpec) -> None:
239-
if not isinstance(spec, TensorSpec):
240+
def spec(self, spec: CompositeSpec) -> None:
241+
if not isinstance(spec, CompositeSpec):
240242
raise RuntimeError(
241-
f"Trying to set an object of type {type(spec)} as a tensorspec."
243+
f"Trying to set an object of type {type(spec)} as a tensorspec but expected a CompositeSpec instance."
242244
)
243245
self._spec = spec
244246

torchrl/modules/tensordict_module/sequence.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import torch
2525
from torch import Tensor, nn
2626

27-
from torchrl.data import CompositeSpec, TensorSpec
27+
from torchrl.data import CompositeSpec
2828
from torchrl.data.tensordict.tensordict import (
2929
LazyStackedTensorDict,
3030
TensorDict,
@@ -137,12 +137,19 @@ def __init__(
137137
):
138138
in_keys, out_keys = self._compute_in_and_out_keys(modules)
139139

140+
spec = CompositeSpec()
141+
for module in modules:
142+
if isinstance(module, TensorDictModule) or hasattr(module, "spec"):
143+
spec.update(module.spec)
144+
else:
145+
spec.update(CompositeSpec(**{key: None for key in module.out_keys}))
140146
super().__init__(
141-
spec=None,
147+
spec=spec,
142148
module=nn.ModuleList(list(modules)),
143149
in_keys=in_keys,
144150
out_keys=out_keys,
145151
)
152+
146153
self.partial_tolerant = partial_tolerant
147154

148155
def _compute_in_and_out_keys(self, modules: List[TensorDictModule]) -> Tuple[List]:
@@ -369,23 +376,6 @@ def __setitem__(self, index: int, tensordict_module: TensorDictModule) -> None:
369376
def __delitem__(self, index: Union[int, slice]) -> None:
370377
self.module.__delitem__(idx=index)
371378

372-
@property
373-
def spec(self):
374-
kwargs = {}
375-
for layer in self.module:
376-
out_key = layer.out_keys[0]
377-
spec = layer.spec
378-
if spec is not None and not isinstance(spec, TensorSpec):
379-
raise RuntimeError(
380-
f"TensorDictSequential.spec requires all specs to be valid TensorSpec objects. Got "
381-
f"{type(layer.spec)}"
382-
)
383-
if isinstance(spec, CompositeSpec):
384-
kwargs.update(spec._specs)
385-
else:
386-
kwargs[out_key] = spec
387-
return CompositeSpec(**kwargs)
388-
389379
def make_functional_with_buffers(self, clone: bool = True, native: bool = False):
390380
"""
391381
Transforms a stateful module in a functional module and returns its parameters and buffers.

torchrl/trainers/loggers/wandb.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ def __init__(
6969
"name": exp_name,
7070
"dir": save_dir,
7171
"id": id,
72-
"project": project,
72+
"project": "torchrl-private",
73+
"entity": "vmoens",
7374
"resume": "allow",
7475
**kwargs,
7576
}

0 commit comments

Comments
 (0)