Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/collect equal episode num in all envs #1127

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from
Prev Previous commit
Next Next commit
Refactor input validation of collector
  • Loading branch information
bordeauxred committed Apr 2, 2024
commit b5a0a98af5725637af8b5e2546e16213bf05f01b
2 changes: 1 addition & 1 deletion test/base/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def test_collector() -> None:
Collector(policy, dummy_venv_4_envs, ReplayBuffer(10))
with pytest.raises(TypeError):
Collector(policy, dummy_venv_4_envs, PrioritizedReplayBuffer(10, 0.5, 0.5))
with pytest.raises(TypeError):
with pytest.raises(ValueError):
c_dummy_venv_4_envs.collect()

# test NXEnv
Expand Down
91 changes: 66 additions & 25 deletions tianshou/data/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,60 @@ def _compute_action_policy_hidden(
)
return act_RA, act_normalized_RA, policy_R, hidden_state_RH

def _validate_collect_input_and_get_ready_env_ids(
self,
n_episode: int | None,
n_step: int | None,
sample_equal_num_episodes_per_worker: bool,
) -> np.ndarray:
"""Check that exactly one of n_step or n_episode is specified.
Returns the idx of non-idle envs that will be used for the collection.
"""
if n_step is not None and n_episode is not None:
raise ValueError(
f"Only one of n_step or n_episode is allowed in Collector."
f"collect, got {n_step=}, {n_episode=}.",
)

if n_step is not None:
if sample_equal_num_episodes_per_worker:
raise ValueError(
"sample_equal_num_episodes_per_worker can only be used if `n_episode` is specified but"
"got `n_step` instead.",
)
if n_step < 1:
raise ValueError(f"n_step should be an integer larger than 0, but got {n_step=}.")

if n_step % self.env_num:
warnings.warn(
f"{n_step=} is not a multiple of ({self.env_num=}). "
"This may cause extra transitions to be collected into the buffer.",
)
return np.arange(self.env_num)

elif n_episode is not None:
if n_episode < 1:
raise ValueError(
f"{n_episode=} should be an integer larger than 0.",
)
if n_episode < self.env_num:
warnings.warn(
f"{n_episode=} should be larger than or equal to {self.env_num=} "
f"(otherwise you will get idle workers and won't collect at"
f"least one trajectory in each env).",
)
if sample_equal_num_episodes_per_worker and n_episode % self.env_num != 0:
raise ValueError(
f"{n_episode=} must be a multiple of {self.env_num=} "
f"when using {sample_equal_num_episodes_per_worker=}.",
)
return np.arange(min(self.env_num, n_episode))

else:
raise ValueError(
f"At least one of {n_step=} and {n_episode=} should be specified as int larger than 0.",
)

# TODO: reduce complexity, remove the noqa
def collect(
self,
Expand All @@ -318,6 +372,7 @@ def collect(
no_grad: bool = True,
reset_before_collect: bool = False,
gym_reset_kwargs: dict[str, Any] | None = None,
sample_equal_num_episodes_per_worker: bool = False,
) -> CollectStats:
"""Collect a specified number of steps or episodes.

Expand All @@ -337,6 +392,8 @@ def collect(
(The collector needs the initial obs and info to function properly.)
:param gym_reset_kwargs: extra keyword arguments to pass into the environment's
reset function. Only used if reset_before_collect is True.
:param sample_equal_num_episodes_per_worker: whether to sample the same number
of episodes from each worker. Only used if n_episode is set.

.. note::

Expand Down Expand Up @@ -364,31 +421,12 @@ def collect(

# Input validation
assert not self.env.is_async, "Please use AsyncCollector if using async venv."
if n_step is not None:
assert n_episode is None, (
f"Only one of n_step or n_episode is allowed in Collector."
f"collect, got {n_step=}, {n_episode=}."
)
assert n_step > 0
if n_step % self.env_num != 0:
warnings.warn(
f"{n_step=} is not a multiple of ({self.env_num=}), "
"which may cause extra transitions being collected into the buffer.",
)
ready_env_ids_R = np.arange(self.env_num)
elif n_episode is not None:
assert n_episode > 0
if self.env_num > n_episode:
warnings.warn(
f"{n_episode=} should be larger than {self.env_num=} to "
f"collect at least one trajectory in each environment.",
)
ready_env_ids_R = np.arange(min(self.env_num, n_episode))
else:
raise TypeError(
"Please specify at least one (either n_step or n_episode) "
"in AsyncCollector.collect().",
)

ready_env_ids_R = self._validate_collect_input_and_get_ready_env_ids(
n_episode,
n_step,
sample_equal_num_episodes_per_worker=False,
)

start_time = time.time()

Expand Down Expand Up @@ -668,6 +706,7 @@ def collect(
no_grad: bool = True,
reset_before_collect: bool = False,
gym_reset_kwargs: dict[str, Any] | None = None,
sample_equal_num_episodes_per_worker: bool = False,
) -> CollectStats:
"""Collect a specified number of steps or episodes with async env setting.

Expand All @@ -689,6 +728,8 @@ def collect(
(The collector needs the initial obs and info to function properly.)
:param gym_reset_kwargs: extra keyword arguments to pass into the environment's
reset function. Defaults to None (extra keyword arguments)
:param sample_equal_num_episodes_per_worker: Not applicable to async collector.
#todo this is only used to keep the signatures of collect in Collector and AsyncCollector the same, maybe introduce some base class with collect as abstract method?

.. note::

Expand Down