diff --git a/stable_baselines3/common/recurrent/buffers.py b/stable_baselines3/common/recurrent/buffers.py index 5116a1aa9..79824b16a 100644 --- a/stable_baselines3/common/recurrent/buffers.py +++ b/stable_baselines3/common/recurrent/buffers.py @@ -17,6 +17,7 @@ RecurrentRolloutBufferData, RecurrentRolloutBufferSamples, ) +from stable_baselines3.common.utils import get_device from stable_baselines3.common.vec_env import VecNormalize @@ -173,7 +174,7 @@ def __init__( self.gamma = gamma batch_shape = (self.buffer_size, self.n_envs) - device = self.device + self.device = device = get_device(device) self.observation_space_example = space_to_example((), observation_space)