Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3259,7 +3259,8 @@ def test_repr(self):
dtype=torch.float32,
domain=continuous),
device=cpu,
shape=torch.Size([3])),
shape=torch.Size([3]),
data_cls=None),
1 ->
lidar: BoundedContinuous(
shape=torch.Size([20]),
Expand All @@ -3279,7 +3280,8 @@ def test_repr(self):
dtype=torch.float32,
domain=continuous),
device=cpu,
shape=torch.Size([3])),
shape=torch.Size([3]),
data_cls=None),
2 ->
individual_2_obs: Composite(
individual_1_obs_0: UnboundedContinuous(
Expand All @@ -3291,7 +3293,8 @@ def test_repr(self):
dtype=torch.float32,
domain=continuous),
device=cpu,
shape=torch.Size([3]))}},
shape=torch.Size([3]),
data_cls=None)}},
device=cpu,
shape={torch.Size((3,))},
stack_dim={c.stack_dim})"""
Expand Down
72 changes: 72 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10383,6 +10383,78 @@ def test_added_transforms_are_in_eval_mode():


class TestTransformedEnv:
class DummyCompositeEnv(EnvBase): # type: ignore[misc]
"""A dummy environment with a composite action set."""

def __init__(self) -> None:
super().__init__()

self.observation_spec = Composite(
observation=UnboundedContinuous((*self.batch_size, 3))
)

self.action_spec = Composite(
action=Composite(
head_0=Composite(
action=Categorical(2, (*self.batch_size, 1), dtype=torch.bool)
),
head_1=Composite(
action=Categorical(2, (*self.batch_size, 1), dtype=torch.bool)
),
)
)

self.done_spec = Categorical(2, (*self.batch_size, 1), dtype=torch.bool)

self.full_done_spec["truncated"] = self.full_done_spec["terminated"].clone()

self.reward_spec = UnboundedContinuous(*self.batch_size, 1)

def _reset(self, tensordict: TensorDict) -> TensorDict:
return TensorDict(
{"observation": torch.randn((*self.batch_size, 3)), "done": False}
)

def _step(self, tensordict: TensorDict) -> TensorDict:
return TensorDict(
{
"observation": torch.randn((*self.batch_size, 3)),
"done": False,
"reward": torch.randn((*self.batch_size, 1)),
}
)

def _set_seed(self, seed: int) -> None:
pass

class PatchedRenameTransform(RenameTransform): # type: ignore[misc]
"""
Fixes a bug in the RenameTransform due to modifying the input_spec of the `base_env` to be transformed.
This is fixed by adding a clone to break stateful modifications to proapagate to the `base_env`.
"""

def transform_input_spec(self, input_spec: Composite) -> Composite:
input_spec = input_spec.clone()
return super().transform_input_spec(input_spec)
Comment on lines +10430 to +10438
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vmoens This class is not needed by this test anymore

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, removed in #3079


def test_no_modif_specs(self) -> None:
base_env = self.DummyCompositeEnv()
specs = base_env.specs.clone()
transformed_env = TransformedEnv(
base_env,
RenameTransform(
in_keys=[],
out_keys=[],
in_keys_inv=[("action", "head_0", "action")],
out_keys_inv=[("action", "head_99", "action")],
),
)
td = transformed_env.reset()
# A second reset with a TD passed fails due to override of the `input_spec`
td = transformed_env.reset(td)
specs_after = base_env.specs.clone()
assert specs == specs_after

@pytest.mark.filterwarnings("error")
def test_nested_transformed_env(self):
base_env = ContinuousActionVecMockEnv()
Expand Down
4 changes: 3 additions & 1 deletion torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5533,8 +5533,10 @@ def __repr__(self) -> str:
sub_str = [
indent(f"{k}: {str(item)}", 4 * " ") for k, item in self._specs.items()
]
if len(sub_str) == 0:
return f"{self.__class__.__name__}(device={self._device}, shape={self.shape}, data_cls={self.data_cls})"
sub_str = ",\n".join(sub_str)
return f"Composite(\n{sub_str},\n device={self._device},\n shape={self.shape})"
return f"{self.__class__.__name__}(\n{sub_str},\n device={self._device},\n shape={self.shape},\n data_cls={self.data_cls})"

def type_check(
self,
Expand Down
1 change: 0 additions & 1 deletion torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1211,7 +1211,6 @@ def _reset(self, tensordict: TensorDictBase | None = None, **kwargs):
if tensordict is not None:
# We must avoid modifying the original tensordict so a shallow copy is necessary.
# We just select the input data and reset signal, which is all we need.
self.transform.transform_input_spec(self.base_env.input_spec.unlock_())
tensordict = tensordict.select(
*self.reset_keys, *self.state_spec.keys(True, True), strict=False
)
Expand Down
28 changes: 28 additions & 0 deletions torchrl/modules/llm/policies/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import warnings
import weakref
from typing import Any, Literal, overload

Expand Down Expand Up @@ -171,6 +172,33 @@ def default_spec(
step_mdp_static=True,
)

def __post_init__(self):
# Check that all history objects have one more batch dimension than the ChatHistory object
if self.prompt is not None:
if self.prompt.batch_dims != self.batch_dims + 1:
warnings.warn(
"Prompt history should have one more batch dimension than the ChatHistory object to handle multi-turn conversations, "
f"got {self.prompt.batch_dims} and {self.batch_dims}. "
"The batch dimension of the ChatHistory object will be unsqueezed along the last dimension."
)
self.prompt = self.prompt.unsqueeze(-1)
if self.response is not None:
if self.response.batch_dims != self.batch_dims + 1:
warnings.warn(
"Response history should have one more batch dimension than the ChatHistory object to handle multi-turn conversations, "
f"got {self.response.batch_dims} and {self.batch_dims}. "
"The batch dimension of the ChatHistory object will be unsqueezed along the last dimension."
)
self.response = self.response.unsqueeze(-1)
if self.full is not None:
if self.full.batch_dims != self.batch_dims + 1:
warnings.warn(
"Full history should have one more batch dimension than the ChatHistory object to handle multi-turn conversations, "
f"got {self.full.batch_dims} and {self.batch_dims}. "
"The batch dimension of the ChatHistory object will be unsqueezed along the last dimension."
)
self.full = self.full.unsqueeze(-1)


class LogProbs(TensorClass["nocast"]):
"""A log-probability container.
Expand Down
Loading