Skip to content

[BUG] SyncDataCollector resets all workers instead of one #570

Closed
@jrobine

Description

@jrobine

Describe the bug

When using the SyncDataCollector with a ParallelEnv, all workers are reset if any of the workers is done.

To Reproduce

import torch
from torchrl.collectors import SyncDataCollector
from torchrl.envs import ParallelEnv
from torchrl.envs.libs.gym import GymEnv

if __name__ == '__main__':
    num_workers = 2
    num_steps = 4000

    create_env = lambda: GymEnv(env_name='PongNoFrameskip-v4', frame_skip=4)
    create_parallel_env = lambda: ParallelEnv(num_workers, create_env)

    collector = SyncDataCollector(create_parallel_env, frames_per_batch=num_workers, split_trajs=False)
    iterator = collector.iterator()
    episode_rewards = []
    cumulative_reward = torch.zeros(num_workers)
    for _ in range(num_steps):
        batch = next(iterator).squeeze(1)
        cumulative_reward += batch['reward'].squeeze(-1)
        done = batch['done'].squeeze(-1)
        episode_rewards.extend(cumulative_reward[done].tolist())
        cumulative_reward[done] = 0.
    collector.shutdown()
    print(episode_rewards)

This leads to the following episode rewards, although the lowest value for Pong should be -21:

[-21.0, -37.0, -41.0, -21.0]

The reason is that some episodes are reset without notice.

Expected behavior

Only the workers which are actually done should be reset.

Reason and Possible fixes

In this line the ParallelEnv is reset but "reset_workers" is ignored.
A quick fix would be to replace the line with:

self._tensordict.update(self.env.reset(self._tensordict), inplace=True)

This leads to the following output for the above code and seems to be correct:

[-21.0, -21.0, -21.0, -21.0, -21.0, -20.0, -20.0, -21.0]

But maybe there is a better solution.

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions