Skip to content

Commit a0dfddc

Browse files
thomasbbrunnerVincent Moens
andauthored
[BugFix] Fixes to RenameTransform (#2442)
Co-authored-by: Vincent Moens <vmoens@meta.com>
1 parent b4d543e commit a0dfddc

File tree

2 files changed

+35
-13
lines changed

2 files changed

+35
-13
lines changed

test/test_transforms.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9331,6 +9331,28 @@ def test_transform_inverse(self, create_copy):
93319331
else:
93329332
assert "b" not in tensordict.keys()
93339333

9334+
def test_rename_action(self, create_copy):
9335+
base_env = ContinuousActionVecMockEnv()
9336+
env = base_env.append_transform(
9337+
RenameTransform(
9338+
in_keys=[],
9339+
out_keys=[],
9340+
in_keys_inv=["action"],
9341+
out_keys_inv=[("renamed", "action")],
9342+
create_copy=create_copy,
9343+
)
9344+
)
9345+
r = env.rollout(3)
9346+
assert ("renamed", "action") in env.action_keys, env.action_keys
9347+
assert ("renamed", "action") in r
9348+
assert env.full_action_spec[("renamed", "action")] is not None
9349+
if create_copy:
9350+
assert "action" in env.action_keys
9351+
assert "action" in r
9352+
else:
9353+
assert "action" not in env.action_keys
9354+
assert "action" not in r
9355+
93349356

93359357
class TestInitTracker(TransformBase):
93369358
@pytest.mark.skipif(not _has_gym, reason="no gym detected")

torchrl/envs/transforms/transforms.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6634,15 +6634,15 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
66346634

66356635

66366636
class RenameTransform(Transform):
6637-
"""A transform to rename entries in the output tensordict.
6637+
"""A transform to rename entries in the output tensordict (or input tensordict via the inverse keys).
66386638
66396639
Args:
6640-
in_keys (sequence of NestedKey): the entries to rename
6640+
in_keys (sequence of NestedKey): the entries to rename.
66416641
out_keys (sequence of NestedKey): the name of the entries after renaming.
6642-
in_keys_inv (sequence of NestedKey, optional): the entries to rename before
6643-
passing the input tensordict to :meth:`EnvBase._step`.
6644-
out_keys_inv (sequence of NestedKey, optional): the names of the renamed
6645-
entries passed to :meth:`EnvBase._step`.
6642+
in_keys_inv (sequence of NestedKey, optional): the entries to rename
6643+
in the input tensordict, which will be passed to :meth:`EnvBase._step`.
6644+
out_keys_inv (sequence of NestedKey, optional): the names of the entries
6645+
in the input tensordict after renaming.
66466646
create_copy (bool, optional): if ``True``, the entries will be copied
66476647
with a different name rather than being renamed. This allows for
66486648
renaming immutable entries such as ``"reward"`` and ``"done"``.
@@ -6713,7 +6713,7 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
67136713
out = tensordict.select(*self.in_keys, strict=not self._missing_tolerance)
67146714
for in_key, out_key in zip(self.in_keys, self.out_keys):
67156715
try:
6716-
tensordict.rename_key_(in_key, out_key)
6716+
out.rename_key_(in_key, out_key)
67176717
except KeyError:
67186718
if not self._missing_tolerance:
67196719
raise
@@ -6802,9 +6802,9 @@ def transform_output_spec(self, output_spec: Composite) -> Composite:
68026802

68036803
def transform_input_spec(self, input_spec: Composite) -> Composite:
68046804
for action_key in self.parent.action_keys:
6805-
if action_key in self.in_keys:
6806-
for i, out_key in enumerate(self.out_keys): # noqa: B007
6807-
if self.in_keys[i] == action_key:
6805+
if action_key in self.in_keys_inv:
6806+
for i, out_key in enumerate(self.out_keys_inv): # noqa: B007
6807+
if self.in_keys_inv[i] == action_key:
68086808
break
68096809
else:
68106810
# unreachable
@@ -6815,9 +6815,9 @@ def transform_input_spec(self, input_spec: Composite) -> Composite:
68156815
if not self.create_copy:
68166816
del input_spec["full_action_spec"][action_key]
68176817
for state_key in self.parent.full_state_spec.keys(True):
6818-
if state_key in self.in_keys:
6819-
for i, out_key in enumerate(self.out_keys): # noqa: B007
6820-
if self.in_keys[i] == state_key:
6818+
if state_key in self.in_keys_inv:
6819+
for i, out_key in enumerate(self.out_keys_inv): # noqa: B007
6820+
if self.in_keys_inv[i] == state_key:
68216821
break
68226822
else:
68236823
# unreachable

0 commit comments

Comments
 (0)