Skip to content

[Feature Request] Batched specs of heterogeneous shape and related stacked tensordicts #766

Open
@vmoens

Description

@vmoens

Motivation

In multiagent settings, each agent's individual spec can differ.
I would be nice to have a way of building heterogeneous composite specs, and carry data using tensordict following this logic.

Solution

  1. StackedCompositeSpec
    We could use a StackedCompositeSpec that would essentially work as a tuple of boxes in gym:

Constructor 1

input_spec = StackedCompositeSpec(
    action=[NdUnboundedTensorSpec(-3, 3, shape=[3]), NdUnboundedTensorSpec(-3, 3, shape=[5]), ]
)

Constructor 2

input_spec = StackedCompositeSpec(
    [
     CompositeSpec(action=NdUnboundedTensorSpec(-3, 3, shape=[3])), 
     CompositeSpec(action=NdUnboundedTensorSpec(-3, 3, shape=[5]), 
])

This would basically mean that the environment expects an action of shape Size([3]) for the first agent and Size([5]) for the second.

  1. Allowing LazyStackedTensorDict to host tensors of different shape across a dimension
  • Construction
tensordict1 = TensorDict({"action": torch.randn(3)}, [])
tensordict2 = TensorDict({"action": torch.randn(5)}, []) # batch size match
out = torch.stack([tensordict1, tensordict2], 0)
print(out)

which would show a tensor

action: Tensor([2, ...], dtype=torch.float32)

where the diverging shapes have been hidden.

  • Prohibited operations
    With this kind of tensordict, the get operation would be prohibited. Instead, one could do
out.get_nested_tensor("action")  # LazyStackedTensorDict would create a nestedtensor from it
out[0]["action"]  # returns a regular tensor

Similarly, set_ method would not work (as we don't have a data format to pass the mixed input tensor except nestedtensor).

That way we could carry data in ParallelEnv and using the collector while keeping the key with mixed attributes visible to the users.
One could also access a nestedtensor provided that not more than one LazyTensorDict layer is used (as we can't currently build nested nested tensors).

TensorDictSequential is already capable of handling lazy stacked tensordicts that have differnt keys. We could also think about allowing it (?) to gather tensors that do not share the same shape for instance, although this is harder to implement as not every module has a precise signature of the input tensor it expected.

cc @matteobettini

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions