Skip to content

Commit 7df2062

Browse files
author
Vincent Moens
committed
[Feature] OrderedDict for TensorDictSequential
ghstack-source-id: a8aed1e Pull Request resolved: #1142
1 parent 2aea3dd commit 7df2062

File tree

3 files changed

+212
-37
lines changed

3 files changed

+212
-37
lines changed

tensordict/nn/probabilistic.py

Lines changed: 66 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77

88
import re
99
import warnings
10+
from collections.abc import MutableSequence
1011

1112
from textwrap import indent
12-
from typing import Any, Dict, List, Optional
13+
from typing import Any, Dict, List, Optional, OrderedDict, overload
1314

1415
import torch
1516

@@ -621,9 +622,12 @@ class ProbabilisticTensorDictSequential(TensorDictSequential):
621622
log(p(z | x, y))
622623
623624
Args:
624-
*modules (sequence of TensorDictModules): An ordered sequence of
625-
:class:`~tensordict.nn.TensorDictModule` instances, terminating in a :class:`~tensordict.nn.ProbabilisticTensorDictModule`,
625+
*modules (sequence or OrderedDict of TensorDictModuleBase or ProbabilisticTensorDictModule): An ordered sequence of
626+
:class:`~tensordict.nn.TensorDictModule` instances, usually terminating in a :class:`~tensordict.nn.ProbabilisticTensorDictModule`,
626627
to be run sequentially.
628+
The modules can be instances of TensorDictModuleBase or any other function that matches this signature.
629+
Note that if a non-TensorDictModuleBase callable is used, its input and output keys will not be tracked,
630+
and thus will not affect the `in_keys` and `out_keys` attributes of the TensorDictSequential.
627631
628632
Keyword Args:
629633
partial_tolerant (bool, optional): If ``True``, the input tensordict can miss some
@@ -791,6 +795,28 @@ class ProbabilisticTensorDictSequential(TensorDictSequential):
791795
792796
"""
793797

798+
@overload
799+
def __init__(
800+
self,
801+
modules: OrderedDict[str, TensorDictModuleBase | ProbabilisticTensorDictModule],
802+
partial_tolerant: bool = False,
803+
return_composite: bool | None = None,
804+
aggregate_probabilities: bool | None = None,
805+
include_sum: bool | None = None,
806+
inplace: bool | None = None,
807+
) -> None: ...
808+
809+
@overload
810+
def __init__(
811+
self,
812+
modules: List[TensorDictModuleBase | ProbabilisticTensorDictModule],
813+
partial_tolerant: bool = False,
814+
return_composite: bool | None = None,
815+
aggregate_probabilities: bool | None = None,
816+
include_sum: bool | None = None,
817+
inplace: bool | None = None,
818+
) -> None: ...
819+
794820
def __init__(
795821
self,
796822
*modules: TensorDictModuleBase | ProbabilisticTensorDictModule,
@@ -805,7 +831,14 @@ def __init__(
805831
"ProbabilisticTensorDictSequential must consist of zero or more "
806832
"TensorDictModules followed by a ProbabilisticTensorDictModule"
807833
)
808-
if not return_composite and not isinstance(
834+
self._ordered_dict = False
835+
if len(modules) == 1 and isinstance(modules[0], (OrderedDict, MutableSequence)):
836+
if isinstance(modules[0], OrderedDict):
837+
modules_list = list(modules[0].values())
838+
self._ordered_dict = True
839+
else:
840+
modules = modules_list = list(modules[0])
841+
elif not return_composite and not isinstance(
809842
modules[-1],
810843
(ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential),
811844
):
@@ -814,13 +847,22 @@ def __init__(
814847
"an instance of ProbabilisticTensorDictModule or another "
815848
"ProbabilisticTensorDictSequential (unless return_composite is set to ``True``)."
816849
)
850+
else:
851+
modules_list = list(modules)
852+
817853
# if the modules not including the final probabilistic module return the sampled
818-
# key we wont be sampling it again, in that case
854+
# key we won't be sampling it again, in that case
819855
# ProbabilisticTensorDictSequential is presumably used to return the
820856
# distribution using `get_dist` or to sample log_probabilities
821-
_, out_keys = self._compute_in_and_out_keys(modules[:-1])
822-
self._requires_sample = modules[-1].out_keys[0] not in set(out_keys)
823-
self.__dict__["_det_part"] = TensorDictSequential(*modules[:-1])
857+
_, out_keys = self._compute_in_and_out_keys(modules_list[:-1])
858+
self._requires_sample = modules_list[-1].out_keys[0] not in set(out_keys)
859+
if self._ordered_dict:
860+
self.__dict__["_det_part"] = TensorDictSequential(
861+
OrderedDict(list(modules[0].items())[:-1])
862+
)
863+
else:
864+
self.__dict__["_det_part"] = TensorDictSequential(*modules[:-1])
865+
824866
super().__init__(*modules, partial_tolerant=partial_tolerant)
825867
self.return_composite = return_composite
826868
self.aggregate_probabilities = aggregate_probabilities
@@ -861,7 +903,7 @@ def get_dist_params(
861903
tds = self.det_part
862904
type = interaction_type()
863905
if type is None:
864-
for m in reversed(self.module):
906+
for m in reversed(list(self._module_iter())):
865907
if hasattr(m, "default_interaction_type"):
866908
type = m.default_interaction_type
867909
break
@@ -873,7 +915,7 @@ def get_dist_params(
873915
@property
874916
def num_samples(self):
875917
num_samples = ()
876-
for tdm in self.module:
918+
for tdm in self._module_iter():
877919
if isinstance(
878920
tdm, (ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential)
879921
):
@@ -917,7 +959,7 @@ def get_dist(
917959

918960
td_copy = tensordict.copy()
919961
dists = {}
920-
for i, tdm in enumerate(self.module):
962+
for i, tdm in enumerate(self._module_iter()):
921963
if isinstance(
922964
tdm, (ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential)
923965
):
@@ -957,12 +999,21 @@ def default_interaction_type(self):
957999
encountered is returned. If no such value is found, a default `interaction_type()` is returned.
9581000
9591001
"""
960-
for m in reversed(self.module):
1002+
for m in reversed(list(self._module_iter())):
9611003
interaction = getattr(m, "default_interaction_type", None)
9621004
if interaction is not None:
9631005
return interaction
9641006
return interaction_type()
9651007

1008+
@property
1009+
def _last_module(self):
1010+
if not self._ordered_dict:
1011+
return self.module[-1]
1012+
mod = None
1013+
for mod in self._module_iter(): # noqa: B007
1014+
continue
1015+
return mod
1016+
9661017
def log_prob(
9671018
self,
9681019
tensordict,
@@ -1079,7 +1130,7 @@ def log_prob(
10791130
include_sum=include_sum,
10801131
**kwargs,
10811132
)
1082-
last_module: ProbabilisticTensorDictModule = self.module[-1]
1133+
last_module: ProbabilisticTensorDictModule = self._last_module
10831134
out = last_module.log_prob(tensordict_inp, dist=dist, **kwargs)
10841135
if is_tensor_collection(out):
10851136
if tensordict_out is not None:
@@ -1138,7 +1189,7 @@ def forward(
11381189
else:
11391190
tensordict_exec = tensordict
11401191
if self.return_composite:
1141-
for m in self.module:
1192+
for m in self._module_iter():
11421193
if isinstance(
11431194
m, (ProbabilisticTensorDictModule, ProbabilisticTensorDictModule)
11441195
):
@@ -1149,7 +1200,7 @@ def forward(
11491200
tensordict_exec = m(tensordict_exec, **kwargs)
11501201
else:
11511202
tensordict_exec = self.get_dist_params(tensordict_exec, **kwargs)
1152-
tensordict_exec = self.module[-1](
1203+
tensordict_exec = self._last_module(
11531204
tensordict_exec, _requires_sample=self._requires_sample
11541205
)
11551206
if tensordict_out is not None:

tensordict/nn/sequence.py

Lines changed: 75 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55

66
from __future__ import annotations
77

8+
import collections
89
import logging
910
from copy import deepcopy
10-
from typing import Any, Iterable, List
11+
from typing import Any, Callable, Iterable, List, OrderedDict, overload
1112

1213
from tensordict._nestedkey import NestedKey
1314

@@ -52,14 +53,18 @@ class TensorDictSequential(TensorDictModule):
5253
buffers) will be concatenated in a single list.
5354
5455
Args:
55-
modules (iterable of TensorDictModules): ordered sequence of TensorDictModule instances to be run sequentially.
56+
modules (OrderedDict[str, Callable[[TensorDictBase], TensorDictBase]] | List[Callable[[TensorDictBase], TensorDictBase]]):
57+
ordered sequence of callables that take a TensorDictBase as input and return a TensorDictBase.
58+
These can be instances of TensorDictModuleBase or any other function that matches this signature.
59+
Note that if a non-TensorDictModuleBase callable is used, its input and output keys will not be tracked,
60+
and thus will not affect the `in_keys` and `out_keys` attributes of the TensorDictSequential.
5661
Keyword Args:
5762
partial_tolerant (bool, optional): if True, the input tensordict can miss some of the input keys.
5863
If so, the only module that will be executed are those who can be executed given the keys that
5964
are present.
6065
Also, if the input tensordict is a lazy stack of tensordicts AND if partial_tolerant is :obj:`True` AND if the
6166
stack does not have the required keys, then TensorDictSequential will scan through the sub-tensordicts
62-
looking for those that have the required keys, if any.
67+
looking for those that have the required keys, if any. Defaults to False.
6368
selected_out_keys (iterable of NestedKeys, optional): the list of out-keys to select. If not provided, all
6469
``out_keys`` will be written.
6570
@@ -170,19 +175,57 @@ class TensorDictSequential(TensorDictModule):
170175
module: nn.ModuleList
171176
_select_before_return = False
172177

178+
@overload
173179
def __init__(
174180
self,
175-
*modules: TensorDictModuleBase,
181+
modules: OrderedDict[str, Callable[[TensorDictBase], TensorDictBase]],
182+
*,
183+
partial_tolerant: bool = False,
184+
selected_out_keys: List[NestedKey] | None = None,
185+
) -> None: ...
186+
187+
@overload
188+
def __init__(
189+
self,
190+
modules: List[Callable[[TensorDictBase], TensorDictBase]],
191+
*,
192+
partial_tolerant: bool = False,
193+
selected_out_keys: List[NestedKey] | None = None,
194+
) -> None: ...
195+
196+
def __init__(
197+
self,
198+
*modules: Callable[[TensorDictBase], TensorDictBase],
176199
partial_tolerant: bool = False,
177200
selected_out_keys: List[NestedKey] | None = None,
178201
) -> None:
179-
modules = self._convert_modules(modules)
180-
in_keys, out_keys = self._compute_in_and_out_keys(modules)
181-
self._complete_out_keys = list(out_keys)
182202

183-
super().__init__(
184-
module=nn.ModuleList(list(modules)), in_keys=in_keys, out_keys=out_keys
185-
)
203+
if len(modules) == 1 and isinstance(modules[0], collections.OrderedDict):
204+
modules_vals = self._convert_modules(modules[0].values())
205+
in_keys, out_keys = self._compute_in_and_out_keys(modules_vals)
206+
self._complete_out_keys = list(out_keys)
207+
modules = collections.OrderedDict(
208+
**{key: val for key, val in zip(modules[0], modules_vals)}
209+
)
210+
super().__init__(
211+
module=nn.ModuleDict(modules), in_keys=in_keys, out_keys=out_keys
212+
)
213+
elif len(modules) == 1 and isinstance(
214+
modules[0], collections.abc.MutableSequence
215+
):
216+
modules = self._convert_modules(modules[0])
217+
in_keys, out_keys = self._compute_in_and_out_keys(modules)
218+
self._complete_out_keys = list(out_keys)
219+
super().__init__(
220+
module=nn.ModuleList(modules), in_keys=in_keys, out_keys=out_keys
221+
)
222+
else:
223+
modules = self._convert_modules(modules)
224+
in_keys, out_keys = self._compute_in_and_out_keys(modules)
225+
self._complete_out_keys = list(out_keys)
226+
super().__init__(
227+
module=nn.ModuleList(list(modules)), in_keys=in_keys, out_keys=out_keys
228+
)
186229

187230
self.partial_tolerant = partial_tolerant
188231
if selected_out_keys:
@@ -408,7 +451,7 @@ def select_subsequence(
408451
out_keys = deepcopy(self.out_keys)
409452
out_keys = unravel_key_list(out_keys)
410453

411-
module_list = list(self.module)
454+
module_list = list(self._module_iter())
412455
id_to_keep = set(range(len(module_list)))
413456
for i, module in enumerate(module_list):
414457
if (
@@ -445,8 +488,14 @@ def select_subsequence(
445488
raise ValueError(
446489
"No modules left after selection. Make sure that in_keys and out_keys are coherent."
447490
)
448-
449-
return type(self)(*modules)
491+
if isinstance(self.module, nn.ModuleList):
492+
return type(self)(*modules)
493+
else:
494+
keys = [key for key in self.module if self.module[key] in modules]
495+
modules_dict = collections.OrderedDict(
496+
**{key: val for key, val in zip(keys, modules)}
497+
)
498+
return type(self)(modules_dict)
450499

451500
def _run_module(
452501
self,
@@ -466,6 +515,12 @@ def _run_module(
466515
module(sub_td, **kwargs)
467516
return tensordict
468517

518+
def _module_iter(self):
519+
if isinstance(self.module, nn.ModuleDict):
520+
yield from self.module.children()
521+
else:
522+
yield from self.module
523+
469524
@dispatch(auto_batch_size=False)
470525
@_set_skip_existing_None()
471526
def forward(
@@ -481,7 +536,7 @@ def forward(
481536
else:
482537
tensordict_exec = tensordict
483538
if not len(kwargs):
484-
for module in self.module:
539+
for module in self._module_iter():
485540
tensordict_exec = self._run_module(module, tensordict_exec, **kwargs)
486541
else:
487542
raise RuntimeError(
@@ -510,14 +565,16 @@ def forward(
510565
def __len__(self) -> int:
511566
return len(self.module)
512567

513-
def __getitem__(self, index: int | slice) -> TensorDictModuleBase:
514-
if isinstance(index, int):
568+
def __getitem__(self, index: int | slice | str) -> TensorDictModuleBase:
569+
if isinstance(index, (int, str)):
515570
return self.module.__getitem__(index)
516571
else:
517572
return type(self)(*self.module.__getitem__(index))
518573

519-
def __setitem__(self, index: int, tensordict_module: TensorDictModuleBase) -> None:
574+
def __setitem__(
575+
self, index: int | slice | str, tensordict_module: TensorDictModuleBase
576+
) -> None:
520577
return self.module.__setitem__(idx=index, module=tensordict_module)
521578

522-
def __delitem__(self, index: int | slice) -> None:
579+
def __delitem__(self, index: int | slice | str) -> None:
523580
self.module.__delitem__(idx=index)

0 commit comments

Comments
 (0)