Closed
Description
Describe the bug
When using the SyncDataCollector with a ParallelEnv, all workers are reset if any of the workers is done.
To Reproduce
import torch
from torchrl.collectors import SyncDataCollector
from torchrl.envs import ParallelEnv
from torchrl.envs.libs.gym import GymEnv
if __name__ == '__main__':
num_workers = 2
num_steps = 4000
create_env = lambda: GymEnv(env_name='PongNoFrameskip-v4', frame_skip=4)
create_parallel_env = lambda: ParallelEnv(num_workers, create_env)
collector = SyncDataCollector(create_parallel_env, frames_per_batch=num_workers, split_trajs=False)
iterator = collector.iterator()
episode_rewards = []
cumulative_reward = torch.zeros(num_workers)
for _ in range(num_steps):
batch = next(iterator).squeeze(1)
cumulative_reward += batch['reward'].squeeze(-1)
done = batch['done'].squeeze(-1)
episode_rewards.extend(cumulative_reward[done].tolist())
cumulative_reward[done] = 0.
collector.shutdown()
print(episode_rewards)
This leads to the following episode rewards, although the lowest value for Pong should be -21:
[-21.0, -37.0, -41.0, -21.0]
The reason is that some episodes are reset without notice.
Expected behavior
Only the workers which are actually done should be reset.
Reason and Possible fixes
In this line the ParallelEnv is reset but "reset_workers"
is ignored.
A quick fix would be to replace the line with:
self._tensordict.update(self.env.reset(self._tensordict), inplace=True)
This leads to the following output for the above code and seems to be correct:
[-21.0, -21.0, -21.0, -21.0, -21.0, -20.0, -20.0, -21.0]
But maybe there is a better solution.
Checklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)