@@ -834,6 +834,16 @@ def contains(self, item: torch.Tensor | TensorDictBase) -> bool:
834834 """
835835 return self .is_in (item )
836836
837+ @abc .abstractmethod
838+ def enumerate (self ):
839+ """Returns all the samples that can be obtained from the TensorSpec.
840+
841+ The samples will be stacked along the first dimension.
842+
843+ This method is only implemented for discrete specs.
844+ """
845+ ...
846+
837847 def project (
838848 self , val : torch .Tensor | TensorDictBase
839849 ) -> torch .Tensor | TensorDictBase :
@@ -1271,6 +1281,11 @@ def __eq__(self, other):
12711281 return False
12721282 return True
12731283
1284+ def enumerate (self ):
1285+ return torch .stack (
1286+ [spec .enumerate () for spec in self ._specs ], dim = self .stack_dim + 1
1287+ )
1288+
12741289 def __len__ (self ):
12751290 return self .shape [0 ]
12761291
@@ -1732,6 +1747,13 @@ def to_numpy(self, val: torch.Tensor, safe: bool = None) -> np.ndarray:
17321747 return np .array (vals ).reshape (tuple (val .shape ))
17331748 return val
17341749
1750+ def enumerate (self ):
1751+ return (
1752+ torch .eye (self .n , dtype = self .dtype , device = self .device )
1753+ .expand (* self .shape , self .n )
1754+ .permute (- 2 , * range (self .ndimension () - 1 ), - 1 )
1755+ )
1756+
17351757 def index (self , index : INDEX_TYPING , tensor_to_index : torch .Tensor ) -> torch .Tensor :
17361758 if not isinstance (index , torch .Tensor ):
17371759 raise ValueError (
@@ -2056,6 +2078,11 @@ def __init__(
20562078 domain = domain ,
20572079 )
20582080
2081+ def enumerate (self ):
2082+ raise NotImplementedError (
2083+ f"enumerate is not implemented for spec of class { type (self ).__name__ } ."
2084+ )
2085+
20592086 def __eq__ (self , other ):
20602087 return (
20612088 type (other ) == type (self )
@@ -2375,6 +2402,9 @@ def __init__(
23752402 shape = shape , space = None , device = device , dtype = dtype , domain = domain , ** kwargs
23762403 )
23772404
2405+ def enumerate (self ):
2406+ raise NotImplementedError ("Cannot enumerate a NonTensorSpec." )
2407+
23782408 def to (self , dest : Union [torch .dtype , DEVICE_TYPING ]) -> NonTensor :
23792409 if isinstance (dest , torch .dtype ):
23802410 dest_dtype = dest
@@ -2611,6 +2641,9 @@ def is_in(self, val: torch.Tensor) -> bool:
26112641 def _project (self , val : torch .Tensor ) -> torch .Tensor :
26122642 return torch .as_tensor (val , dtype = self .dtype ).reshape (self .shape )
26132643
2644+ def enumerate (self ):
2645+ raise NotImplementedError ("enumerate cannot be called with continuous specs." )
2646+
26142647 def expand (self , * shape ):
26152648 if len (shape ) == 1 and isinstance (shape [0 ], (tuple , list , torch .Size )):
26162649 shape = shape [0 ]
@@ -2775,6 +2808,18 @@ def __init__(
27752808 )
27762809 self .update_mask (mask )
27772810
2811+ def enumerate (self ):
2812+ nvec = self .nvec
2813+ enum_disc = self .to_categorical_spec ().enumerate ()
2814+ enums = torch .cat (
2815+ [
2816+ torch .nn .functional .one_hot (enum_unb , nv ).to (self .dtype )
2817+ for nv , enum_unb in zip (nvec , enum_disc .unbind (- 1 ))
2818+ ],
2819+ - 1 ,
2820+ )
2821+ return enums
2822+
27782823 def update_mask (self , mask ):
27792824 """Sets a mask to prevent some of the possible outcomes when a sample is taken.
27802825
@@ -3208,6 +3253,12 @@ def __init__(
32083253 )
32093254 self .update_mask (mask )
32103255
3256+ def enumerate (self ):
3257+ arange = torch .arange (self .n , dtype = self .dtype , device = self .device )
3258+ if self .ndim :
3259+ arange = arange .view (- 1 , * (1 ,) * self .ndim )
3260+ return arange .expand (self .n , * self .shape )
3261+
32113262 @property
32123263 def n (self ):
32133264 return self .space .n
@@ -3715,6 +3766,29 @@ def __init__(
37153766 self .update_mask (mask )
37163767 self .remove_singleton = remove_singleton
37173768
3769+ def enumerate (self ):
3770+ if self .mask is not None :
3771+ raise RuntimeError (
3772+ "Cannot enumerate a masked TensorSpec. Submit an issue on github if this feature is requested."
3773+ )
3774+ if self .nvec ._base .ndim == 1 :
3775+ nvec = self .nvec ._base
3776+ else :
3777+ # we have to use unique() to isolate the nvec
3778+ nvec = self .nvec .view (- 1 , self .nvec .shape [- 1 ]).unique (dim = 0 ).squeeze (0 )
3779+ if nvec .ndim > 1 :
3780+ raise ValueError (
3781+ f"Cannot call enumerate on heterogeneous nvecs: unique nvecs={ nvec } ."
3782+ )
3783+ arange = torch .meshgrid (
3784+ * [torch .arange (n , device = self .device , dtype = self .dtype ) for n in nvec ],
3785+ indexing = "ij" ,
3786+ )
3787+ arange = torch .stack ([arange_ .reshape (- 1 ) for arange_ in arange ], dim = - 1 )
3788+ arange = arange .view (arange .shape [0 ], * (1 ,) * (self .ndim - 1 ), self .shape [- 1 ])
3789+ arange = arange .expand (arange .shape [0 ], * self .shape )
3790+ return arange
3791+
37183792 def update_mask (self , mask ):
37193793 """Sets a mask to prevent some of the possible outcomes when a sample is taken.
37203794
@@ -3932,6 +4006,8 @@ def to_one_hot(
39324006
39334007 def to_one_hot_spec (self ) -> MultiOneHot :
39344008 """Converts the spec to the equivalent one-hot spec."""
4009+ if self .ndim > 1 :
4010+ return torch .stack ([spec .to_one_hot_spec () for spec in self .unbind (0 )])
39354011 nvec = [_space .n for _space in self .space ]
39364012 return MultiOneHot (
39374013 nvec ,
@@ -4606,6 +4682,33 @@ def clone(self) -> Composite:
46064682 shape = self .shape ,
46074683 )
46084684
4685+ def enumerate (self ):
4686+ # We are going to use meshgrid to create samples of all the subspecs in here
4687+ # but first let's get rid of the batch size, we'll put it back later
4688+ self_without_batch = self
4689+ while self_without_batch .ndim :
4690+ self_without_batch = self_without_batch [0 ]
4691+ samples = {key : spec .enumerate () for key , spec in self_without_batch .items ()}
4692+ if samples :
4693+ idx_rep = torch .meshgrid (
4694+ * (torch .arange (s .shape [0 ]) for s in samples .values ()), indexing = "ij"
4695+ )
4696+ idx_rep = tuple (idx .reshape (- 1 ) for idx in idx_rep )
4697+ samples = {
4698+ key : sample [idx ]
4699+ for ((key , sample ), idx ) in zip (samples .items (), idx_rep )
4700+ }
4701+ samples = TensorDict (
4702+ samples , batch_size = idx_rep [0 ].shape [:1 ], device = self .device
4703+ )
4704+ # Expand
4705+ if self .ndim :
4706+ samples = samples .reshape (- 1 , * (1 ,) * self .ndim )
4707+ samples = samples .expand (samples .shape [0 ], * self .shape )
4708+ else :
4709+ samples = TensorDict (batch_size = self .shape , device = self .device )
4710+ return samples
4711+
46094712 def empty (self ):
46104713 """Create a spec like self, but with no entries."""
46114714 try :
@@ -4856,6 +4959,12 @@ def update(self, dict) -> None:
48564959 self [key ] = item
48574960 return self
48584961
4962+ def enumerate (self ):
4963+ dim = self .stack_dim
4964+ return LazyStackedTensorDict .maybe_dense_stack (
4965+ [spec .enumerate () for spec in self ._specs ], dim + 1
4966+ )
4967+
48594968 def __eq__ (self , other ):
48604969 if not isinstance (other , StackedComposite ):
48614970 return False
@@ -5150,7 +5259,7 @@ def rand(self, shape: torch.Size = None) -> TensorDictBase:
51505259
51515260
51525261@TensorSpec .implements_for_spec (torch .stack )
5153- def _stack_specs (list_of_spec , dim , out = None ):
5262+ def _stack_specs (list_of_spec , dim = 0 , out = None ):
51545263 if out is not None :
51555264 raise NotImplementedError (
51565265 "In-place spec modification is not a feature of torchrl, hence "
@@ -5187,7 +5296,7 @@ def _stack_specs(list_of_spec, dim, out=None):
51875296
51885297
51895298@Composite .implements_for_spec (torch .stack )
5190- def _stack_composite_specs (list_of_spec , dim , out = None ):
5299+ def _stack_composite_specs (list_of_spec , dim = 0 , out = None ):
51915300 if out is not None :
51925301 raise NotImplementedError (
51935302 "In-place spec modification is not a feature of torchrl, hence "
0 commit comments