Skip to content

[BUG] data collectors may omit some keys in the output tensordict #505

Closed
@vmoens

Description

@vmoens

Describe the bug

Remote data collector will populate the output tensordict with the data that come with the first collection in the environment. There are several issues with this, the main being that if we use a value of init_random_frames > 0, the initial tensordict will just be made out of the environment keys + an action key + some utils from the collector (trajectory id and such).

To Reproduce

We can create a policy that uses an RNN to reproduce this.

from torchrl.collectors import MultiaSyncDataCollector
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import TensorDictModule, LSTMNet
import torch

out_features=1
batch = 7
time_steps = 6
hidden_size = 12
net = LSTMNet(
    out_features,
    {"input_size": hidden_size, "hidden_size": hidden_size},
    {"out_features": hidden_size},
)

policy = TensorDictModule(net, ["observation", "hidden1", "hidden2"], ["action", "hidden1", "hidden2", "next_hidden1", "next_hidden2"])

env_maker = lambda: GymEnv("Pendulum-v1")

policy(env_maker().reset())

Case 1: no init_random_frames

Here everything works as expected

c = MultiaSyncDataCollector([env_maker], policy, total_frames=200, frames_per_batch=20, init_random_frames=0)

for i, b in enumerate(c):
    print(b)
    if i == 5:
        # by now we should have a tensordict that is the output of the policy
        break

Every tensordict look like this

TensorDict(
    fields={
        action: Tensor(torch.Size([1, 20, 1]), dtype=torch.float32),
        done: Tensor(torch.Size([1, 20, 1]), dtype=torch.bool),
        hidden1: Tensor(torch.Size([1, 20, 1, 12]), dtype=torch.float32),
        hidden2: Tensor(torch.Size([1, 20, 1, 12]), dtype=torch.float32),
        mask: Tensor(torch.Size([1, 20, 1]), dtype=torch.bool),
        next_hidden1: Tensor(torch.Size([1, 20, 1, 12]), dtype=torch.float32),
        next_hidden2: Tensor(torch.Size([1, 20, 1, 12]), dtype=torch.float32),
        next_observation: Tensor(torch.Size([1, 20, 3]), dtype=torch.float32),
        observation: Tensor(torch.Size([1, 20, 3]), dtype=torch.float32),
        reward: Tensor(torch.Size([1, 20, 1]), dtype=torch.float32),
        step_count: Tensor(torch.Size([1, 20, 1]), dtype=torch.int32),
        traj_ids: Tensor(torch.Size([1, 20, 1, 1]), dtype=torch.int64)},
    batch_size=torch.Size([1, 20]),
    device=cpu,
    is_shared=False)

Case 2: init_random_frames > 0

When running this

c = MultiaSyncDataCollector([env_maker], policy, total_frames=200, frames_per_batch=20, init_random_frames=50)

for i, b in enumerate(c):
    print(b)
    if i == 5:
        # by now we should have a tensordict that is the output of the policy
        break

every tensordict look like this

TensorDict(
    fields={
        action: Tensor(torch.Size([1, 20, 1]), dtype=torch.float32),
        done: Tensor(torch.Size([1, 20, 1]), dtype=torch.bool),
        mask: Tensor(torch.Size([1, 20, 1]), dtype=torch.bool),
        next_observation: Tensor(torch.Size([1, 20, 3]), dtype=torch.float32),
        observation: Tensor(torch.Size([1, 20, 3]), dtype=torch.float32),
        reward: Tensor(torch.Size([1, 20, 1]), dtype=torch.float32),
        step_count: Tensor(torch.Size([1, 20, 1]), dtype=torch.int32),
        traj_ids: Tensor(torch.Size([1, 20, 1, 1]), dtype=torch.int64)},
    batch_size=torch.Size([1, 20]),
    device=cpu,
    is_shared=False)

i.e. the hidden and next_hidden keys are missing even when we have passed the number of frames indicated by init_random_frames.

Case 3: init_random_frames > 0 with SyncDataCollector

SyncDataCollector is the simplest data collector: it does not spawn other processes.

c = SyncDataCollector(env_maker, policy, total_frames=200, frames_per_batch=20, init_random_frames=50)

for i, b in enumerate(c):
    print(b)
    if i == 5:
        # by now we should have a tensordict that is the output of the policy
        break

The first three tensordicts look like this

TensorDict(
    fields={
        action: Tensor(torch.Size([1, 20, 1]), dtype=torch.float32),
        done: Tensor(torch.Size([1, 20, 1]), dtype=torch.bool),
        mask: Tensor(torch.Size([1, 20, 1]), dtype=torch.bool),
        next_observation: Tensor(torch.Size([1, 20, 3]), dtype=torch.float32),
        observation: Tensor(torch.Size([1, 20, 3]), dtype=torch.float32),
        reward: Tensor(torch.Size([1, 20, 1]), dtype=torch.float32),
        step_count: Tensor(torch.Size([1, 20, 1]), dtype=torch.int32),
        traj_ids: Tensor(torch.Size([1, 20, 1, 1]), dtype=torch.int64)},
    batch_size=torch.Size([1, 20]),
    device=cpu,
    is_shared=False)

after that, they start being populated with the hidden keys:

TensorDict(
    fields={
        action: Tensor(torch.Size([1, 20, 1]), dtype=torch.float32),
        done: Tensor(torch.Size([1, 20, 1]), dtype=torch.bool),
        hidden1: Tensor(torch.Size([1, 20, 1, 12]), dtype=torch.float32),
        hidden2: Tensor(torch.Size([1, 20, 1, 12]), dtype=torch.float32),
        mask: Tensor(torch.Size([1, 20, 1]), dtype=torch.bool),
        next_hidden1: Tensor(torch.Size([1, 20, 1, 12]), dtype=torch.float32),
        next_hidden2: Tensor(torch.Size([1, 20, 1, 12]), dtype=torch.float32),
        next_observation: Tensor(torch.Size([1, 20, 3]), dtype=torch.float32),
        observation: Tensor(torch.Size([1, 20, 3]), dtype=torch.float32),
        reward: Tensor(torch.Size([1, 20, 1]), dtype=torch.float32),
        step_count: Tensor(torch.Size([1, 20, 1]), dtype=torch.int32),
        traj_ids: Tensor(torch.Size([1, 20, 1, 1]), dtype=torch.int64)},
    batch_size=torch.Size([1, 20]),
    device=cpu,
    is_shared=False)

## Proposed solution

The expected behaviour should be that every tensordict has all possible keys from the first batch (values that are not valid should be populated with 0s) for all data collectors (note that under the hood every collector will call SyncDataCollector hence only this class needs to be corrected).

The reason for this is that we will store the tensordict in a replay buffer and we need to know what key-value pairs will need to be stored in it from the first collection.

What we want is that after instantiating the env, the self._tensordict_out object should be populated with all the relevant keys.

The collector should make the following operations:

  1. If the policy has an non-empty policy.spec attribute and all the keys are not None (policy.spec is a CompositeSpec object that can have values set to None) and the keys in the spec match the out_keys of the policy, then we assume that the user has given all the relevant information about it. We just need to do
self._tensordict_out = env.reset()
self._tensordict_out.update(policy.action_spec.rand())

If some values in the spec are None or if the policy out_keys are not all present in policy.spec, we would revert to (2) here below.

  1. If the policy does not have a spec, then we should run a mini-rollout in the env with the policy to make sure that we have all the relevant keys:
self._tensordict_out = env.rollout(3, policy)[..., 0]

Finally, if at some point there are more keys in the tensordict collected in SyncDataCollector.rollout than in the SyncDataCollector._tensordict_out object we should raise a warning to the user informing that some data is likely to be missed during collection for the reasons explained above.

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions