@@ -6634,15 +6634,15 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
6634
6634
6635
6635
6636
6636
class RenameTransform (Transform ):
6637
- """A transform to rename entries in the output tensordict.
6637
+ """A transform to rename entries in the output tensordict (or input tensordict via the inverse keys) .
6638
6638
6639
6639
Args:
6640
- in_keys (sequence of NestedKey): the entries to rename
6640
+ in_keys (sequence of NestedKey): the entries to rename.
6641
6641
out_keys (sequence of NestedKey): the name of the entries after renaming.
6642
- in_keys_inv (sequence of NestedKey, optional): the entries to rename before
6643
- passing the input tensordict to :meth:`EnvBase._step`.
6644
- out_keys_inv (sequence of NestedKey, optional): the names of the renamed
6645
- entries passed to :meth:`EnvBase._step` .
6642
+ in_keys_inv (sequence of NestedKey, optional): the entries to rename
6643
+ in the input tensordict, which will be passed to :meth:`EnvBase._step`.
6644
+ out_keys_inv (sequence of NestedKey, optional): the names of the entries
6645
+ in the input tensordict after renaming .
6646
6646
create_copy (bool, optional): if ``True``, the entries will be copied
6647
6647
with a different name rather than being renamed. This allows for
6648
6648
renaming immutable entries such as ``"reward"`` and ``"done"``.
@@ -6713,7 +6713,7 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
6713
6713
out = tensordict .select (* self .in_keys , strict = not self ._missing_tolerance )
6714
6714
for in_key , out_key in zip (self .in_keys , self .out_keys ):
6715
6715
try :
6716
- tensordict .rename_key_ (in_key , out_key )
6716
+ out .rename_key_ (in_key , out_key )
6717
6717
except KeyError :
6718
6718
if not self ._missing_tolerance :
6719
6719
raise
@@ -6802,9 +6802,9 @@ def transform_output_spec(self, output_spec: Composite) -> Composite:
6802
6802
6803
6803
def transform_input_spec (self , input_spec : Composite ) -> Composite :
6804
6804
for action_key in self .parent .action_keys :
6805
- if action_key in self .in_keys :
6806
- for i , out_key in enumerate (self .out_keys ): # noqa: B007
6807
- if self .in_keys [i ] == action_key :
6805
+ if action_key in self .in_keys_inv :
6806
+ for i , out_key in enumerate (self .out_keys_inv ): # noqa: B007
6807
+ if self .in_keys_inv [i ] == action_key :
6808
6808
break
6809
6809
else :
6810
6810
# unreachable
@@ -6815,9 +6815,9 @@ def transform_input_spec(self, input_spec: Composite) -> Composite:
6815
6815
if not self .create_copy :
6816
6816
del input_spec ["full_action_spec" ][action_key ]
6817
6817
for state_key in self .parent .full_state_spec .keys (True ):
6818
- if state_key in self .in_keys :
6819
- for i , out_key in enumerate (self .out_keys ): # noqa: B007
6820
- if self .in_keys [i ] == state_key :
6818
+ if state_key in self .in_keys_inv :
6819
+ for i , out_key in enumerate (self .out_keys_inv ): # noqa: B007
6820
+ if self .in_keys_inv [i ] == state_key :
6821
6821
break
6822
6822
else :
6823
6823
# unreachable
0 commit comments