Description
I am in the process of adding multiprocessing(vectorized envs) support for off-policy algorithms (TD3, SAC, DDPG etc), I've added support for sampling multiple actions and updates timesteps appropriately to the number of vectorized environments. The modified code can run without throwing an error, but the algorithms don't really converge anymore.
I tried on OpenAI Gym's Pendulum-v0, where single instance envs made from make_vec_env('Pendulum-v0', n_envs = 1, vec_env_cls=DummyVecEnv)
trains fine. If I specify multiple instances such as make_vec_env('Pendulum-v0', n_envs = 2, vec_env_cls=DummyVecEnv)
or make_vec_env('Pendulum-v0', n_envs = 2, vec_env_cls=SubprocVecEnv)
, then the algorithms don't converge at all.
Here's a warning message that I get, which I suspect is closely related to the non-convergence.
/home/me/code/stable-baselines3/stable_baselines3/sac/sac.py:237: UserWarning: Using a target size (torch.Size([256, 2])) that is different
to the input size (torch.Size([256, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size. critic_loss = 0.5 * sum([F.mse_loss(current_q, q_backup) for current_q in current_q_estimates])
It appears to me that the replay buffer wasn't not retrieving n_envs
samples thus the loss target had to rely on broadcasting
Some pointers on modifying the replay buffer so it would support multiprocessing would be much appreciated! If the authors would like, I can create a PR