Skip to content

Commit

Permalink
No need to detach buffers -- they already don't have gradients
Browse files Browse the repository at this point in the history
This reverts commit c8deb63.
  • Loading branch information
rhaps0dy committed Feb 1, 2024
1 parent 4f7f1d4 commit 6319f74
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 26 deletions.
50 changes: 25 additions & 25 deletions stable_baselines3/common/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,22 +287,22 @@ def add(

# Copy to avoid modification by reference
assert isinstance(self.observations, th.Tensor)
self.observations[self.pos].copy_(obs.detach(), non_blocking=True)
self.observations[self.pos].copy_(obs, non_blocking=True)

if self.optimize_memory_usage:
self.observations[(self.pos + 1) % self.buffer_size].copy_(next_obs.detach(), non_blocking=True)
self.observations[(self.pos + 1) % self.buffer_size].copy_(next_obs, non_blocking=True)
else:
assert isinstance(self.next_observations, th.Tensor)
self.next_observations[self.pos].copy_(next_obs.detach(), non_blocking=True)
self.next_observations[self.pos].copy_(next_obs, non_blocking=True)

self.actions[self.pos].copy_(action.detach(), non_blocking=True)
self.rewards[self.pos].copy_(reward.detach(), non_blocking=True)
self.dones[self.pos].copy_(done.detach(), non_blocking=True)
self.actions[self.pos].copy_(action, non_blocking=True)
self.rewards[self.pos].copy_(reward, non_blocking=True)
self.dones[self.pos].copy_(done, non_blocking=True)

if self.handle_timeout_termination:
for i, info in enumerate(infos):
self.timeouts[self.pos, i].copy_(
th.as_tensor(info.get("TimeLimit.truncated", False)).squeeze().detach(), non_blocking=True
th.as_tensor(info.get("TimeLimit.truncated", False)).squeeze(), non_blocking=True
)

self.pos += 1
Expand Down Expand Up @@ -498,12 +498,12 @@ def add(
action = action.reshape((self.n_envs, self.action_dim))

assert isinstance(self.observations, th.Tensor)
self.observations[self.pos].copy_(obs.detach(), non_blocking=True)
self.actions[self.pos].copy_(action.detach(), non_blocking=True)
self.rewards[self.pos].copy_(reward.detach(), non_blocking=True)
self.episode_starts[self.pos].copy_(episode_start.detach(), non_blocking=True)
self.values[self.pos].copy_(value.flatten().detach(), non_blocking=True)
self.log_probs[self.pos].copy_(log_prob.detach(), non_blocking=True)
self.observations[self.pos].copy_(obs, non_blocking=True)
self.actions[self.pos].copy_(action, non_blocking=True)
self.rewards[self.pos].copy_(reward, non_blocking=True)
self.episode_starts[self.pos].copy_(episode_start, non_blocking=True)
self.values[self.pos].copy_(value.flatten(), non_blocking=True)
self.log_probs[self.pos].copy_(log_prob, non_blocking=True)
self.pos += 1
if self.pos == self.buffer_size:
self.full = True
Expand Down Expand Up @@ -670,24 +670,24 @@ def add(
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
if isinstance(self.observation_space.spaces[key], spaces.Discrete):
obs[key] = obs[key].reshape((self.n_envs,) + self.obs_shape[key])
self.observations[key][self.pos].copy_(obs[key].detach(), non_blocking=True)
self.observations[key][self.pos].copy_(obs[key], non_blocking=True)

for key in self.next_observations.keys():
if isinstance(self.observation_space.spaces[key], spaces.Discrete):
next_obs[key] = next_obs[key].reshape((self.n_envs,) + self.obs_shape[key])
self.next_observations[key][self.pos].copy_(next_obs[key].detach(), non_blocking=True)
self.next_observations[key][self.pos].copy_(next_obs[key], non_blocking=True)

# Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
action = action.reshape((self.n_envs, self.action_dim))

self.actions[self.pos].copy_(th.as_tensor(action).detach(), non_blocking=True)
self.rewards[self.pos].copy_(reward.detach(), non_blocking=True)
self.dones[self.pos].copy_(done.detach(), non_blocking=True)
self.actions[self.pos].copy_(th.as_tensor(action), non_blocking=True)
self.rewards[self.pos].copy_(reward, non_blocking=True)
self.dones[self.pos].copy_(done, non_blocking=True)

if self.handle_timeout_termination:
for i, info in enumerate(infos):
self.timeouts[self.pos, i].copy_(
th.as_tensor(info.get("TimeLimit.truncated", False)).squeeze().detach(), non_blocking=True
th.as_tensor(info.get("TimeLimit.truncated", False)).squeeze(), non_blocking=True
)

self.pos += 1
Expand Down Expand Up @@ -848,16 +848,16 @@ def add(
# as torch cannot broadcast (n_discrete,) to (n_discrete, 1)
if isinstance(self.observation_space.spaces[key], spaces.Discrete):
obs_ = obs_.reshape((self.n_envs,) + self.obs_shape[key])
self.observations[key][self.pos].copy_(obs_.detach(), non_blocking=True)
self.observations[key][self.pos].copy_(obs_, non_blocking=True)

# Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
action = action.reshape((self.n_envs, self.action_dim))

self.actions[self.pos].copy_(action.detach(), non_blocking=True)
self.rewards[self.pos].copy_(reward.detach(), non_blocking=True)
self.episode_starts[self.pos].copy_(episode_start.detach(), non_blocking=True)
self.values[self.pos].copy_(value.flatten().detach(), non_blocking=True)
self.log_probs[self.pos].copy_(log_prob.detach(), non_blocking=True)
self.actions[self.pos].copy_(action, non_blocking=True)
self.rewards[self.pos].copy_(reward, non_blocking=True)
self.episode_starts[self.pos].copy_(episode_start, non_blocking=True)
self.values[self.pos].copy_(value.flatten(), non_blocking=True)
self.log_probs[self.pos].copy_(log_prob, non_blocking=True)
self.pos += 1
if self.pos == self.buffer_size:
self.full = True
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/recurrent/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def add(self, data: RecurrentRolloutBufferData, **kwargs) -> None: # type: igno
)

tree_map(
lambda buf, x: buf[self.pos].copy_((x if x.ndim + 1 == buf.ndim else x.unsqueeze(-1)).detach(), non_blocking=True),
lambda buf, x: buf[self.pos].copy_(x if x.ndim + 1 == buf.ndim else x.unsqueeze(-1), non_blocking=True),
self.data,
new_data,
)
Expand Down

0 comments on commit 6319f74

Please sign in to comment.