Description
Describe the bug
Remote data collector will populate the output tensordict with the data that come with the first collection in the environment. There are several issues with this, the main being that if we use a value of init_random_frames
> 0, the initial tensordict will just be made out of the environment keys + an action key + some utils from the collector (trajectory id and such).
To Reproduce
We can create a policy that uses an RNN to reproduce this.
from torchrl.collectors import MultiaSyncDataCollector
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import TensorDictModule, LSTMNet
import torch
out_features=1
batch = 7
time_steps = 6
hidden_size = 12
net = LSTMNet(
out_features,
{"input_size": hidden_size, "hidden_size": hidden_size},
{"out_features": hidden_size},
)
policy = TensorDictModule(net, ["observation", "hidden1", "hidden2"], ["action", "hidden1", "hidden2", "next_hidden1", "next_hidden2"])
env_maker = lambda: GymEnv("Pendulum-v1")
policy(env_maker().reset())
Case 1: no init_random_frames
Here everything works as expected
c = MultiaSyncDataCollector([env_maker], policy, total_frames=200, frames_per_batch=20, init_random_frames=0)
for i, b in enumerate(c):
print(b)
if i == 5:
# by now we should have a tensordict that is the output of the policy
break
Every tensordict look like this
TensorDict(
fields={
action: Tensor(torch.Size([1, 20, 1]), dtype=torch.float32),
done: Tensor(torch.Size([1, 20, 1]), dtype=torch.bool),
hidden1: Tensor(torch.Size([1, 20, 1, 12]), dtype=torch.float32),
hidden2: Tensor(torch.Size([1, 20, 1, 12]), dtype=torch.float32),
mask: Tensor(torch.Size([1, 20, 1]), dtype=torch.bool),
next_hidden1: Tensor(torch.Size([1, 20, 1, 12]), dtype=torch.float32),
next_hidden2: Tensor(torch.Size([1, 20, 1, 12]), dtype=torch.float32),
next_observation: Tensor(torch.Size([1, 20, 3]), dtype=torch.float32),
observation: Tensor(torch.Size([1, 20, 3]), dtype=torch.float32),
reward: Tensor(torch.Size([1, 20, 1]), dtype=torch.float32),
step_count: Tensor(torch.Size([1, 20, 1]), dtype=torch.int32),
traj_ids: Tensor(torch.Size([1, 20, 1, 1]), dtype=torch.int64)},
batch_size=torch.Size([1, 20]),
device=cpu,
is_shared=False)
Case 2: init_random_frames
> 0
When running this
c = MultiaSyncDataCollector([env_maker], policy, total_frames=200, frames_per_batch=20, init_random_frames=50)
for i, b in enumerate(c):
print(b)
if i == 5:
# by now we should have a tensordict that is the output of the policy
break
every tensordict look like this
TensorDict(
fields={
action: Tensor(torch.Size([1, 20, 1]), dtype=torch.float32),
done: Tensor(torch.Size([1, 20, 1]), dtype=torch.bool),
mask: Tensor(torch.Size([1, 20, 1]), dtype=torch.bool),
next_observation: Tensor(torch.Size([1, 20, 3]), dtype=torch.float32),
observation: Tensor(torch.Size([1, 20, 3]), dtype=torch.float32),
reward: Tensor(torch.Size([1, 20, 1]), dtype=torch.float32),
step_count: Tensor(torch.Size([1, 20, 1]), dtype=torch.int32),
traj_ids: Tensor(torch.Size([1, 20, 1, 1]), dtype=torch.int64)},
batch_size=torch.Size([1, 20]),
device=cpu,
is_shared=False)
i.e. the hidden
and next_hidden
keys are missing even when we have passed the number of frames indicated by init_random_frames
.
Case 3: init_random_frames
> 0 with SyncDataCollector
SyncDataCollector
is the simplest data collector: it does not spawn other processes.
c = SyncDataCollector(env_maker, policy, total_frames=200, frames_per_batch=20, init_random_frames=50)
for i, b in enumerate(c):
print(b)
if i == 5:
# by now we should have a tensordict that is the output of the policy
break
The first three tensordicts look like this
TensorDict(
fields={
action: Tensor(torch.Size([1, 20, 1]), dtype=torch.float32),
done: Tensor(torch.Size([1, 20, 1]), dtype=torch.bool),
mask: Tensor(torch.Size([1, 20, 1]), dtype=torch.bool),
next_observation: Tensor(torch.Size([1, 20, 3]), dtype=torch.float32),
observation: Tensor(torch.Size([1, 20, 3]), dtype=torch.float32),
reward: Tensor(torch.Size([1, 20, 1]), dtype=torch.float32),
step_count: Tensor(torch.Size([1, 20, 1]), dtype=torch.int32),
traj_ids: Tensor(torch.Size([1, 20, 1, 1]), dtype=torch.int64)},
batch_size=torch.Size([1, 20]),
device=cpu,
is_shared=False)
after that, they start being populated with the hidden keys:
TensorDict(
fields={
action: Tensor(torch.Size([1, 20, 1]), dtype=torch.float32),
done: Tensor(torch.Size([1, 20, 1]), dtype=torch.bool),
hidden1: Tensor(torch.Size([1, 20, 1, 12]), dtype=torch.float32),
hidden2: Tensor(torch.Size([1, 20, 1, 12]), dtype=torch.float32),
mask: Tensor(torch.Size([1, 20, 1]), dtype=torch.bool),
next_hidden1: Tensor(torch.Size([1, 20, 1, 12]), dtype=torch.float32),
next_hidden2: Tensor(torch.Size([1, 20, 1, 12]), dtype=torch.float32),
next_observation: Tensor(torch.Size([1, 20, 3]), dtype=torch.float32),
observation: Tensor(torch.Size([1, 20, 3]), dtype=torch.float32),
reward: Tensor(torch.Size([1, 20, 1]), dtype=torch.float32),
step_count: Tensor(torch.Size([1, 20, 1]), dtype=torch.int32),
traj_ids: Tensor(torch.Size([1, 20, 1, 1]), dtype=torch.int64)},
batch_size=torch.Size([1, 20]),
device=cpu,
is_shared=False)
## Proposed solution
The expected behaviour should be that every tensordict has all possible keys from the first batch (values that are not valid should be populated with 0s) for all data collectors (note that under the hood every collector will call SyncDataCollector
hence only this class needs to be corrected).
The reason for this is that we will store the tensordict in a replay buffer and we need to know what key-value pairs will need to be stored in it from the first collection.
What we want is that after instantiating the env, the self._tensordict_out
object should be populated with all the relevant keys.
The collector should make the following operations:
- If the policy has an non-empty
policy.spec
attribute and all the keys are not None (policy.spec
is aCompositeSpec
object that can have values set to None) and the keys in the spec match theout_keys
of the policy, then we assume that the user has given all the relevant information about it. We just need to do
self._tensordict_out = env.reset()
self._tensordict_out.update(policy.action_spec.rand())
If some values in the spec
are None or if the policy out_keys
are not all present in policy.spec
, we would revert to (2) here below.
- If the policy does not have a spec, then we should run a mini-rollout in the env with the policy to make sure that we have all the relevant keys:
self._tensordict_out = env.rollout(3, policy)[..., 0]
Finally, if at some point there are more keys in the tensordict collected in SyncDataCollector.rollout
than in the SyncDataCollector._tensordict_out
object we should raise a warning to the user informing that some data is likely to be missed during collection for the reasons explained above.