diff --git a/.circleci/config.yml b/.circleci/config.yml index 9f8a77ee2..0b1f89fe5 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -9,20 +9,20 @@ parameters: docker_img_version: # Docker image version for running tests. type: string - default: "9fcf583" + default: "b4b3a80-main" workflows: test-jobs: when: equal: [oncommit, << pipeline.parameters.action >>] jobs: - - format-and-mypy: + - py-tests: context: - ghcr-auth - - pytype: + - format-and-mypy: context: - ghcr-auth - - py-tests: + - pytype: context: - ghcr-auth @@ -71,13 +71,13 @@ jobs: password: "$GHCR_DOCKER_TOKEN" resource_class: medium working_directory: /workspace/third_party/stable-baselines3 - parallelism: 16 + parallelism: 14 steps: - checkout - run: name: Run tests command: | - /workspace/dist_test.py --worker-out-dir /workspace/test-results . -k 'not test_save_load_large_model' + /workspace/dist_test.py --worker-out-dir /workspace/test-results -- . -k 'not test_save_load_large_model' environment: OMP_NUM_THREADS: "2" - save-worker-test-results diff --git a/stable_baselines3/common/recurrent/buffers.py b/stable_baselines3/common/recurrent/buffers.py index 284e28a7c..cfa85c19b 100644 --- a/stable_baselines3/common/recurrent/buffers.py +++ b/stable_baselines3/common/recurrent/buffers.py @@ -134,7 +134,8 @@ def add(self, data: RecurrentRolloutBufferData, **kwargs) -> None: # type: igno :param hidden_states: Hidden state of the RNN """ new_data = dataclasses.replace( - data, actions=data.actions.reshape((self.n_envs, self.action_dim)) # type: ignore[misc] + data, + actions=data.actions.reshape((self.n_envs, self.action_dim)), # type: ignore[misc] ) tree_map( @@ -148,34 +149,32 @@ def add(self, data: RecurrentRolloutBufferData, **kwargs) -> None: # type: igno self.full = True def get( # type: ignore[override] - self, batch_size: Optional[int] = None + self, + batch_time: int, + batch_envs: int, ) -> Generator[RecurrentRolloutBufferSamples, None, None]: assert self.full, "Rollout buffer must be full before sampling from it" - - # Return everything, don't create minibatches - if batch_size is None: - batch_size = self.buffer_size * self.n_envs - - if batch_size % self.buffer_size != 0: - raise ValueError(f"Batch size must be divisible by sequence length, but {batch_size=} and len={self.buffer_size}") - - indices = th.randperm(self.n_envs) - adjusted_batch_size = batch_size // self.buffer_size - - if adjusted_batch_size >= self.n_envs: - yield self._get_samples(slice(None)) - - start_idx = 0 - while start_idx < self.n_envs: - yield self._get_samples(indices[start_idx : start_idx + adjusted_batch_size]) - start_idx += adjusted_batch_size + if batch_envs >= self.n_envs: + for time_start in range(0, self.buffer_size, batch_time): + yield self._get_samples(seq_inds=slice(time_start, time_start + batch_time), batch_inds=slice(None)) + + else: + env_indices = th.randperm(self.n_envs) + for env_start in range(0, self.n_envs, batch_envs): + for time_start in range(0, self.buffer_size, batch_time): + yield self._get_samples( + seq_inds=slice(time_start, time_start + batch_time), + batch_inds=env_indices[env_start : env_start + batch_envs], + ) def _get_samples( # type: ignore[override] self, + seq_inds: slice, batch_inds: Union[slice, th.Tensor], ) -> RecurrentRolloutBufferSamples: - idx = (slice(None), batch_inds) - hidden_states_idx = (0, slice(None), batch_inds) + idx = (seq_inds, batch_inds) + # hidden_states: time, n_layers, batch + first_hidden_state_idx = (seq_inds.start, slice(None), batch_inds) return RecurrentRolloutBufferSamples( observations=tree_index(self.data.observations, idx), @@ -184,7 +183,7 @@ def _get_samples( # type: ignore[override] old_log_prob=self.data.log_probs[idx], advantages=self.advantages[idx], returns=self.returns[idx], - hidden_states=tree_index(self.data.hidden_states, hidden_states_idx), # Return only the first hidden state + hidden_states=tree_index(self.data.hidden_states, first_hidden_state_idx), episode_starts=self.data.episode_starts[idx], ) diff --git a/stable_baselines3/ppo_recurrent/ppo_recurrent.py b/stable_baselines3/ppo_recurrent/ppo_recurrent.py index e4a3c9dd5..b72b2485e 100644 --- a/stable_baselines3/ppo_recurrent/ppo_recurrent.py +++ b/stable_baselines3/ppo_recurrent/ppo_recurrent.py @@ -106,7 +106,8 @@ def __init__( env: Union[GymEnv, str], learning_rate: Union[float, Schedule] = 3e-4, n_steps: int = 128, - batch_size: int = 128, + batch_envs: int = 128, + batch_time: Optional[int] = None, n_epochs: int = 10, gamma: float = 0.99, gae_lambda: float = 0.95, @@ -153,32 +154,47 @@ def __init__( spaces.MultiBinary, ), ) + if batch_time is None: + batch_time = self.n_steps # Sanity check, otherwise it will lead to noisy gradient and NaN # because of the advantage normalization if normalize_advantage: assert ( - batch_size > 1 + batch_envs * batch_time > 1 ), "`batch_size` must be greater than 1. See https://github.com/DLR-RM/stable-baselines3/issues/440" if self.env is not None: # Check that `n_steps * n_envs > 1` to avoid NaN # when doing advantage normalization - buffer_size = self.env.num_envs * self.n_steps - assert buffer_size > 1 or ( - not normalize_advantage - ), f"`n_steps * n_envs` must be greater than 1. Currently n_steps={self.n_steps} and n_envs={self.env.num_envs}" + num_envs = self.env.num_envs + assert ( + num_envs > 1 or batch_time > 1 or (not normalize_advantage) + ), f"`num_envs` or `batch_time` must be greater than 1. Currently num_envs={num_envs} and batch_time={batch_time}" # Check that the rollout buffer size is a multiple of the mini-batch size - untruncated_batches = buffer_size // batch_size - if buffer_size % batch_size > 0: + if (truncated_batch_size := num_envs % batch_envs) > 0: + untruncated_batches = num_envs // batch_envs warnings.warn( - f"You have specified a mini-batch size of {batch_size}," - f" but because the `RolloutBuffer` is of size `n_steps * n_envs = {buffer_size}`," + f"You have specified an environment mini-batch size of {batch_envs}," + f" but because the `RecurrentRolloutBuffer` has `n_envs = {self.env.num_envs}`," f" after every {untruncated_batches} untruncated mini-batches," - f" there will be a truncated mini-batch of size {buffer_size % batch_size}\n" - f"We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.\n" - f"Info: (n_steps={self.n_steps} and n_envs={self.env.num_envs})" + f" there will be a truncated mini-batch of size {truncated_batch_size}\n" + f"We recommend using a `batch_envs` that is a factor of `n_envs`.\n" + f"Info: (n_envs={self.env.num_envs})" ) - self.batch_size = batch_size + + if (truncated_batch_size := self.n_steps % batch_time) > 0: + untruncated_batches = self.n_steps // batch_time + warnings.warn( + f"You have specified a time mini-batch size of {batch_time}," + f" but because the `RecurrentRolloutBuffer` has `n_steps = {self.n_steps}`," + f" after every {untruncated_batches} untruncated mini-batches," + f" there will be a truncated mini-batch of size {truncated_batch_size}\n" + f"We recommend using a `batch_time` that is a factor of `n_steps`.\n" + f"Info: (n_envs={self.n_steps})" + ) + + self.batch_envs = batch_envs + self.batch_time = batch_time self.n_epochs = n_epochs self.clip_range: Schedule = clip_range # type: ignore self.clip_range_vf: Schedule = clip_range_vf # type: ignore @@ -217,9 +233,7 @@ def _setup_model(self) -> None: gae_lambda=self.gae_lambda, n_envs=self.n_envs, ) - self._last_lstm_states = tree_map( # type: ignore - lambda x: x[0].clone().contiguous(), self.rollout_buffer.data.hidden_states - ) + self._last_lstm_states = tree_map(lambda x: th.zeros_like(x, memory_format=th.contiguous_format), hidden_state_example) # Initialize schedules for policy/value clipping self.clip_range = get_schedule_fn(self.clip_range) @@ -372,14 +386,14 @@ def train(self) -> None: # train for n_epochs epochs for epoch in range(self.n_epochs): # Do a complete pass on the rollout buffer - for rollout_data in self.rollout_buffer.get(self.batch_size): + for rollout_data in self.rollout_buffer.get(batch_time=self.batch_time, batch_envs=self.batch_envs): actions = rollout_data.actions if isinstance(self.action_space, spaces.Discrete): actions = rollout_data.actions.squeeze(-1) # Re-sample the noise matrix because the log_std has changed if self.use_sde: - self.policy.reset_noise(self.batch_size) + self.policy.reset_noise(self.batch_envs * self.batch_time) values, log_prob, entropy = self.policy.evaluate_actions( rollout_data.observations, # type: ignore[arg-type] diff --git a/tests/test_buffers.py b/tests/test_buffers.py index babf749d0..83a13f19a 100644 --- a/tests/test_buffers.py +++ b/tests/test_buffers.py @@ -127,8 +127,9 @@ def test_replay_buffer_normalization(replay_buffer_cls): @pytest.mark.parametrize( "replay_buffer_cls", [DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer, RecurrentRolloutBuffer] ) +@pytest.mark.parametrize("n_envs", [1, 4]) @pytest.mark.parametrize("device", ["cpu", "cuda", "auto"]) -def test_device_buffer(replay_buffer_cls, device): +def test_device_buffer(replay_buffer_cls, n_envs, device): if device == "cuda" and not th.cuda.is_available(): pytest.skip("CUDA not available") @@ -139,40 +140,47 @@ def test_device_buffer(replay_buffer_cls, device): DictReplayBuffer: DummyDictEnv, RecurrentRolloutBuffer: DummyDictEnv, }[replay_buffer_cls] - env = make_vec_env(env) + env = make_vec_env(env, n_envs=n_envs) hidden_states_shape = HIDDEN_STATES_EXAMPLE["a"]["b"].shape N_ENVS_HIDDEN_STATES = {"a": {"b": th.zeros((hidden_states_shape[0], env.num_envs, *hidden_states_shape[1:]))}} if replay_buffer_cls == RecurrentRolloutBuffer: buffer = RecurrentRolloutBuffer( - EP_LENGTH, env.observation_space, env.action_space, hidden_state_example=N_ENVS_HIDDEN_STATES, device=device + EP_LENGTH, + env.observation_space, + env.action_space, + hidden_state_example=N_ENVS_HIDDEN_STATES, + device=device, + n_envs=n_envs, ) else: - buffer = replay_buffer_cls(EP_LENGTH, env.observation_space, env.action_space, device=device) + buffer = replay_buffer_cls(EP_LENGTH, env.observation_space, env.action_space, device=device, n_envs=n_envs) # Interract and store transitions obs = env.reset() + episode_start, values, log_prob = th.zeros(n_envs), th.zeros(n_envs), th.ones(n_envs) + for _ in range(EP_LENGTH): - action = th.as_tensor(env.action_space.sample()) + action = th.as_tensor([env.action_space.sample() for _ in range(n_envs)]) next_obs, reward, done, info = env.step(action) if replay_buffer_cls in [RolloutBuffer, DictRolloutBuffer]: - episode_start, values, log_prob = th.zeros(1), th.zeros(1), th.ones(1) buffer.add(obs, action, reward, episode_start, values, log_prob) elif replay_buffer_cls == RecurrentRolloutBuffer: - episode_start, values, log_prob = th.zeros(1), th.zeros(1), th.ones(1) buffer.add(RecurrentRolloutBufferData(obs, action, reward, episode_start, values, log_prob, N_ENVS_HIDDEN_STATES)) else: buffer.add(obs, next_obs, action, reward, done, info) obs = next_obs # Get data from the buffer + batch_envs = max(1, env.num_envs // 2) + batch_time = EP_LENGTH // 2 if replay_buffer_cls in [RolloutBuffer, DictRolloutBuffer]: - data = buffer.get(50) + data = buffer.get(batch_time) elif replay_buffer_cls in [ReplayBuffer, DictReplayBuffer]: - data = [buffer.sample(50)] + data = [buffer.sample(batch_time)] elif replay_buffer_cls == RecurrentRolloutBuffer: - data = buffer.get(EP_LENGTH) + data = buffer.get(batch_envs=batch_envs, batch_time=batch_time) # Check that all data are on the desired device desired_device = get_device(device).type @@ -182,3 +190,12 @@ def test_device_buffer(replay_buffer_cls, device): for value in flattened_tensors: assert isinstance(value, th.Tensor) assert value.device.type == desired_device + + # Check that data are of the desired shape + if isinstance(minibatch, (ReplayBufferSamples, DictReplayBufferSamples)): + assert minibatch.rewards.shape == (batch_time, 1) + else: + if replay_buffer_cls == RecurrentRolloutBuffer: + assert minibatch.old_log_prob.shape == (batch_time, batch_envs) + else: + assert minibatch.old_log_prob.shape == (batch_time,) diff --git a/tests/test_deterministic.py b/tests/test_deterministic.py index 8a741ca04..6bd45b711 100644 --- a/tests/test_deterministic.py +++ b/tests/test_deterministic.py @@ -27,7 +27,7 @@ def test_deterministic_training_common(algo): kwargs.update({"n_steps": 64, "n_epochs": 4}) elif algo == RecurrentPPO: kwargs.update({"policy_kwargs": dict(net_arch=[], enable_critic_lstm=True, lstm_hidden_size=8)}) - kwargs.update({"n_steps": 50, "n_epochs": 4, "batch_size": 100}) + kwargs.update({"n_steps": 50, "n_epochs": 4, "batch_time": 25, "batch_envs": 1}) policy_str = "MlpLstmPolicy" if algo == RecurrentPPO else "MlpPolicy" for i in range(2): diff --git a/tests/test_identity.py b/tests/test_identity.py index 34412f7d2..212d35e1e 100644 --- a/tests/test_identity.py +++ b/tests/test_identity.py @@ -66,25 +66,35 @@ @pytest.mark.parametrize("model_class", [A2C, PPO, DQN, RecurrentPPO]) -@pytest.mark.parametrize("env", [IdentityEnv(DIM), IdentityEnvMultiDiscrete(DIM), IdentityEnvMultiBinary(DIM)]) -def test_discrete(model_class, env): - env_ = DummyVecEnv([lambda: env]) - kwargs: dict[str, Any] = {} - n_steps = 10000 +@pytest.mark.parametrize( + "env_fn", [lambda: IdentityEnv(DIM), lambda: IdentityEnvMultiDiscrete(DIM), lambda: IdentityEnvMultiBinary(DIM)] +) +def test_discrete(model_class, env_fn): + # Use multiple envs so we can test that batching works correctly + env_ = DummyVecEnv([env_fn] * 4) + kwargs: dict[str, Any] = dict() + total_n_steps = 10000 if model_class == DQN: kwargs = dict(learning_starts=0) # DQN only support discrete actions - if isinstance(env, (IdentityEnvMultiDiscrete, IdentityEnvMultiBinary)): + if isinstance(env_.envs[0], (IdentityEnvMultiDiscrete, IdentityEnvMultiBinary)): return + if model_class in (RecurrentPPO, PPO): + kwargs["target_kl"] = 0.02 + kwargs["n_epochs"] = 30 + if model_class == RecurrentPPO: # Ensure that there's not an MLP on top of the LSTM that the default Policy creates. kwargs["policy_kwargs"] = dict(net_arch=dict(vf=[], pi=[])) + kwargs["batch_time"] = 128 + kwargs["batch_envs"] = 4 + kwargs["n_steps"] = 256 - model = model_class("MlpPolicy", env_, gamma=0.4, seed=3, **kwargs).learn(n_steps) + model = model_class("MlpPolicy", env_, gamma=0.4, seed=3, **kwargs).learn(total_n_steps) evaluate_policy(model, env_, n_eval_episodes=20, reward_threshold=99, warn=False) - obs, _ = env.reset() + obs, _ = env_.envs[0].reset() assert np.shape(model.predict(obs)[0]) == np.shape(obs) @@ -93,7 +103,7 @@ def test_discrete(model_class, env): def test_continuous(model_class): env = IdentityEnvBox(eps=0.5) - n_steps = 2000 if issubclass(model_class, OnPolicyAlgorithm) else 400 + total_n_steps = 2000 if issubclass(model_class, OnPolicyAlgorithm) else 400 kwargs: dict[str, Any] = dict(policy_kwargs=dict(net_arch=[64, 64]), seed=0, gamma=0.95) @@ -109,6 +119,6 @@ def test_continuous(model_class): # Ensure that there's not an MLP on top of the LSTM that the default Policy creates. kwargs["policy_kwargs"]["net_arch"] = dict(vf=[], pi=[]) - model = model_class("MlpPolicy", env, learning_rate=1e-3, **kwargs).learn(n_steps) + model = model_class("MlpPolicy", env, learning_rate=1e-3, **kwargs).learn(total_n_steps) evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90, warn=False) diff --git a/tests/test_lstm.py b/tests/test_lstm.py index 3e5e5d9e9..02361f43e 100644 --- a/tests/test_lstm.py +++ b/tests/test_lstm.py @@ -260,6 +260,7 @@ def make_env(): N_ENVS = 16 N_STEPS = 32 + BATCH_TIME = 4 env = VecNormalize(make_vec_env(make_env, n_envs=N_ENVS)) eval_callback = EvalCallback( @@ -282,7 +283,8 @@ def make_env(): n_steps=N_STEPS, learning_rate=0.0007, verbose=1, - batch_size=N_ENVS * N_STEPS, + batch_envs=N_ENVS, + batch_time=BATCH_TIME, seed=1, n_epochs=10, max_grad_norm=1,