Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions test/test_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,70 @@ def test_qvalue_mask(self, action_space, action_n):
else:
assert action_mask.gather(-1, td.get("action").unsqueeze(-1)).all()

def test_qvalue_actor_strict_shape_auto(self):
"""Test that strict_shape='auto' reshapes action to match spec (issue #3059)."""
action_spec = Categorical(4, shape=torch.Size((1, 1)), dtype=torch.int64)
module = TensorDictModule(
module=nn.Linear(3, 1), in_keys=("observation",), out_keys=("action_value",)
)
qvalue_actor = QValueActor(
module=module,
in_keys=["observation"],
spec=action_spec,
strict_shape="auto",
)
td = TensorDict({"observation": torch.randn(12, 3)})
qvalue_actor(td)
assert td["action"].shape == torch.Size([12, 1])

def test_qvalue_actor_strict_shape_true_raises(self):
"""Test that strict_shape=True raises on shape mismatch."""
action_spec = Categorical(4, shape=torch.Size((1, 1)), dtype=torch.int64)
module = TensorDictModule(
module=nn.Linear(3, 1), in_keys=("observation",), out_keys=("action_value",)
)
qvalue_actor = QValueActor(
module=module, in_keys=["observation"], spec=action_spec, strict_shape=True
)
td = TensorDict({"observation": torch.randn(12, 3)})
with pytest.raises(RuntimeError, match="does not match expected shape"):
qvalue_actor(td)

def test_qvalue_actor_strict_shape_none_warns(self):
"""Test that strict_shape=None (default) issues FutureWarning."""
action_spec = Categorical(4, shape=torch.Size((1, 1)), dtype=torch.int64)
module = TensorDictModule(
module=nn.Linear(3, 1), in_keys=("observation",), out_keys=("action_value",)
)
qvalue_actor = QValueActor(
module=module, in_keys=["observation"], spec=action_spec
)
td = TensorDict({"observation": torch.randn(12, 3)})
with pytest.warns(FutureWarning, match="does not match expected shape"):
qvalue_actor(td)

def test_qvalue_actor_strict_shape_normal_no_warning(self):
"""Test that matching shapes produce no warning even with strict_shape='auto'."""
import warnings

action_spec = OneHot(4)
module = TensorDictModule(
module=nn.Linear(3, 4), in_keys=("observation",), out_keys=("action_value",)
)
qvalue_actor = QValueActor(
module=module,
in_keys=["observation"],
spec=action_spec,
strict_shape="auto",
)
td = TensorDict({"observation": torch.randn(5, 3)})
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
qvalue_actor(td)
future_warns = [x for x in w if issubclass(x.category, FutureWarning)]
assert len(future_warns) == 0
assert td["action"].shape == torch.Size([5, 4])


@pytest.mark.parametrize("device", get_default_devices())
def test_value_based_policy(device):
Expand Down
46 changes: 46 additions & 0 deletions torchrl/modules/tensordict_module/actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,10 +538,12 @@ def __init__(
var_nums: int | None = None,
spec: TensorSpec | None = None,
safe: bool = False,
strict_shape: bool | str | None = None,
):
if isinstance(action_space, TensorSpec):
raise TypeError("Using specs in action_space is deprecated")
action_space, spec = _process_action_space_spec(action_space, spec)
self.strict_shape = strict_shape
self.action_space = action_space
self.var_nums = var_nums
self.action_func_mapping = {
Expand Down Expand Up @@ -618,6 +620,48 @@ def forward(self, tensordict: torch.Tensor) -> TensorDictBase:
self.action_space, self._default_action_value
)
chosen_action_value = action_value_func(action_values, action)

# Enforce action shape to match spec (after chosen_action_value computation)
action_key = self.out_keys[0]
action_spec = (
self.spec.get(action_key, None)
if isinstance(self.spec, Composite)
else None
)
if action_spec is not None and self.strict_shape is not False:
composite_batch_ndim = len(self.spec.shape)
per_sample_shape = action_spec.shape[composite_batch_ndim:]
batch_shape = action_values.shape[:-1]
target_shape = torch.Size(list(batch_shape) + list(per_sample_shape))

if action.shape != target_shape:
if self.strict_shape is True:
raise RuntimeError(
f"Action shape {action.shape} does not match expected shape {target_shape} "
f"(per-sample spec shape: {per_sample_shape}). "
f"Set strict_shape='auto' to attempt automatic reshaping."
)
elif self.strict_shape == "auto":
try:
action = action.reshape(target_shape)
except RuntimeError:
raise RuntimeError(
f"Cannot reshape action from {action.shape} to {target_shape}."
)
elif self.strict_shape is None:
import warnings

warnings.warn(
f"Action shape {action.shape} does not match expected shape {target_shape} "
f"(per-sample spec shape: {per_sample_shape}). "
f"In v0.14, this will raise an error. "
f"Set strict_shape='auto' to automatically reshape, "
f"strict_shape=True to raise immediately, "
f"or strict_shape=False to silence this warning.",
FutureWarning,
stacklevel=2,
)

tensordict.update(
dict(zip(self.out_keys, (action, action_values, chosen_action_value)))
)
Expand Down Expand Up @@ -1127,6 +1171,7 @@ def __init__(
action_space: str | None = None,
action_value_key=None,
action_mask_key: NestedKey | None = None,
strict_shape: bool | str | None = None,
):
if isinstance(action_space, TensorSpec):
raise RuntimeError(
Expand Down Expand Up @@ -1172,6 +1217,7 @@ def __init__(
safe=safe,
action_space=action_space,
action_mask_key=action_mask_key,
strict_shape=strict_shape,
)

super().__init__(module, qvalue)
Expand Down
Loading