Skip to content

[Feature] EnvBase.auto_specs_ #2601

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,11 +1038,13 @@ def _step(
tensordict: TensorDictBase,
) -> TensorDictBase:
action = tensordict.get(self.action_key)
try:
device = self.full_action_spec[self.action_key].device
except KeyError:
device = self.device
self.count += action.to(
dtype=torch.int,
device=self.full_action_spec[self.action_key].device
if self.device is None
else self.device,
device=device if self.device is None else self.device,
)
tensordict = TensorDict(
source={
Expand Down Expand Up @@ -1275,8 +1277,10 @@ def __init__(
max_steps = torch.tensor(5)
if start_val is None:
start_val = torch.zeros((), dtype=torch.int32)
if not max_steps.shape == self.batch_size:
raise RuntimeError("batch_size and max_steps shape must match.")
if max_steps.shape != self.batch_size:
raise RuntimeError(
f"batch_size and max_steps shape must match. Got self.batch_size={self.batch_size} and max_steps.shape={max_steps.shape}."
)

self.max_steps = max_steps

Expand Down
28 changes: 28 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3526,6 +3526,34 @@ def test_single_env_spec():
assert env.input_spec.is_in(env.input_spec_unbatched.zeros(env.shape))


def test_auto_spec():
env = CountingEnv()
td = env.reset()

policy = lambda td, action_spec=env.full_action_spec.clone(): td.update(
action_spec.rand()
)

env.full_observation_spec = Composite(
shape=env.full_observation_spec.shape, device=env.full_observation_spec.device
)
env.full_action_spec = Composite(
shape=env.full_action_spec.shape, device=env.full_action_spec.device
)
env.full_reward_spec = Composite(
shape=env.full_reward_spec.shape, device=env.full_reward_spec.device
)
env.full_done_spec = Composite(
shape=env.full_done_spec.shape, device=env.full_done_spec.device
)
env.full_state_spec = Composite(
shape=env.full_state_spec.shape, device=env.full_state_spec.device
)
env._action_keys = ["action"]
env.auto_specs_(policy, tensordict=td.copy())
env.check_env_specs(tensordict=td.copy())


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
9 changes: 9 additions & 0 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,15 @@ def test_getitem(self, shape, is_complete, device, dtype):
with pytest.raises(KeyError):
_ = ts["UNK"]

def test_setitem_newshape(self, shape, is_complete, device, dtype):
ts = self._composite_spec(shape, is_complete, device, dtype)
new_spec = ts.clone()
new_spec.shape = torch.Size(())
new_spec.clear_device_()
ts["new_spec"] = new_spec
assert ts["new_spec"].shape == ts.shape
assert ts["new_spec"].device == ts.device

def test_setitem_forbidden_keys(self, shape, is_complete, device, dtype):
ts = self._composite_spec(shape, is_complete, device, dtype)
for key in {"shape", "device", "dtype", "space"}:
Expand Down
25 changes: 20 additions & 5 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4372,11 +4372,20 @@ def set(self, name, spec):
if spec is not None:
shape = spec.shape
if shape[: self.ndim] != self.shape:
raise ValueError(
"The shape of the spec and the Composite mismatch: the first "
f"{self.ndim} dimensions should match but got spec.shape={spec.shape} and "
f"Composite.shape={self.shape}."
)
if (
isinstance(spec, Composite)
and spec.ndim < self.ndim
and self.shape[: spec.ndim] == spec.shape
):
# Try to set the composite shape
spec = spec.clone()
spec.shape = self.shape
else:
raise ValueError(
"The shape of the spec and the Composite mismatch: the first "
f"{self.ndim} dimensions should match but got spec.shape={spec.shape} and "
f"Composite.shape={self.shape}."
)
self._specs[name] = spec

def __init__(
Expand Down Expand Up @@ -4448,6 +4457,8 @@ def clear_device_(self):
"""Clears the device of the Composite."""
self._device = None
for spec in self._specs.values():
if spec is None:
continue
spec.clear_device_()
return self

Expand Down Expand Up @@ -4530,6 +4541,10 @@ def __setitem__(self, key, value):
and value.device != self.device
):
if isinstance(value, Composite) and value.device is None:
# We make a clone not to mess up the spec that was provided.
# in set() we do the same for shape - these two ops should be grouped.
# we don't care about the overhead of cloning twice though because in theory
# we don't set specs often.
value = value.clone().to(self.device)
else:
raise RuntimeError(
Expand Down
Loading
Loading