@@ -10383,6 +10383,78 @@ def test_added_transforms_are_in_eval_mode():
10383
10383
10384
10384
10385
10385
class TestTransformedEnv:
10386
+ class DummyCompositeEnv(EnvBase): # type: ignore[misc]
10387
+ """A dummy environment with a composite action set."""
10388
+
10389
+ def __init__(self) -> None:
10390
+ super().__init__()
10391
+
10392
+ self.observation_spec = Composite(
10393
+ observation=UnboundedContinuous((*self.batch_size, 3))
10394
+ )
10395
+
10396
+ self.action_spec = Composite(
10397
+ action=Composite(
10398
+ head_0=Composite(
10399
+ action=Categorical(2, (*self.batch_size, 1), dtype=torch.bool)
10400
+ ),
10401
+ head_1=Composite(
10402
+ action=Categorical(2, (*self.batch_size, 1), dtype=torch.bool)
10403
+ ),
10404
+ )
10405
+ )
10406
+
10407
+ self.done_spec = Categorical(2, (*self.batch_size, 1), dtype=torch.bool)
10408
+
10409
+ self.full_done_spec["truncated"] = self.full_done_spec["terminated"].clone()
10410
+
10411
+ self.reward_spec = UnboundedContinuous(*self.batch_size, 1)
10412
+
10413
+ def _reset(self, tensordict: TensorDict) -> TensorDict:
10414
+ return TensorDict(
10415
+ {"observation": torch.randn((*self.batch_size, 3)), "done": False}
10416
+ )
10417
+
10418
+ def _step(self, tensordict: TensorDict) -> TensorDict:
10419
+ return TensorDict(
10420
+ {
10421
+ "observation": torch.randn((*self.batch_size, 3)),
10422
+ "done": False,
10423
+ "reward": torch.randn((*self.batch_size, 1)),
10424
+ }
10425
+ )
10426
+
10427
+ def _set_seed(self, seed: int) -> None:
10428
+ pass
10429
+
10430
+ class PatchedRenameTransform(RenameTransform): # type: ignore[misc]
10431
+ """
10432
+ Fixes a bug in the RenameTransform due to modifying the input_spec of the `base_env` to be transformed.
10433
+ This is fixed by adding a clone to break stateful modifications to proapagate to the `base_env`.
10434
+ """
10435
+
10436
+ def transform_input_spec(self, input_spec: Composite) -> Composite:
10437
+ input_spec = input_spec.clone()
10438
+ return super().transform_input_spec(input_spec)
10439
+
10440
+ def test_no_modif_specs(self) -> None:
10441
+ base_env = self.DummyCompositeEnv()
10442
+ specs = base_env.specs.clone()
10443
+ transformed_env = TransformedEnv(
10444
+ base_env,
10445
+ RenameTransform(
10446
+ in_keys=[],
10447
+ out_keys=[],
10448
+ in_keys_inv=[("action", "head_0", "action")],
10449
+ out_keys_inv=[("action", "head_99", "action")],
10450
+ ),
10451
+ )
10452
+ td = transformed_env.reset()
10453
+ # A second reset with a TD passed fails due to override of the `input_spec`
10454
+ td = transformed_env.reset(td)
10455
+ specs_after = base_env.specs.clone()
10456
+ assert specs == specs_after
10457
+
10386
10458
@pytest.mark.filterwarnings("error")
10387
10459
def test_nested_transformed_env(self):
10388
10460
base_env = ContinuousActionVecMockEnv()
0 commit comments