From 1978bc625724eb9339682253add34fd63a3dbd88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Wed, 11 Oct 2023 11:10:28 -0700 Subject: [PATCH] Assign self.device = device --- stable_baselines3/common/recurrent/buffers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/stable_baselines3/common/recurrent/buffers.py b/stable_baselines3/common/recurrent/buffers.py index 5116a1aa9..79824b16a 100644 --- a/stable_baselines3/common/recurrent/buffers.py +++ b/stable_baselines3/common/recurrent/buffers.py @@ -17,6 +17,7 @@ RecurrentRolloutBufferData, RecurrentRolloutBufferSamples, ) +from stable_baselines3.common.utils import get_device from stable_baselines3.common.vec_env import VecNormalize @@ -173,7 +174,7 @@ def __init__( self.gamma = gamma batch_shape = (self.buffer_size, self.n_envs) - device = self.device + self.device = device = get_device(device) self.observation_space_example = space_to_example((), observation_space)