Skip to content

Commit abb24dc

Browse files
committed
[BugFix] Fix in-place modification of specs in TransformedEnv (#3076)
1 parent 4077803 commit abb24dc

File tree

5 files changed

+109
-5
lines changed

5 files changed

+109
-5
lines changed

test/test_specs.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3259,7 +3259,8 @@ def test_repr(self):
32593259
dtype=torch.float32,
32603260
domain=continuous),
32613261
device=cpu,
3262-
shape=torch.Size([3])),
3262+
shape=torch.Size([3]),
3263+
data_cls=None),
32633264
1 ->
32643265
lidar: BoundedContinuous(
32653266
shape=torch.Size([20]),
@@ -3279,7 +3280,8 @@ def test_repr(self):
32793280
dtype=torch.float32,
32803281
domain=continuous),
32813282
device=cpu,
3282-
shape=torch.Size([3])),
3283+
shape=torch.Size([3]),
3284+
data_cls=None),
32833285
2 ->
32843286
individual_2_obs: Composite(
32853287
individual_1_obs_0: UnboundedContinuous(
@@ -3291,7 +3293,8 @@ def test_repr(self):
32913293
dtype=torch.float32,
32923294
domain=continuous),
32933295
device=cpu,
3294-
shape=torch.Size([3]))}},
3296+
shape=torch.Size([3]),
3297+
data_cls=None)}},
32953298
device=cpu,
32963299
shape={torch.Size((3,))},
32973300
stack_dim={c.stack_dim})"""

test/test_transforms.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10383,6 +10383,78 @@ def test_added_transforms_are_in_eval_mode():
1038310383

1038410384

1038510385
class TestTransformedEnv:
10386+
class DummyCompositeEnv(EnvBase): # type: ignore[misc]
10387+
"""A dummy environment with a composite action set."""
10388+
10389+
def __init__(self) -> None:
10390+
super().__init__()
10391+
10392+
self.observation_spec = Composite(
10393+
observation=UnboundedContinuous((*self.batch_size, 3))
10394+
)
10395+
10396+
self.action_spec = Composite(
10397+
action=Composite(
10398+
head_0=Composite(
10399+
action=Categorical(2, (*self.batch_size, 1), dtype=torch.bool)
10400+
),
10401+
head_1=Composite(
10402+
action=Categorical(2, (*self.batch_size, 1), dtype=torch.bool)
10403+
),
10404+
)
10405+
)
10406+
10407+
self.done_spec = Categorical(2, (*self.batch_size, 1), dtype=torch.bool)
10408+
10409+
self.full_done_spec["truncated"] = self.full_done_spec["terminated"].clone()
10410+
10411+
self.reward_spec = UnboundedContinuous(*self.batch_size, 1)
10412+
10413+
def _reset(self, tensordict: TensorDict) -> TensorDict:
10414+
return TensorDict(
10415+
{"observation": torch.randn((*self.batch_size, 3)), "done": False}
10416+
)
10417+
10418+
def _step(self, tensordict: TensorDict) -> TensorDict:
10419+
return TensorDict(
10420+
{
10421+
"observation": torch.randn((*self.batch_size, 3)),
10422+
"done": False,
10423+
"reward": torch.randn((*self.batch_size, 1)),
10424+
}
10425+
)
10426+
10427+
def _set_seed(self, seed: int) -> None:
10428+
pass
10429+
10430+
class PatchedRenameTransform(RenameTransform): # type: ignore[misc]
10431+
"""
10432+
Fixes a bug in the RenameTransform due to modifying the input_spec of the `base_env` to be transformed.
10433+
This is fixed by adding a clone to break stateful modifications to proapagate to the `base_env`.
10434+
"""
10435+
10436+
def transform_input_spec(self, input_spec: Composite) -> Composite:
10437+
input_spec = input_spec.clone()
10438+
return super().transform_input_spec(input_spec)
10439+
10440+
def test_no_modif_specs(self) -> None:
10441+
base_env = self.DummyCompositeEnv()
10442+
specs = base_env.specs.clone()
10443+
transformed_env = TransformedEnv(
10444+
base_env,
10445+
RenameTransform(
10446+
in_keys=[],
10447+
out_keys=[],
10448+
in_keys_inv=[("action", "head_0", "action")],
10449+
out_keys_inv=[("action", "head_99", "action")],
10450+
),
10451+
)
10452+
td = transformed_env.reset()
10453+
# A second reset with a TD passed fails due to override of the `input_spec`
10454+
td = transformed_env.reset(td)
10455+
specs_after = base_env.specs.clone()
10456+
assert specs == specs_after
10457+
1038610458
@pytest.mark.filterwarnings("error")
1038710459
def test_nested_transformed_env(self):
1038810460
base_env = ContinuousActionVecMockEnv()

torchrl/data/tensor_specs.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5533,8 +5533,10 @@ def __repr__(self) -> str:
55335533
sub_str = [
55345534
indent(f"{k}: {str(item)}", 4 * " ") for k, item in self._specs.items()
55355535
]
5536+
if len(sub_str) == 0:
5537+
return f"{self.__class__.__name__}(device={self._device}, shape={self.shape}, data_cls={self.data_cls})"
55365538
sub_str = ",\n".join(sub_str)
5537-
return f"Composite(\n{sub_str},\n device={self._device},\n shape={self.shape})"
5539+
return f"{self.__class__.__name__}(\n{sub_str},\n device={self._device},\n shape={self.shape},\n data_cls={self.data_cls})"
55385540

55395541
def type_check(
55405542
self,

torchrl/envs/transforms/transforms.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1211,7 +1211,6 @@ def _reset(self, tensordict: TensorDictBase | None = None, **kwargs):
12111211
if tensordict is not None:
12121212
# We must avoid modifying the original tensordict so a shallow copy is necessary.
12131213
# We just select the input data and reset signal, which is all we need.
1214-
self.transform.transform_input_spec(self.base_env.input_spec.unlock_())
12151214
tensordict = tensordict.select(
12161215
*self.reset_keys, *self.state_spec.keys(True, True), strict=False
12171216
)

torchrl/modules/llm/policies/common.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55
from __future__ import annotations
66

7+
import warnings
78
import weakref
89
from typing import Any, Literal, overload
910

@@ -171,6 +172,33 @@ def default_spec(
171172
step_mdp_static=True,
172173
)
173174

175+
def __post_init__(self):
176+
# Check that all history objects have one more batch dimension than the ChatHistory object
177+
if self.prompt is not None:
178+
if self.prompt.batch_dims != self.batch_dims + 1:
179+
warnings.warn(
180+
"Prompt history should have one more batch dimension than the ChatHistory object to handle multi-turn conversations, "
181+
f"got {self.prompt.batch_dims} and {self.batch_dims}. "
182+
"The batch dimension of the ChatHistory object will be unsqueezed along the last dimension."
183+
)
184+
self.prompt = self.prompt.unsqueeze(-1)
185+
if self.response is not None:
186+
if self.response.batch_dims != self.batch_dims + 1:
187+
warnings.warn(
188+
"Response history should have one more batch dimension than the ChatHistory object to handle multi-turn conversations, "
189+
f"got {self.response.batch_dims} and {self.batch_dims}. "
190+
"The batch dimension of the ChatHistory object will be unsqueezed along the last dimension."
191+
)
192+
self.response = self.response.unsqueeze(-1)
193+
if self.full is not None:
194+
if self.full.batch_dims != self.batch_dims + 1:
195+
warnings.warn(
196+
"Full history should have one more batch dimension than the ChatHistory object to handle multi-turn conversations, "
197+
f"got {self.full.batch_dims} and {self.batch_dims}. "
198+
"The batch dimension of the ChatHistory object will be unsqueezed along the last dimension."
199+
)
200+
self.full = self.full.unsqueeze(-1)
201+
174202

175203
class LogProbs(TensorClass["nocast"]):
176204
"""A log-probability container.

0 commit comments

Comments
 (0)