Skip to content

Commit

Permalink
Batches over time and space for recurrent PPO (#20)
Browse files Browse the repository at this point in the history
Add support for batches that aren't as time-long as `n_steps`, with
separate`batch_time` and `batch_envs` parameterers for `RecurrentPPO`.
  • Loading branch information
rhaps0dy authored Feb 12, 2024
2 parents cc0e970 + 1b7a72a commit fa71467
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 70 deletions.
12 changes: 6 additions & 6 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,20 @@ parameters:
docker_img_version:
# Docker image version for running tests.
type: string
default: "9fcf583"
default: "b4b3a80-main"

workflows:
test-jobs:
when:
equal: [oncommit, << pipeline.parameters.action >>]
jobs:
- format-and-mypy:
- py-tests:
context:
- ghcr-auth
- pytype:
- format-and-mypy:
context:
- ghcr-auth
- py-tests:
- pytype:
context:
- ghcr-auth

Expand Down Expand Up @@ -71,13 +71,13 @@ jobs:
password: "$GHCR_DOCKER_TOKEN"
resource_class: medium
working_directory: /workspace/third_party/stable-baselines3
parallelism: 16
parallelism: 14
steps:
- checkout
- run:
name: Run tests
command: |
/workspace/dist_test.py --worker-out-dir /workspace/test-results . -k 'not test_save_load_large_model'
/workspace/dist_test.py --worker-out-dir /workspace/test-results -- . -k 'not test_save_load_large_model'
environment:
OMP_NUM_THREADS: "2"
- save-worker-test-results
45 changes: 22 additions & 23 deletions stable_baselines3/common/recurrent/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ def add(self, data: RecurrentRolloutBufferData, **kwargs) -> None: # type: igno
:param hidden_states: Hidden state of the RNN
"""
new_data = dataclasses.replace(
data, actions=data.actions.reshape((self.n_envs, self.action_dim)) # type: ignore[misc]
data,
actions=data.actions.reshape((self.n_envs, self.action_dim)), # type: ignore[misc]
)

tree_map(
Expand All @@ -148,34 +149,32 @@ def add(self, data: RecurrentRolloutBufferData, **kwargs) -> None: # type: igno
self.full = True

def get( # type: ignore[override]
self, batch_size: Optional[int] = None
self,
batch_time: int,
batch_envs: int,
) -> Generator[RecurrentRolloutBufferSamples, None, None]:
assert self.full, "Rollout buffer must be full before sampling from it"

# Return everything, don't create minibatches
if batch_size is None:
batch_size = self.buffer_size * self.n_envs

if batch_size % self.buffer_size != 0:
raise ValueError(f"Batch size must be divisible by sequence length, but {batch_size=} and len={self.buffer_size}")

indices = th.randperm(self.n_envs)
adjusted_batch_size = batch_size // self.buffer_size

if adjusted_batch_size >= self.n_envs:
yield self._get_samples(slice(None))

start_idx = 0
while start_idx < self.n_envs:
yield self._get_samples(indices[start_idx : start_idx + adjusted_batch_size])
start_idx += adjusted_batch_size
if batch_envs >= self.n_envs:
for time_start in range(0, self.buffer_size, batch_time):
yield self._get_samples(seq_inds=slice(time_start, time_start + batch_time), batch_inds=slice(None))

else:
env_indices = th.randperm(self.n_envs)
for env_start in range(0, self.n_envs, batch_envs):
for time_start in range(0, self.buffer_size, batch_time):
yield self._get_samples(
seq_inds=slice(time_start, time_start + batch_time),
batch_inds=env_indices[env_start : env_start + batch_envs],
)

def _get_samples( # type: ignore[override]
self,
seq_inds: slice,
batch_inds: Union[slice, th.Tensor],
) -> RecurrentRolloutBufferSamples:
idx = (slice(None), batch_inds)
hidden_states_idx = (0, slice(None), batch_inds)
idx = (seq_inds, batch_inds)
# hidden_states: time, n_layers, batch
first_hidden_state_idx = (seq_inds.start, slice(None), batch_inds)

return RecurrentRolloutBufferSamples(
observations=tree_index(self.data.observations, idx),
Expand All @@ -184,7 +183,7 @@ def _get_samples( # type: ignore[override]
old_log_prob=self.data.log_probs[idx],
advantages=self.advantages[idx],
returns=self.returns[idx],
hidden_states=tree_index(self.data.hidden_states, hidden_states_idx), # Return only the first hidden state
hidden_states=tree_index(self.data.hidden_states, first_hidden_state_idx),
episode_starts=self.data.episode_starts[idx],
)

Expand Down
52 changes: 33 additions & 19 deletions stable_baselines3/ppo_recurrent/ppo_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ def __init__(
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 3e-4,
n_steps: int = 128,
batch_size: int = 128,
batch_envs: int = 128,
batch_time: Optional[int] = None,
n_epochs: int = 10,
gamma: float = 0.99,
gae_lambda: float = 0.95,
Expand Down Expand Up @@ -153,32 +154,47 @@ def __init__(
spaces.MultiBinary,
),
)
if batch_time is None:
batch_time = self.n_steps
# Sanity check, otherwise it will lead to noisy gradient and NaN
# because of the advantage normalization
if normalize_advantage:
assert (
batch_size > 1
batch_envs * batch_time > 1
), "`batch_size` must be greater than 1. See https://github.com/DLR-RM/stable-baselines3/issues/440"

if self.env is not None:
# Check that `n_steps * n_envs > 1` to avoid NaN
# when doing advantage normalization
buffer_size = self.env.num_envs * self.n_steps
assert buffer_size > 1 or (
not normalize_advantage
), f"`n_steps * n_envs` must be greater than 1. Currently n_steps={self.n_steps} and n_envs={self.env.num_envs}"
num_envs = self.env.num_envs
assert (
num_envs > 1 or batch_time > 1 or (not normalize_advantage)
), f"`num_envs` or `batch_time` must be greater than 1. Currently num_envs={num_envs} and batch_time={batch_time}"
# Check that the rollout buffer size is a multiple of the mini-batch size
untruncated_batches = buffer_size // batch_size
if buffer_size % batch_size > 0:
if (truncated_batch_size := num_envs % batch_envs) > 0:
untruncated_batches = num_envs // batch_envs
warnings.warn(
f"You have specified a mini-batch size of {batch_size},"
f" but because the `RolloutBuffer` is of size `n_steps * n_envs = {buffer_size}`,"
f"You have specified an environment mini-batch size of {batch_envs},"
f" but because the `RecurrentRolloutBuffer` has `n_envs = {self.env.num_envs}`,"
f" after every {untruncated_batches} untruncated mini-batches,"
f" there will be a truncated mini-batch of size {buffer_size % batch_size}\n"
f"We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.\n"
f"Info: (n_steps={self.n_steps} and n_envs={self.env.num_envs})"
f" there will be a truncated mini-batch of size {truncated_batch_size}\n"
f"We recommend using a `batch_envs` that is a factor of `n_envs`.\n"
f"Info: (n_envs={self.env.num_envs})"
)
self.batch_size = batch_size

if (truncated_batch_size := self.n_steps % batch_time) > 0:
untruncated_batches = self.n_steps // batch_time
warnings.warn(
f"You have specified a time mini-batch size of {batch_time},"
f" but because the `RecurrentRolloutBuffer` has `n_steps = {self.n_steps}`,"
f" after every {untruncated_batches} untruncated mini-batches,"
f" there will be a truncated mini-batch of size {truncated_batch_size}\n"
f"We recommend using a `batch_time` that is a factor of `n_steps`.\n"
f"Info: (n_envs={self.n_steps})"
)

self.batch_envs = batch_envs
self.batch_time = batch_time
self.n_epochs = n_epochs
self.clip_range: Schedule = clip_range # type: ignore
self.clip_range_vf: Schedule = clip_range_vf # type: ignore
Expand Down Expand Up @@ -217,9 +233,7 @@ def _setup_model(self) -> None:
gae_lambda=self.gae_lambda,
n_envs=self.n_envs,
)
self._last_lstm_states = tree_map( # type: ignore
lambda x: x[0].clone().contiguous(), self.rollout_buffer.data.hidden_states
)
self._last_lstm_states = tree_map(lambda x: th.zeros_like(x, memory_format=th.contiguous_format), hidden_state_example)

# Initialize schedules for policy/value clipping
self.clip_range = get_schedule_fn(self.clip_range)
Expand Down Expand Up @@ -372,14 +386,14 @@ def train(self) -> None:
# train for n_epochs epochs
for epoch in range(self.n_epochs):
# Do a complete pass on the rollout buffer
for rollout_data in self.rollout_buffer.get(self.batch_size):
for rollout_data in self.rollout_buffer.get(batch_time=self.batch_time, batch_envs=self.batch_envs):
actions = rollout_data.actions
if isinstance(self.action_space, spaces.Discrete):
actions = rollout_data.actions.squeeze(-1)

# Re-sample the noise matrix because the log_std has changed
if self.use_sde:
self.policy.reset_noise(self.batch_size)
self.policy.reset_noise(self.batch_envs * self.batch_time)

values, log_prob, entropy = self.policy.evaluate_actions(
rollout_data.observations, # type: ignore[arg-type]
Expand Down
37 changes: 27 additions & 10 deletions tests/test_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,9 @@ def test_replay_buffer_normalization(replay_buffer_cls):
@pytest.mark.parametrize(
"replay_buffer_cls", [DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer, RecurrentRolloutBuffer]
)
@pytest.mark.parametrize("n_envs", [1, 4])
@pytest.mark.parametrize("device", ["cpu", "cuda", "auto"])
def test_device_buffer(replay_buffer_cls, device):
def test_device_buffer(replay_buffer_cls, n_envs, device):
if device == "cuda" and not th.cuda.is_available():
pytest.skip("CUDA not available")

Expand All @@ -139,40 +140,47 @@ def test_device_buffer(replay_buffer_cls, device):
DictReplayBuffer: DummyDictEnv,
RecurrentRolloutBuffer: DummyDictEnv,
}[replay_buffer_cls]
env = make_vec_env(env)
env = make_vec_env(env, n_envs=n_envs)
hidden_states_shape = HIDDEN_STATES_EXAMPLE["a"]["b"].shape
N_ENVS_HIDDEN_STATES = {"a": {"b": th.zeros((hidden_states_shape[0], env.num_envs, *hidden_states_shape[1:]))}}

if replay_buffer_cls == RecurrentRolloutBuffer:
buffer = RecurrentRolloutBuffer(
EP_LENGTH, env.observation_space, env.action_space, hidden_state_example=N_ENVS_HIDDEN_STATES, device=device
EP_LENGTH,
env.observation_space,
env.action_space,
hidden_state_example=N_ENVS_HIDDEN_STATES,
device=device,
n_envs=n_envs,
)
else:
buffer = replay_buffer_cls(EP_LENGTH, env.observation_space, env.action_space, device=device)
buffer = replay_buffer_cls(EP_LENGTH, env.observation_space, env.action_space, device=device, n_envs=n_envs)

# Interract and store transitions
obs = env.reset()
episode_start, values, log_prob = th.zeros(n_envs), th.zeros(n_envs), th.ones(n_envs)

for _ in range(EP_LENGTH):
action = th.as_tensor(env.action_space.sample())
action = th.as_tensor([env.action_space.sample() for _ in range(n_envs)])

next_obs, reward, done, info = env.step(action)
if replay_buffer_cls in [RolloutBuffer, DictRolloutBuffer]:
episode_start, values, log_prob = th.zeros(1), th.zeros(1), th.ones(1)
buffer.add(obs, action, reward, episode_start, values, log_prob)
elif replay_buffer_cls == RecurrentRolloutBuffer:
episode_start, values, log_prob = th.zeros(1), th.zeros(1), th.ones(1)
buffer.add(RecurrentRolloutBufferData(obs, action, reward, episode_start, values, log_prob, N_ENVS_HIDDEN_STATES))
else:
buffer.add(obs, next_obs, action, reward, done, info)
obs = next_obs

# Get data from the buffer
batch_envs = max(1, env.num_envs // 2)
batch_time = EP_LENGTH // 2
if replay_buffer_cls in [RolloutBuffer, DictRolloutBuffer]:
data = buffer.get(50)
data = buffer.get(batch_time)
elif replay_buffer_cls in [ReplayBuffer, DictReplayBuffer]:
data = [buffer.sample(50)]
data = [buffer.sample(batch_time)]
elif replay_buffer_cls == RecurrentRolloutBuffer:
data = buffer.get(EP_LENGTH)
data = buffer.get(batch_envs=batch_envs, batch_time=batch_time)

# Check that all data are on the desired device
desired_device = get_device(device).type
Expand All @@ -182,3 +190,12 @@ def test_device_buffer(replay_buffer_cls, device):
for value in flattened_tensors:
assert isinstance(value, th.Tensor)
assert value.device.type == desired_device

# Check that data are of the desired shape
if isinstance(minibatch, (ReplayBufferSamples, DictReplayBufferSamples)):
assert minibatch.rewards.shape == (batch_time, 1)
else:
if replay_buffer_cls == RecurrentRolloutBuffer:
assert minibatch.old_log_prob.shape == (batch_time, batch_envs)
else:
assert minibatch.old_log_prob.shape == (batch_time,)
2 changes: 1 addition & 1 deletion tests/test_deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_deterministic_training_common(algo):
kwargs.update({"n_steps": 64, "n_epochs": 4})
elif algo == RecurrentPPO:
kwargs.update({"policy_kwargs": dict(net_arch=[], enable_critic_lstm=True, lstm_hidden_size=8)})
kwargs.update({"n_steps": 50, "n_epochs": 4, "batch_size": 100})
kwargs.update({"n_steps": 50, "n_epochs": 4, "batch_time": 25, "batch_envs": 1})

policy_str = "MlpLstmPolicy" if algo == RecurrentPPO else "MlpPolicy"
for i in range(2):
Expand Down
30 changes: 20 additions & 10 deletions tests/test_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,25 +66,35 @@


@pytest.mark.parametrize("model_class", [A2C, PPO, DQN, RecurrentPPO])
@pytest.mark.parametrize("env", [IdentityEnv(DIM), IdentityEnvMultiDiscrete(DIM), IdentityEnvMultiBinary(DIM)])
def test_discrete(model_class, env):
env_ = DummyVecEnv([lambda: env])
kwargs: dict[str, Any] = {}
n_steps = 10000
@pytest.mark.parametrize(
"env_fn", [lambda: IdentityEnv(DIM), lambda: IdentityEnvMultiDiscrete(DIM), lambda: IdentityEnvMultiBinary(DIM)]
)
def test_discrete(model_class, env_fn):
# Use multiple envs so we can test that batching works correctly
env_ = DummyVecEnv([env_fn] * 4)
kwargs: dict[str, Any] = dict()
total_n_steps = 10000
if model_class == DQN:
kwargs = dict(learning_starts=0)
# DQN only support discrete actions
if isinstance(env, (IdentityEnvMultiDiscrete, IdentityEnvMultiBinary)):
if isinstance(env_.envs[0], (IdentityEnvMultiDiscrete, IdentityEnvMultiBinary)):
return

if model_class in (RecurrentPPO, PPO):
kwargs["target_kl"] = 0.02
kwargs["n_epochs"] = 30

if model_class == RecurrentPPO:
# Ensure that there's not an MLP on top of the LSTM that the default Policy creates.
kwargs["policy_kwargs"] = dict(net_arch=dict(vf=[], pi=[]))
kwargs["batch_time"] = 128
kwargs["batch_envs"] = 4
kwargs["n_steps"] = 256

model = model_class("MlpPolicy", env_, gamma=0.4, seed=3, **kwargs).learn(n_steps)
model = model_class("MlpPolicy", env_, gamma=0.4, seed=3, **kwargs).learn(total_n_steps)

evaluate_policy(model, env_, n_eval_episodes=20, reward_threshold=99, warn=False)
obs, _ = env.reset()
obs, _ = env_.envs[0].reset()

assert np.shape(model.predict(obs)[0]) == np.shape(obs)

Expand All @@ -93,7 +103,7 @@ def test_discrete(model_class, env):
def test_continuous(model_class):
env = IdentityEnvBox(eps=0.5)

n_steps = 2000 if issubclass(model_class, OnPolicyAlgorithm) else 400
total_n_steps = 2000 if issubclass(model_class, OnPolicyAlgorithm) else 400

kwargs: dict[str, Any] = dict(policy_kwargs=dict(net_arch=[64, 64]), seed=0, gamma=0.95)

Expand All @@ -109,6 +119,6 @@ def test_continuous(model_class):
# Ensure that there's not an MLP on top of the LSTM that the default Policy creates.
kwargs["policy_kwargs"]["net_arch"] = dict(vf=[], pi=[])

model = model_class("MlpPolicy", env, learning_rate=1e-3, **kwargs).learn(n_steps)
model = model_class("MlpPolicy", env, learning_rate=1e-3, **kwargs).learn(total_n_steps)

evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90, warn=False)
4 changes: 3 additions & 1 deletion tests/test_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ def make_env():

N_ENVS = 16
N_STEPS = 32
BATCH_TIME = 4
env = VecNormalize(make_vec_env(make_env, n_envs=N_ENVS))

eval_callback = EvalCallback(
Expand All @@ -282,7 +283,8 @@ def make_env():
n_steps=N_STEPS,
learning_rate=0.0007,
verbose=1,
batch_size=N_ENVS * N_STEPS,
batch_envs=N_ENVS,
batch_time=BATCH_TIME,
seed=1,
n_epochs=10,
max_grad_norm=1,
Expand Down

0 comments on commit fa71467

Please sign in to comment.