|
25 | 25 | MultiSyncDataCollector,
|
26 | 26 | RandomPolicy,
|
27 | 27 | )
|
| 28 | +from torchrl.collectors.utils import split_trajectories |
28 | 29 | from torchrl.data import (
|
29 | 30 | CompositeSpec,
|
30 | 31 | NdUnboundedContinuousTensorSpec,
|
31 | 32 | TensorDict,
|
32 | 33 | UnboundedContinuousTensorSpec,
|
33 | 34 | )
|
34 | 35 | from torchrl.data.tensordict.tensordict import assert_allclose_td
|
35 |
| -from torchrl.envs import EnvCreator, ParallelEnv |
36 |
| -from torchrl.envs.libs.gym import _has_gym |
| 36 | +from torchrl.envs import EnvCreator, ParallelEnv, SerialEnv |
| 37 | +from torchrl.envs.libs.gym import _has_gym, GymEnv |
37 | 38 | from torchrl.envs.transforms import TransformedEnv, VecNorm
|
38 | 39 | from torchrl.modules import (
|
39 | 40 | Actor,
|
@@ -295,6 +296,32 @@ def env_fn(seed):
|
295 | 296 | ccollector.shutdown()
|
296 | 297 |
|
297 | 298 |
|
| 299 | +@pytest.mark.skipif(not _has_gym, reason="gym library is not installed") |
| 300 | +def test_collector_env_reset(): |
| 301 | + torch.manual_seed(0) |
| 302 | + env = SerialEnv(2, lambda: GymEnv("ALE/Pong-v5", frame_skip=4)) |
| 303 | + # env = SerialEnv(3, lambda: GymEnv("CartPole-v1", frame_skip=4)) |
| 304 | + env.set_seed(0) |
| 305 | + collector = SyncDataCollector( |
| 306 | + env, total_frames=10000, frames_per_batch=10000, split_trajs=False |
| 307 | + ) |
| 308 | + for data in collector: |
| 309 | + continue |
| 310 | + steps = data["step_count"][..., 1:, :] |
| 311 | + done = data["done"][..., :-1, :] |
| 312 | + # we don't want just one done |
| 313 | + assert done.sum() > 3 |
| 314 | + # check that after a done, the next step count is always 1 |
| 315 | + assert (steps[done] == 1).all() |
| 316 | + # check that if the env is not done, the next step count is > 1 |
| 317 | + assert (steps[~done] > 1).all() |
| 318 | + # check that if step is 1, then the env was done before |
| 319 | + assert (steps == 1)[done].all() |
| 320 | + # check that split traj has a minimum total reward of -21 (for pong only) |
| 321 | + data = split_trajectories(data) |
| 322 | + assert data["reward"].sum(-2).min() == -21 |
| 323 | + |
| 324 | + |
298 | 325 | @pytest.mark.parametrize("num_env", [1, 3])
|
299 | 326 | @pytest.mark.parametrize("env_name", ["vec"])
|
300 | 327 | def test_collector_done_persist(num_env, env_name, seed=5):
|
|
0 commit comments