Description
Motivation
In multiagent settings, each agent's individual spec can differ.
I would be nice to have a way of building heterogeneous composite specs, and carry data using tensordict following this logic.
Solution
StackedCompositeSpec
We could use aStackedCompositeSpec
that would essentially work as a tuple of boxes in gym:
Constructor 1
input_spec = StackedCompositeSpec(
action=[NdUnboundedTensorSpec(-3, 3, shape=[3]), NdUnboundedTensorSpec(-3, 3, shape=[5]), ]
)
Constructor 2
input_spec = StackedCompositeSpec(
[
CompositeSpec(action=NdUnboundedTensorSpec(-3, 3, shape=[3])),
CompositeSpec(action=NdUnboundedTensorSpec(-3, 3, shape=[5]),
])
This would basically mean that the environment expects an action of shape Size([3])
for the first agent and Size([5])
for the second.
- Allowing
LazyStackedTensorDict
to host tensors of different shape across a dimension
- Construction
tensordict1 = TensorDict({"action": torch.randn(3)}, [])
tensordict2 = TensorDict({"action": torch.randn(5)}, []) # batch size match
out = torch.stack([tensordict1, tensordict2], 0)
print(out)
which would show a tensor
action: Tensor([2, ...], dtype=torch.float32)
where the diverging shapes have been hidden.
- Prohibited operations
With this kind of tensordict, theget
operation would be prohibited. Instead, one could do
out.get_nested_tensor("action") # LazyStackedTensorDict would create a nestedtensor from it
out[0]["action"] # returns a regular tensor
Similarly, set_
method would not work (as we don't have a data format to pass the mixed input tensor except nestedtensor).
That way we could carry data in ParallelEnv and using the collector while keeping the key with mixed attributes visible to the users.
One could also access a nestedtensor provided that not more than one LazyTensorDict layer is used (as we can't currently build nested nested tensors).
TensorDictSequential is already capable of handling lazy stacked tensordicts that have differnt keys. We could also think about allowing it (?) to gather tensors that do not share the same shape for instance, although this is harder to implement as not every module has a precise signature of the input tensor it expected.