Skip to content

[BugFix] Various device fix #558

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Oct 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions test/test_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,10 @@ def test_additivegaussian_sd(
net,
in_keys=["observation"],
out_keys=["loc", "scale"],
spec=CompositeSpec(action=action_spec) if spec_origin == "policy" else None,
spec=None,
)
policy = ProbabilisticActor(
spec=action_spec if spec_origin is not None else None,
spec=CompositeSpec(action=action_spec) if spec_origin is not None else None,
module=module,
dist_param_keys=["loc", "scale"],
distribution_class=TanhNormal,
Expand Down
2 changes: 2 additions & 0 deletions test/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
DiscreteActionVecMockEnv,
)
from torchrl.envs.libs.gym import _has_gym
from torchrl.envs.transforms.transforms import _has_tv
from torchrl.envs.utils import set_exploration_mode
from torchrl.trainers.helpers import transformed_env_constructor
from torchrl.trainers.helpers.envs import EnvConfig
Expand Down Expand Up @@ -55,6 +56,7 @@ def _assert_keys_match(td, expeceted_keys):


@pytest.mark.skipif(not _has_gym, reason="No gym library found")
@pytest.mark.skipif(not _has_tv, reason="No torchvision library found")
@pytest.mark.skipif(not _has_hydra, reason="No hydra library found")
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("noisy", [tuple(), ("noisy=True",)])
Expand Down
4 changes: 3 additions & 1 deletion test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,9 @@ def test_stateful_probabilistic(self, safe, spec_type, lazy, exp_mode, out_keys)
else:
raise NotImplementedError
spec = (
CompositeSpec(out=spec, loc=None, scale=None) if spec is not None else None
CompositeSpec(out=spec, **{out_key: None for out_key in out_keys})
if spec is not None
else None
)

kwargs = {"distribution_class": TanhNormal}
Expand Down
4 changes: 3 additions & 1 deletion torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,9 @@ def __init__(
# otherwise, we perform a small number of steps with the policy to
# determine the relevant keys with which to pre-populate _tensordict_out.
# See #505 for additional context.
self._tensordict_out = env.rollout(3, policy)
self._tensordict_out = self.env.rollout(
3, self.policy, auto_cast_to_device=True
)
if env.batch_size:
self._tensordict_out = self._tensordict_out[..., :1]
else:
Expand Down
18 changes: 15 additions & 3 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import abc
import os
from copy import deepcopy
from dataclasses import dataclass
from textwrap import indent
from typing import (
Expand Down Expand Up @@ -1066,6 +1067,11 @@ class CompositeSpec(TensorSpec):

domain: str = "composite"

@classmethod
def __new__(cls, *args, **kwargs):
cls._device = torch.device("cpu")
return super().__new__(cls)

def __init__(self, **kwargs):
self._specs = kwargs
if len(kwargs):
Expand Down Expand Up @@ -1205,11 +1211,11 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec:
"Only device casting is allowed with specs of type CompositeSpec."
)

for value in self.values():
self.device = torch.device(dest)
for key, value in list(self.items()):
if value is None:
continue
value.to(dest)
self.device = torch.device(dest)
self[key] = value.to(dest)
return self

def to_numpy(self, val: TensorDict, safe: bool = True) -> dict:
Expand All @@ -1230,3 +1236,9 @@ def __eq__(self, other):
and self._device == other._device
and self._specs == other._specs
)

def update(self, dict_or_spec: Union[CompositeSpec, Dict[str, TensorSpec]]) -> None:
for key, item in dict_or_spec.items():
if isinstance(item, TensorSpec) and item.device != self.device:
item = deepcopy(item).to(self.device)
self[key] = item
24 changes: 13 additions & 11 deletions torchrl/modules/tensordict_module/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,12 +196,20 @@ def __init__(
"Consider using a CompositeSpec object or no spec at all."
)
spec = CompositeSpec(**{self.out_keys[0]: spec})
if spec and len(spec) < len(self.out_keys):
elif spec is None:
spec = CompositeSpec()

if set(spec.keys()) != set(self.out_keys):
# then assume that all the non indicated specs are None
for key in self.out_keys:
if key not in spec:
spec[key] = None

if set(spec.keys()) != set(self.out_keys):
raise RuntimeError(
f"spec keys and out_keys do not match, got: {spec.keys()} and {self.out_keys} respectively"
)

self._spec = spec
self.safe = safe
if safe:
Expand All @@ -217,12 +225,6 @@ def __init__(

self.module = module

def __setattr__(self, key: str, attribute: Any) -> None:
if key == "spec" and isinstance(attribute, TensorSpec):
self._spec = attribute
return
super().__setattr__(key, attribute)

@property
def is_functional(self):
return isinstance(
Expand All @@ -231,14 +233,14 @@ def is_functional(self):
)

@property
def spec(self) -> TensorSpec:
def spec(self) -> CompositeSpec:
return self._spec

@spec.setter
def spec(self, spec: TensorSpec) -> None:
if not isinstance(spec, TensorSpec):
def spec(self, spec: CompositeSpec) -> None:
if not isinstance(spec, CompositeSpec):
raise RuntimeError(
f"Trying to set an object of type {type(spec)} as a tensorspec."
f"Trying to set an object of type {type(spec)} as a tensorspec but expected a CompositeSpec instance."
)
self._spec = spec

Expand Down
28 changes: 9 additions & 19 deletions torchrl/modules/tensordict_module/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import torch
from torch import Tensor, nn

from torchrl.data import CompositeSpec, TensorSpec
from torchrl.data import CompositeSpec
from torchrl.data.tensordict.tensordict import (
LazyStackedTensorDict,
TensorDict,
Expand Down Expand Up @@ -137,12 +137,19 @@ def __init__(
):
in_keys, out_keys = self._compute_in_and_out_keys(modules)

spec = CompositeSpec()
for module in modules:
if isinstance(module, TensorDictModule) or hasattr(module, "spec"):
spec.update(module.spec)
else:
spec.update(CompositeSpec(**{key: None for key in module.out_keys}))
super().__init__(
spec=None,
spec=spec,
module=nn.ModuleList(list(modules)),
in_keys=in_keys,
out_keys=out_keys,
)

self.partial_tolerant = partial_tolerant

def _compute_in_and_out_keys(self, modules: List[TensorDictModule]) -> Tuple[List]:
Expand Down Expand Up @@ -369,23 +376,6 @@ def __setitem__(self, index: int, tensordict_module: TensorDictModule) -> None:
def __delitem__(self, index: Union[int, slice]) -> None:
self.module.__delitem__(idx=index)

@property
def spec(self):
kwargs = {}
for layer in self.module:
out_key = layer.out_keys[0]
spec = layer.spec
if spec is not None and not isinstance(spec, TensorSpec):
raise RuntimeError(
f"TensorDictSequential.spec requires all specs to be valid TensorSpec objects. Got "
f"{type(layer.spec)}"
)
if isinstance(spec, CompositeSpec):
kwargs.update(spec._specs)
else:
kwargs[out_key] = spec
return CompositeSpec(**kwargs)

def make_functional_with_buffers(self, clone: bool = True, native: bool = False):
"""
Transforms a stateful module in a functional module and returns its parameters and buffers.
Expand Down
3 changes: 2 additions & 1 deletion torchrl/trainers/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def __init__(
"name": exp_name,
"dir": save_dir,
"id": id,
"project": project,
"project": "torchrl-private",
"entity": "vmoens",
"resume": "allow",
**kwargs,
}
Expand Down