@@ -756,6 +756,16 @@ def contains(self, item):
756756 """
757757 return self .is_in (item )
758758
759+ @abc .abstractmethod
760+ def enumerate (self ):
761+ """Returns all the samples that can be obtained from the TensorSpec.
762+
763+ The samples will be stacked along the first dimension.
764+
765+ This method is only implemented for discrete specs.
766+ """
767+ ...
768+
759769 def project (self , val : torch .Tensor ) -> torch .Tensor :
760770 """If the input tensor is not in the TensorSpec box, it maps it back to it given some heuristic.
761771
@@ -1152,6 +1162,11 @@ def __eq__(self, other):
11521162 return False
11531163 return True
11541164
1165+ def enumerate (self ):
1166+ return torch .stack (
1167+ [spec .enumerate () for spec in self ._specs ], dim = self .stack_dim + 1
1168+ )
1169+
11551170 def __len__ (self ):
11561171 return self .shape [0 ]
11571172
@@ -1601,6 +1616,13 @@ def to_numpy(self, val: torch.Tensor, safe: bool = None) -> np.ndarray:
16011616 return np .array (vals ).reshape (tuple (val .shape ))
16021617 return val
16031618
1619+ def enumerate (self ):
1620+ return (
1621+ torch .eye (self .n , dtype = self .dtype , device = self .device )
1622+ .expand (* self .shape , self .n )
1623+ .permute (- 2 , * range (self .ndimension () - 1 ), - 1 )
1624+ )
1625+
16041626 def index (self , index : INDEX_TYPING , tensor_to_index : torch .Tensor ) -> torch .Tensor :
16051627 if not isinstance (index , torch .Tensor ):
16061628 raise ValueError (
@@ -1832,6 +1854,11 @@ def __init__(
18321854 domain = domain ,
18331855 )
18341856
1857+ def enumerate (self ):
1858+ raise NotImplementedError (
1859+ f"enumerate is not implemented for spec of class { type (self ).__name__ } ."
1860+ )
1861+
18351862 def __eq__ (self , other ):
18361863 return (
18371864 type (other ) == type (self )
@@ -2107,6 +2134,9 @@ def __init__(
21072134 shape = shape , space = None , device = device , dtype = dtype , domain = domain , ** kwargs
21082135 )
21092136
2137+ def enumerate (self ):
2138+ raise NotImplementedError ("Cannot enumerate a NonTensorSpec." )
2139+
21102140 def to (self , dest : Union [torch .dtype , DEVICE_TYPING ]) -> NonTensorSpec :
21112141 if isinstance (dest , torch .dtype ):
21122142 dest_dtype = dest
@@ -2273,6 +2303,9 @@ def is_in(self, val: torch.Tensor) -> bool:
22732303 def _project (self , val : torch .Tensor ) -> torch .Tensor :
22742304 return torch .as_tensor (val , dtype = self .dtype ).reshape (self .shape )
22752305
2306+ def enumerate (self ):
2307+ raise NotImplementedError ("enumerate cannot be called with continuous specs." )
2308+
22762309 def expand (self , * shape ):
22772310 if len (shape ) == 1 and isinstance (shape [0 ], (tuple , list , torch .Size )):
22782311 shape = shape [0 ]
@@ -2361,8 +2394,6 @@ class UnboundedDiscreteTensorSpec(TensorSpec):
23612394 (should be an integer dtype such as long, uint8 etc.)
23622395 """
23632396
2364- # SPEC_HANDLED_FUNCTIONS = {}
2365-
23662397 def __init__ (
23672398 self ,
23682399 shape : Union [torch .Size , int ] = _DEFAULT_SHAPE ,
@@ -2409,6 +2440,9 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec:
24092440 return self
24102441 return self .__class__ (shape = self .shape , device = dest_device , dtype = dest_dtype )
24112442
2443+ def enumerate (self ):
2444+ raise NotImplementedError ("Cannot enumerate an unbounded tensor spec." )
2445+
24122446 def clone (self ) -> UnboundedDiscreteTensorSpec :
24132447 return self .__class__ (shape = self .shape , device = self .device , dtype = self .dtype )
24142448
@@ -2553,8 +2587,6 @@ class MultiOneHotDiscreteTensorSpec(OneHotDiscreteTensorSpec):
25532587
25542588 """
25552589
2556- # SPEC_HANDLED_FUNCTIONS = {}
2557-
25582590 def __init__ (
25592591 self ,
25602592 nvec : Sequence [int ],
@@ -2586,6 +2618,18 @@ def __init__(
25862618 )
25872619 self .update_mask (mask )
25882620
2621+ def enumerate (self ):
2622+ nvec = self .nvec
2623+ enum_disc = self .to_categorical_spec ().enumerate ()
2624+ enums = torch .cat (
2625+ [
2626+ torch .nn .functional .one_hot (enum_unb , nv ).to (self .dtype )
2627+ for nv , enum_unb in zip (nvec , enum_disc .unbind (- 1 ))
2628+ ],
2629+ - 1 ,
2630+ )
2631+ return enums
2632+
25892633 def update_mask (self , mask ):
25902634 """Sets a mask to prevent some of the possible outcomes when a sample is taken.
25912635
@@ -2975,6 +3019,12 @@ def __init__(
29753019 )
29763020 self .update_mask (mask )
29773021
3022+ def enumerate (self ):
3023+ arange = torch .arange (self .n , dtype = self .dtype , device = self .device )
3024+ if self .ndim :
3025+ arange = arange .view (- 1 , * (1 ,) * self .ndim )
3026+ return arange .expand (self .n , * self .shape )
3027+
29783028 @property
29793029 def n (self ):
29803030 return self .space .n
@@ -3428,6 +3478,29 @@ def __init__(
34283478 self .update_mask (mask )
34293479 self .remove_singleton = remove_singleton
34303480
3481+ def enumerate (self ):
3482+ if self .mask is not None :
3483+ raise RuntimeError (
3484+ "Cannot enumerate a masked TensorSpec. Submit an issue on github if this feature is requested."
3485+ )
3486+ if self .nvec ._base .ndim == 1 :
3487+ nvec = self .nvec ._base
3488+ else :
3489+ # we have to use unique() to isolate the nvec
3490+ nvec = self .nvec .view (- 1 , self .nvec .shape [- 1 ]).unique (dim = 0 ).squeeze (0 )
3491+ if nvec .ndim > 1 :
3492+ raise ValueError (
3493+ f"Cannot call enumerate on heterogeneous nvecs: unique nvecs={ nvec } ."
3494+ )
3495+ arange = torch .meshgrid (
3496+ * [torch .arange (n , device = self .device , dtype = self .dtype ) for n in nvec ],
3497+ indexing = "ij" ,
3498+ )
3499+ arange = torch .stack ([arange_ .reshape (- 1 ) for arange_ in arange ], dim = - 1 )
3500+ arange = arange .view (arange .shape [0 ], * (1 ,) * (self .ndim - 1 ), self .shape [- 1 ])
3501+ arange = arange .expand (arange .shape [0 ], * self .shape )
3502+ return arange
3503+
34313504 def update_mask (self , mask ):
34323505 """Sets a mask to prevent some of the possible outcomes when a sample is taken.
34333506
@@ -3646,6 +3719,8 @@ def to_one_hot(
36463719
36473720 def to_one_hot_spec (self ) -> MultiOneHotDiscreteTensorSpec :
36483721 """Converts the spec to the equivalent one-hot spec."""
3722+ if self .ndim > 1 :
3723+ return torch .stack ([spec .to_one_hot_spec () for spec in self .unbind (0 )])
36493724 nvec = [_space .n for _space in self .space ]
36503725 return MultiOneHotDiscreteTensorSpec (
36513726 nvec ,
@@ -4297,6 +4372,33 @@ def clone(self) -> CompositeSpec:
42974372 shape = self .shape ,
42984373 )
42994374
4375+ def enumerate (self ):
4376+ # We are going to use meshgrid to create samples of all the subspecs in here
4377+ # but first let's get rid of the batch size, we'll put it back later
4378+ self_without_batch = self
4379+ while self_without_batch .ndim :
4380+ self_without_batch = self_without_batch [0 ]
4381+ samples = {key : spec .enumerate () for key , spec in self_without_batch .items ()}
4382+ if samples :
4383+ idx_rep = torch .meshgrid (
4384+ * (torch .arange (s .shape [0 ]) for s in samples .values ()), indexing = "ij"
4385+ )
4386+ idx_rep = tuple (idx .reshape (- 1 ) for idx in idx_rep )
4387+ samples = {
4388+ key : sample [idx ]
4389+ for ((key , sample ), idx ) in zip (samples .items (), idx_rep )
4390+ }
4391+ samples = TensorDict (
4392+ samples , batch_size = idx_rep [0 ].shape [:1 ], device = self .device
4393+ )
4394+ # Expand
4395+ if self .ndim :
4396+ samples = samples .reshape (- 1 , * (1 ,) * self .ndim )
4397+ samples = samples .expand (samples .shape [0 ], * self .shape )
4398+ else :
4399+ samples = TensorDict (batch_size = self .shape , device = self .device )
4400+ return samples
4401+
43004402 def empty (self ):
43014403 """Create a spec like self, but with no entries."""
43024404 try :
@@ -4547,6 +4649,12 @@ def update(self, dict) -> None:
45474649 self [key ] = item
45484650 return self
45494651
4652+ def enumerate (self ):
4653+ dim = self .stack_dim
4654+ return LazyStackedTensorDict .maybe_dense_stack (
4655+ [spec .enumerate () for spec in self ._specs ], dim + 1
4656+ )
4657+
45504658 def __eq__ (self , other ):
45514659 if not isinstance (other , LazyStackedCompositeSpec ):
45524660 return False
@@ -4842,7 +4950,7 @@ def rand(self, shape=None) -> TensorDictBase:
48424950
48434951# for SPEC_CLASS in [BinaryDiscreteTensorSpec, BoundedTensorSpec, DiscreteTensorSpec, MultiDiscreteTensorSpec, MultiOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, UnboundedContinuousTensorSpec, UnboundedDiscreteTensorSpec]:
48444952@TensorSpec .implements_for_spec (torch .stack )
4845- def _stack_specs (list_of_spec , dim , out = None ):
4953+ def _stack_specs (list_of_spec , dim = 0 , out = None ):
48464954 if out is not None :
48474955 raise NotImplementedError (
48484956 "In-place spec modification is not a feature of torchrl, hence "
@@ -4879,7 +4987,7 @@ def _stack_specs(list_of_spec, dim, out=None):
48794987
48804988
48814989@CompositeSpec .implements_for_spec (torch .stack )
4882- def _stack_composite_specs (list_of_spec , dim , out = None ):
4990+ def _stack_composite_specs (list_of_spec , dim = 0 , out = None ):
48834991 if out is not None :
48844992 raise NotImplementedError (
48854993 "In-place spec modification is not a feature of torchrl, hence "
0 commit comments