Skip to content

Commit

Permalink
Explicit arguments for get()
Browse files Browse the repository at this point in the history
  • Loading branch information
rhaps0dy committed Feb 1, 2024
1 parent b2fd61b commit e027120
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 9 deletions.
10 changes: 3 additions & 7 deletions stable_baselines3/common/recurrent/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,15 +149,11 @@ def add(self, data: RecurrentRolloutBufferData, **kwargs) -> None: # type: igno
self.full = True

def get( # type: ignore[override]
self, batch_shape: Optional[tuple[int, int]] = None
self,
batch_time: int,
batch_envs: int,
) -> Generator[RecurrentRolloutBufferSamples, None, None]:
assert self.full, "Rollout buffer must be full before sampling from it"

# Return everything, don't create minibatches
if batch_shape is None:
batch_shape = (self.buffer_size, self.n_envs)
batch_time, batch_envs = batch_shape

if batch_envs >= self.n_envs:
for time_start in range(0, self.buffer_size, batch_time):
yield self._get_samples(slice(None), slice(time_start, time_start + batch_time))
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/ppo_recurrent/ppo_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def train(self) -> None:
# train for n_epochs epochs
for epoch in range(self.n_epochs):
# Do a complete pass on the rollout buffer
for rollout_data in self.rollout_buffer.get(batch_shape=(self.batch_time, self.batch_envs)):
for rollout_data in self.rollout_buffer.get(batch_time=self.batch_time, batch_envs=self.batch_envs):
actions = rollout_data.actions
if isinstance(self.action_space, spaces.Discrete):
actions = rollout_data.actions.squeeze(-1)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def test_device_buffer(replay_buffer_cls, device):
elif replay_buffer_cls in [ReplayBuffer, DictReplayBuffer]:
data = [buffer.sample(50)]
elif replay_buffer_cls == RecurrentRolloutBuffer:
data = buffer.get(EP_LENGTH)
data = buffer.get(batch_envs=env.num_envs // 2, batch_time=EP_LENGTH // 2)

# Check that all data are on the desired device
desired_device = get_device(device).type
Expand Down

0 comments on commit e027120

Please sign in to comment.