Skip to content

Commit db9e2af

Browse files
author
Vincent Moens
committed
[Feature] TensorSpec.enumerate()
ghstack-source-id: 47c3c22 Pull Request resolved: #2354
1 parent 6deedec commit db9e2af

File tree

2 files changed

+161
-6
lines changed

2 files changed

+161
-6
lines changed

test/test_specs.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3740,6 +3740,53 @@ def test_device_ordinal():
37403740
assert spec.device == torch.device("cuda:0")
37413741

37423742

3743+
class TestSpecEnumerate:
3744+
def test_discrete(self):
3745+
spec = DiscreteTensorSpec(n=5, shape=(3,))
3746+
assert (
3747+
spec.enumerate()
3748+
== torch.tensor([[0, 0, 0], [1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4]])
3749+
).all()
3750+
3751+
def test_one_hot(self):
3752+
spec = OneHotDiscreteTensorSpec(n=5, shape=(2, 5))
3753+
assert (
3754+
spec.enumerate()
3755+
== torch.tensor(
3756+
[
3757+
[[1, 0, 0, 0, 0], [1, 0, 0, 0, 0]],
3758+
[[0, 1, 0, 0, 0], [0, 1, 0, 0, 0]],
3759+
[[0, 0, 1, 0, 0], [0, 0, 1, 0, 0]],
3760+
[[0, 0, 0, 1, 0], [0, 0, 0, 1, 0]],
3761+
[[0, 0, 0, 0, 1], [0, 0, 0, 0, 1]],
3762+
],
3763+
dtype=torch.bool,
3764+
)
3765+
).all()
3766+
3767+
def test_multi_discrete(self):
3768+
spec = MultiDiscreteTensorSpec([3, 4, 5], shape=(2, 3))
3769+
enum = spec.enumerate()
3770+
assert enum.shape == torch.Size([60, 2, 3])
3771+
3772+
def test_multi_onehot(self):
3773+
spec = MultiOneHotDiscreteTensorSpec([3, 4, 5], shape=(2, 12))
3774+
enum = spec.enumerate()
3775+
assert enum.shape == torch.Size([60, 2, 12])
3776+
3777+
def test_composite(self):
3778+
c = CompositeSpec(
3779+
{
3780+
"a": OneHotDiscreteTensorSpec(n=5, shape=(3, 5)),
3781+
("b", "c"): DiscreteTensorSpec(n=4, shape=(3,)),
3782+
},
3783+
shape=[3],
3784+
)
3785+
c_enum = c.enumerate()
3786+
assert c_enum.shape == torch.Size((20, 3))
3787+
assert c_enum["b"].shape == torch.Size((20, 3))
3788+
3789+
37433790
if __name__ == "__main__":
37443791
args, unknown = argparse.ArgumentParser().parse_known_args()
37453792
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/data/tensor_specs.py

Lines changed: 114 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -756,6 +756,16 @@ def contains(self, item):
756756
"""
757757
return self.is_in(item)
758758

759+
@abc.abstractmethod
760+
def enumerate(self):
761+
"""Returns all the samples that can be obtained from the TensorSpec.
762+
763+
The samples will be stacked along the first dimension.
764+
765+
This method is only implemented for discrete specs.
766+
"""
767+
...
768+
759769
def project(self, val: torch.Tensor) -> torch.Tensor:
760770
"""If the input tensor is not in the TensorSpec box, it maps it back to it given some heuristic.
761771
@@ -1152,6 +1162,11 @@ def __eq__(self, other):
11521162
return False
11531163
return True
11541164

1165+
def enumerate(self):
1166+
return torch.stack(
1167+
[spec.enumerate() for spec in self._specs], dim=self.stack_dim + 1
1168+
)
1169+
11551170
def __len__(self):
11561171
return self.shape[0]
11571172

@@ -1601,6 +1616,13 @@ def to_numpy(self, val: torch.Tensor, safe: bool = None) -> np.ndarray:
16011616
return np.array(vals).reshape(tuple(val.shape))
16021617
return val
16031618

1619+
def enumerate(self):
1620+
return (
1621+
torch.eye(self.n, dtype=self.dtype, device=self.device)
1622+
.expand(*self.shape, self.n)
1623+
.permute(-2, *range(self.ndimension() - 1), -1)
1624+
)
1625+
16041626
def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Tensor:
16051627
if not isinstance(index, torch.Tensor):
16061628
raise ValueError(
@@ -1832,6 +1854,11 @@ def __init__(
18321854
domain=domain,
18331855
)
18341856

1857+
def enumerate(self):
1858+
raise NotImplementedError(
1859+
f"enumerate is not implemented for spec of class {type(self).__name__}."
1860+
)
1861+
18351862
def __eq__(self, other):
18361863
return (
18371864
type(other) == type(self)
@@ -2107,6 +2134,9 @@ def __init__(
21072134
shape=shape, space=None, device=device, dtype=dtype, domain=domain, **kwargs
21082135
)
21092136

2137+
def enumerate(self):
2138+
raise NotImplementedError("Cannot enumerate a NonTensorSpec.")
2139+
21102140
def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> NonTensorSpec:
21112141
if isinstance(dest, torch.dtype):
21122142
dest_dtype = dest
@@ -2273,6 +2303,9 @@ def is_in(self, val: torch.Tensor) -> bool:
22732303
def _project(self, val: torch.Tensor) -> torch.Tensor:
22742304
return torch.as_tensor(val, dtype=self.dtype).reshape(self.shape)
22752305

2306+
def enumerate(self):
2307+
raise NotImplementedError("enumerate cannot be called with continuous specs.")
2308+
22762309
def expand(self, *shape):
22772310
if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)):
22782311
shape = shape[0]
@@ -2361,8 +2394,6 @@ class UnboundedDiscreteTensorSpec(TensorSpec):
23612394
(should be an integer dtype such as long, uint8 etc.)
23622395
"""
23632396

2364-
# SPEC_HANDLED_FUNCTIONS = {}
2365-
23662397
def __init__(
23672398
self,
23682399
shape: Union[torch.Size, int] = _DEFAULT_SHAPE,
@@ -2409,6 +2440,9 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec:
24092440
return self
24102441
return self.__class__(shape=self.shape, device=dest_device, dtype=dest_dtype)
24112442

2443+
def enumerate(self):
2444+
raise NotImplementedError("Cannot enumerate an unbounded tensor spec.")
2445+
24122446
def clone(self) -> UnboundedDiscreteTensorSpec:
24132447
return self.__class__(shape=self.shape, device=self.device, dtype=self.dtype)
24142448

@@ -2553,8 +2587,6 @@ class MultiOneHotDiscreteTensorSpec(OneHotDiscreteTensorSpec):
25532587
25542588
"""
25552589

2556-
# SPEC_HANDLED_FUNCTIONS = {}
2557-
25582590
def __init__(
25592591
self,
25602592
nvec: Sequence[int],
@@ -2586,6 +2618,18 @@ def __init__(
25862618
)
25872619
self.update_mask(mask)
25882620

2621+
def enumerate(self):
2622+
nvec = self.nvec
2623+
enum_disc = self.to_categorical_spec().enumerate()
2624+
enums = torch.cat(
2625+
[
2626+
torch.nn.functional.one_hot(enum_unb, nv).to(self.dtype)
2627+
for nv, enum_unb in zip(nvec, enum_disc.unbind(-1))
2628+
],
2629+
-1,
2630+
)
2631+
return enums
2632+
25892633
def update_mask(self, mask):
25902634
"""Sets a mask to prevent some of the possible outcomes when a sample is taken.
25912635
@@ -2975,6 +3019,12 @@ def __init__(
29753019
)
29763020
self.update_mask(mask)
29773021

3022+
def enumerate(self):
3023+
arange = torch.arange(self.n, dtype=self.dtype, device=self.device)
3024+
if self.ndim:
3025+
arange = arange.view(-1, *(1,) * self.ndim)
3026+
return arange.expand(self.n, *self.shape)
3027+
29783028
@property
29793029
def n(self):
29803030
return self.space.n
@@ -3428,6 +3478,29 @@ def __init__(
34283478
self.update_mask(mask)
34293479
self.remove_singleton = remove_singleton
34303480

3481+
def enumerate(self):
3482+
if self.mask is not None:
3483+
raise RuntimeError(
3484+
"Cannot enumerate a masked TensorSpec. Submit an issue on github if this feature is requested."
3485+
)
3486+
if self.nvec._base.ndim == 1:
3487+
nvec = self.nvec._base
3488+
else:
3489+
# we have to use unique() to isolate the nvec
3490+
nvec = self.nvec.view(-1, self.nvec.shape[-1]).unique(dim=0).squeeze(0)
3491+
if nvec.ndim > 1:
3492+
raise ValueError(
3493+
f"Cannot call enumerate on heterogeneous nvecs: unique nvecs={nvec}."
3494+
)
3495+
arange = torch.meshgrid(
3496+
*[torch.arange(n, device=self.device, dtype=self.dtype) for n in nvec],
3497+
indexing="ij",
3498+
)
3499+
arange = torch.stack([arange_.reshape(-1) for arange_ in arange], dim=-1)
3500+
arange = arange.view(arange.shape[0], *(1,) * (self.ndim - 1), self.shape[-1])
3501+
arange = arange.expand(arange.shape[0], *self.shape)
3502+
return arange
3503+
34313504
def update_mask(self, mask):
34323505
"""Sets a mask to prevent some of the possible outcomes when a sample is taken.
34333506
@@ -3646,6 +3719,8 @@ def to_one_hot(
36463719

36473720
def to_one_hot_spec(self) -> MultiOneHotDiscreteTensorSpec:
36483721
"""Converts the spec to the equivalent one-hot spec."""
3722+
if self.ndim > 1:
3723+
return torch.stack([spec.to_one_hot_spec() for spec in self.unbind(0)])
36493724
nvec = [_space.n for _space in self.space]
36503725
return MultiOneHotDiscreteTensorSpec(
36513726
nvec,
@@ -4297,6 +4372,33 @@ def clone(self) -> CompositeSpec:
42974372
shape=self.shape,
42984373
)
42994374

4375+
def enumerate(self):
4376+
# We are going to use meshgrid to create samples of all the subspecs in here
4377+
# but first let's get rid of the batch size, we'll put it back later
4378+
self_without_batch = self
4379+
while self_without_batch.ndim:
4380+
self_without_batch = self_without_batch[0]
4381+
samples = {key: spec.enumerate() for key, spec in self_without_batch.items()}
4382+
if samples:
4383+
idx_rep = torch.meshgrid(
4384+
*(torch.arange(s.shape[0]) for s in samples.values()), indexing="ij"
4385+
)
4386+
idx_rep = tuple(idx.reshape(-1) for idx in idx_rep)
4387+
samples = {
4388+
key: sample[idx]
4389+
for ((key, sample), idx) in zip(samples.items(), idx_rep)
4390+
}
4391+
samples = TensorDict(
4392+
samples, batch_size=idx_rep[0].shape[:1], device=self.device
4393+
)
4394+
# Expand
4395+
if self.ndim:
4396+
samples = samples.reshape(-1, *(1,) * self.ndim)
4397+
samples = samples.expand(samples.shape[0], *self.shape)
4398+
else:
4399+
samples = TensorDict(batch_size=self.shape, device=self.device)
4400+
return samples
4401+
43004402
def empty(self):
43014403
"""Create a spec like self, but with no entries."""
43024404
try:
@@ -4547,6 +4649,12 @@ def update(self, dict) -> None:
45474649
self[key] = item
45484650
return self
45494651

4652+
def enumerate(self):
4653+
dim = self.stack_dim
4654+
return LazyStackedTensorDict.maybe_dense_stack(
4655+
[spec.enumerate() for spec in self._specs], dim + 1
4656+
)
4657+
45504658
def __eq__(self, other):
45514659
if not isinstance(other, LazyStackedCompositeSpec):
45524660
return False
@@ -4842,7 +4950,7 @@ def rand(self, shape=None) -> TensorDictBase:
48424950

48434951
# for SPEC_CLASS in [BinaryDiscreteTensorSpec, BoundedTensorSpec, DiscreteTensorSpec, MultiDiscreteTensorSpec, MultiOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, UnboundedContinuousTensorSpec, UnboundedDiscreteTensorSpec]:
48444952
@TensorSpec.implements_for_spec(torch.stack)
4845-
def _stack_specs(list_of_spec, dim, out=None):
4953+
def _stack_specs(list_of_spec, dim=0, out=None):
48464954
if out is not None:
48474955
raise NotImplementedError(
48484956
"In-place spec modification is not a feature of torchrl, hence "
@@ -4879,7 +4987,7 @@ def _stack_specs(list_of_spec, dim, out=None):
48794987

48804988

48814989
@CompositeSpec.implements_for_spec(torch.stack)
4882-
def _stack_composite_specs(list_of_spec, dim, out=None):
4990+
def _stack_composite_specs(list_of_spec, dim=0, out=None):
48834991
if out is not None:
48844992
raise NotImplementedError(
48854993
"In-place spec modification is not a feature of torchrl, hence "

0 commit comments

Comments
 (0)