7
7
8
8
import re
9
9
import warnings
10
+ from collections .abc import MutableSequence
10
11
11
12
from textwrap import indent
12
- from typing import Any , Dict , List , Optional
13
+ from typing import Any , Dict , List , Optional , OrderedDict , overload
13
14
14
15
import torch
15
16
@@ -621,9 +622,12 @@ class ProbabilisticTensorDictSequential(TensorDictSequential):
621
622
log(p(z | x, y))
622
623
623
624
Args:
624
- *modules (sequence of TensorDictModules ): An ordered sequence of
625
- :class:`~tensordict.nn.TensorDictModule` instances, terminating in a :class:`~tensordict.nn.ProbabilisticTensorDictModule`,
625
+ *modules (sequence or OrderedDict of TensorDictModuleBase or ProbabilisticTensorDictModule ): An ordered sequence of
626
+ :class:`~tensordict.nn.TensorDictModule` instances, usually terminating in a :class:`~tensordict.nn.ProbabilisticTensorDictModule`,
626
627
to be run sequentially.
628
+ The modules can be instances of TensorDictModuleBase or any other function that matches this signature.
629
+ Note that if a non-TensorDictModuleBase callable is used, its input and output keys will not be tracked,
630
+ and thus will not affect the `in_keys` and `out_keys` attributes of the TensorDictSequential.
627
631
628
632
Keyword Args:
629
633
partial_tolerant (bool, optional): If ``True``, the input tensordict can miss some
@@ -791,6 +795,28 @@ class ProbabilisticTensorDictSequential(TensorDictSequential):
791
795
792
796
"""
793
797
798
+ @overload
799
+ def __init__ (
800
+ self ,
801
+ modules : OrderedDict [str , TensorDictModuleBase | ProbabilisticTensorDictModule ],
802
+ partial_tolerant : bool = False ,
803
+ return_composite : bool | None = None ,
804
+ aggregate_probabilities : bool | None = None ,
805
+ include_sum : bool | None = None ,
806
+ inplace : bool | None = None ,
807
+ ) -> None : ...
808
+
809
+ @overload
810
+ def __init__ (
811
+ self ,
812
+ modules : List [TensorDictModuleBase | ProbabilisticTensorDictModule ],
813
+ partial_tolerant : bool = False ,
814
+ return_composite : bool | None = None ,
815
+ aggregate_probabilities : bool | None = None ,
816
+ include_sum : bool | None = None ,
817
+ inplace : bool | None = None ,
818
+ ) -> None : ...
819
+
794
820
def __init__ (
795
821
self ,
796
822
* modules : TensorDictModuleBase | ProbabilisticTensorDictModule ,
@@ -805,7 +831,14 @@ def __init__(
805
831
"ProbabilisticTensorDictSequential must consist of zero or more "
806
832
"TensorDictModules followed by a ProbabilisticTensorDictModule"
807
833
)
808
- if not return_composite and not isinstance (
834
+ self ._ordered_dict = False
835
+ if len (modules ) == 1 and isinstance (modules [0 ], (OrderedDict , MutableSequence )):
836
+ if isinstance (modules [0 ], OrderedDict ):
837
+ modules_list = list (modules [0 ].values ())
838
+ self ._ordered_dict = True
839
+ else :
840
+ modules = modules_list = list (modules [0 ])
841
+ elif not return_composite and not isinstance (
809
842
modules [- 1 ],
810
843
(ProbabilisticTensorDictModule , ProbabilisticTensorDictSequential ),
811
844
):
@@ -814,13 +847,22 @@ def __init__(
814
847
"an instance of ProbabilisticTensorDictModule or another "
815
848
"ProbabilisticTensorDictSequential (unless return_composite is set to ``True``)."
816
849
)
850
+ else :
851
+ modules_list = list (modules )
852
+
817
853
# if the modules not including the final probabilistic module return the sampled
818
- # key we wont be sampling it again, in that case
854
+ # key we won't be sampling it again, in that case
819
855
# ProbabilisticTensorDictSequential is presumably used to return the
820
856
# distribution using `get_dist` or to sample log_probabilities
821
- _ , out_keys = self ._compute_in_and_out_keys (modules [:- 1 ])
822
- self ._requires_sample = modules [- 1 ].out_keys [0 ] not in set (out_keys )
823
- self .__dict__ ["_det_part" ] = TensorDictSequential (* modules [:- 1 ])
857
+ _ , out_keys = self ._compute_in_and_out_keys (modules_list [:- 1 ])
858
+ self ._requires_sample = modules_list [- 1 ].out_keys [0 ] not in set (out_keys )
859
+ if self ._ordered_dict :
860
+ self .__dict__ ["_det_part" ] = TensorDictSequential (
861
+ OrderedDict (list (modules [0 ].items ())[:- 1 ])
862
+ )
863
+ else :
864
+ self .__dict__ ["_det_part" ] = TensorDictSequential (* modules [:- 1 ])
865
+
824
866
super ().__init__ (* modules , partial_tolerant = partial_tolerant )
825
867
self .return_composite = return_composite
826
868
self .aggregate_probabilities = aggregate_probabilities
@@ -861,7 +903,7 @@ def get_dist_params(
861
903
tds = self .det_part
862
904
type = interaction_type ()
863
905
if type is None :
864
- for m in reversed (self .module ):
906
+ for m in reversed (list ( self ._module_iter ()) ):
865
907
if hasattr (m , "default_interaction_type" ):
866
908
type = m .default_interaction_type
867
909
break
@@ -873,7 +915,7 @@ def get_dist_params(
873
915
@property
874
916
def num_samples (self ):
875
917
num_samples = ()
876
- for tdm in self .module :
918
+ for tdm in self ._module_iter () :
877
919
if isinstance (
878
920
tdm , (ProbabilisticTensorDictModule , ProbabilisticTensorDictSequential )
879
921
):
@@ -917,7 +959,7 @@ def get_dist(
917
959
918
960
td_copy = tensordict .copy ()
919
961
dists = {}
920
- for i , tdm in enumerate (self .module ):
962
+ for i , tdm in enumerate (self ._module_iter () ):
921
963
if isinstance (
922
964
tdm , (ProbabilisticTensorDictModule , ProbabilisticTensorDictSequential )
923
965
):
@@ -957,12 +999,21 @@ def default_interaction_type(self):
957
999
encountered is returned. If no such value is found, a default `interaction_type()` is returned.
958
1000
959
1001
"""
960
- for m in reversed (self .module ):
1002
+ for m in reversed (list ( self ._module_iter ()) ):
961
1003
interaction = getattr (m , "default_interaction_type" , None )
962
1004
if interaction is not None :
963
1005
return interaction
964
1006
return interaction_type ()
965
1007
1008
+ @property
1009
+ def _last_module (self ):
1010
+ if not self ._ordered_dict :
1011
+ return self .module [- 1 ]
1012
+ mod = None
1013
+ for mod in self ._module_iter (): # noqa: B007
1014
+ continue
1015
+ return mod
1016
+
966
1017
def log_prob (
967
1018
self ,
968
1019
tensordict ,
@@ -1079,7 +1130,7 @@ def log_prob(
1079
1130
include_sum = include_sum ,
1080
1131
** kwargs ,
1081
1132
)
1082
- last_module : ProbabilisticTensorDictModule = self .module [ - 1 ]
1133
+ last_module : ProbabilisticTensorDictModule = self ._last_module
1083
1134
out = last_module .log_prob (tensordict_inp , dist = dist , ** kwargs )
1084
1135
if is_tensor_collection (out ):
1085
1136
if tensordict_out is not None :
@@ -1138,7 +1189,7 @@ def forward(
1138
1189
else :
1139
1190
tensordict_exec = tensordict
1140
1191
if self .return_composite :
1141
- for m in self .module :
1192
+ for m in self ._module_iter () :
1142
1193
if isinstance (
1143
1194
m , (ProbabilisticTensorDictModule , ProbabilisticTensorDictModule )
1144
1195
):
@@ -1149,7 +1200,7 @@ def forward(
1149
1200
tensordict_exec = m (tensordict_exec , ** kwargs )
1150
1201
else :
1151
1202
tensordict_exec = self .get_dist_params (tensordict_exec , ** kwargs )
1152
- tensordict_exec = self .module [ - 1 ] (
1203
+ tensordict_exec = self ._last_module (
1153
1204
tensordict_exec , _requires_sample = self ._requires_sample
1154
1205
)
1155
1206
if tensordict_out is not None :
0 commit comments