Skip to content

Commit

Permalink
Allow more than 1 step to think when evaluating policy
Browse files Browse the repository at this point in the history
  • Loading branch information
rhaps0dy committed Feb 19, 2024
1 parent 2722c1f commit 4c11731
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 5 deletions.
15 changes: 13 additions & 2 deletions stable_baselines3/common/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def evaluate_policy(
reward_threshold: Optional[float] = None,
return_episode_rewards: bool = False,
warn: bool = True,
n_steps_to_think: int = 0,
) -> Union[Tuple[float, float], Tuple[List[float], List[int]]]:
"""
Runs policy for ``n_eval_episodes`` episodes and returns average reward.
Expand Down Expand Up @@ -87,15 +88,25 @@ def evaluate_policy(

# Hardcode episode counts and the reward accumulators to use CPU. They're used for bookkeeping and don't involve
# much computation.

episode_counts = th.zeros(n_envs, dtype=th.int64, device="cpu")
# Divides episodes among different sub environments in the vector as evenly as possible
episode_count_targets = th.tensor([(n_eval_episodes + i) // n_envs for i in range(n_envs)], dtype=th.int64, device="cpu")

episode_starts_yes = th.ones((env.num_envs,), dtype=th.bool, device=model.device)
episode_starts_no = th.zeros((env.num_envs,), dtype=th.bool, device=model.device)
states = None
for step_i in range(n_steps_to_think):
assert (episode_count_targets <= 1).all(), "If episodes count several times the steps to think won't be accurate."
_, states = model.predict(
observations, # type: ignore[arg-type]
state=states,
episode_start=(episode_starts_yes if step_i == 0 else episode_starts_no),
deterministic=deterministic,
)

current_rewards = th.zeros(n_envs, dtype=th.float32, device="cpu")
current_lengths = th.zeros(n_envs, dtype=th.int64, device="cpu")
episode_starts = th.ones((env.num_envs,), dtype=th.bool, device=model.device)
episode_starts = episode_starts_yes
while (episode_counts < episode_count_targets).any():
with th.no_grad():
actions, states = model.predict(
Expand Down
8 changes: 5 additions & 3 deletions tests/test_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,10 +236,12 @@ def test_run_sde_recurrent_extractor():
),
],
)
def test_dict_obs(policy_kwargs):
env = make_vec_env("CartPole-v1", n_envs=1, wrapper_class=ToDictWrapper)
@pytest.mark.parametrize("n_steps_to_think", [0, 1, 4])
def test_dict_obs(policy_kwargs, n_steps_to_think):
N_ENVS = 10
env = make_vec_env("CartPole-v1", n_envs=N_ENVS, wrapper_class=ToDictWrapper)
model = RecurrentPPO("MultiInputLstmPolicy", env, n_steps=32, policy_kwargs=policy_kwargs).learn(64)
evaluate_policy(model, env, warn=False)
evaluate_policy(model, env, n_eval_episodes=N_ENVS, warn=False, n_steps_to_think=n_steps_to_think)


def test_dict_obs_recurrent_extractor():
Expand Down

0 comments on commit 4c11731

Please sign in to comment.