Skip to content

Commit

Permalink
Fix pytype and lstm_states indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
rhaps0dy committed Feb 27, 2024
1 parent 09cdebc commit b46fce8
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
4 changes: 3 additions & 1 deletion stable_baselines3/common/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,9 @@ def _log_success_callback(self, locals_: Dict[str, Any], globals_: Dict[str, Any
if maybe_is_success is not None:
self._is_success_buffer.append(maybe_is_success)

def _evaluate_policy(self) -> Union[Tuple[float, float], Tuple[List[float], List[int]]]:
def _evaluate_policy(
self,
) -> Union[Tuple[float, float], Tuple[List[float], List[int]]]: # pytype: disable=bad-return-type
return evaluate_policy(
self.model,
self.eval_env,
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/ppo_recurrent/ppo_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def think_for_n_steps(
return lstm_states
# ignore because TorchGymObs and TensorTree do not match
obs_for_start_envs: TorchGymObs = tree_index(obs_tensor, (episode_starts,)) # type: ignore[type-var]
lstm_states_for_start_envs = tree_index(lstm_states, (episode_starts,))
lstm_states_for_start_envs = tree_index(lstm_states, (slice(None), episode_starts))
for _ in range(n_steps):
_, _, _, lstm_states_for_start_envs = self.policy.forward(
obs_for_start_envs,
Expand Down

0 comments on commit b46fce8

Please sign in to comment.