|
42 | 42 | CatFrames, |
43 | 43 | CatTensors, |
44 | 44 | ChessEnv, |
| 45 | + ConditionalSkip, |
45 | 46 | DoubleToFloat, |
46 | 47 | EnvBase, |
47 | 48 | EnvCreator, |
|
72 | 73 | check_marl_grouping, |
73 | 74 | make_composite_from_td, |
74 | 75 | MarlGroupMapType, |
| 76 | + RandomPolicy, |
75 | 77 | step_mdp, |
76 | 78 | ) |
77 | 79 | from torchrl.modules import Actor, ActorCriticOperator, MLP, SafeModule, ValueOperator |
|
134 | 136 | EnvWithTensorClass, |
135 | 137 | HeterogeneousCountingEnv, |
136 | 138 | HeterogeneousCountingEnvPolicy, |
| 139 | + HistoryTransform, |
137 | 140 | MockBatchedLockedEnv, |
138 | 141 | MockBatchedUnLockedEnv, |
139 | 142 | MockSerialEnv, |
|
174 | 177 | EnvWithTensorClass, |
175 | 178 | HeterogeneousCountingEnv, |
176 | 179 | HeterogeneousCountingEnvPolicy, |
| 180 | + HistoryTransform, |
177 | 181 | MockBatchedLockedEnv, |
178 | 182 | MockBatchedUnLockedEnv, |
179 | 183 | MockSerialEnv, |
@@ -4398,6 +4402,124 @@ def test_serial_partial_step_and_maybe_reset(self, use_buffers, device, env_devi |
4398 | 4402 | assert (td[3].get("next") != 0).any() |
4399 | 4403 |
|
4400 | 4404 |
|
| 4405 | +class TestEnvWithHistory: |
| 4406 | + @pytest.fixture(autouse=True, scope="class") |
| 4407 | + def set_capture(self): |
| 4408 | + with set_capture_non_tensor_stack(False), set_auto_unwrap_transformed_env( |
| 4409 | + False |
| 4410 | + ): |
| 4411 | + yield |
| 4412 | + return |
| 4413 | + |
| 4414 | + def _make_env(self, device, max_steps=10): |
| 4415 | + return CountingEnv(device=device, max_steps=max_steps).append_transform( |
| 4416 | + HistoryTransform() |
| 4417 | + ) |
| 4418 | + |
| 4419 | + def _make_skipping_env(self, device, max_steps=10): |
| 4420 | + env = self._make_env(device=device, max_steps=max_steps) |
| 4421 | + # skip every 3 steps |
| 4422 | + env = env.append_transform( |
| 4423 | + ConditionalSkip(lambda td: ((td["step_count"] % 3) == 2)) |
| 4424 | + ) |
| 4425 | + env = TransformedEnv(env, StepCounter()) |
| 4426 | + return env |
| 4427 | + |
| 4428 | + @pytest.mark.parametrize("device", [None, "cpu"]) |
| 4429 | + def test_env_history_base(self, device): |
| 4430 | + env = self._make_env(device) |
| 4431 | + env.check_env_specs() |
| 4432 | + |
| 4433 | + @pytest.mark.parametrize("device", [None, "cpu"]) |
| 4434 | + def test_skipping_history_env(self, device): |
| 4435 | + env = self._make_skipping_env(device) |
| 4436 | + env.check_env_specs() |
| 4437 | + r = env.rollout(100) |
| 4438 | + |
| 4439 | + @pytest.mark.parametrize("device_env", [None, "cpu"]) |
| 4440 | + @pytest.mark.parametrize("device", [None, "cpu"]) |
| 4441 | + @pytest.mark.parametrize("batch_cls", [SerialEnv, "parallel"]) |
| 4442 | + @pytest.mark.parametrize("consolidate", [False, True]) |
| 4443 | + def test_env_history_base_batched( |
| 4444 | + self, device, device_env, batch_cls, maybe_fork_ParallelEnv, consolidate |
| 4445 | + ): |
| 4446 | + if batch_cls == "parallel": |
| 4447 | + batch_cls = maybe_fork_ParallelEnv |
| 4448 | + env = batch_cls( |
| 4449 | + 2, |
| 4450 | + lambda: self._make_env(device_env), |
| 4451 | + device=device, |
| 4452 | + consolidate=consolidate, |
| 4453 | + ) |
| 4454 | + try: |
| 4455 | + assert not env._use_buffers |
| 4456 | + env.check_env_specs(break_when_any_done="both") |
| 4457 | + finally: |
| 4458 | + env.close(raise_if_closed=False) |
| 4459 | + |
| 4460 | + @pytest.mark.parametrize("device_env", [None, "cpu"]) |
| 4461 | + @pytest.mark.parametrize("device", [None, "cpu"]) |
| 4462 | + @pytest.mark.parametrize("batch_cls", [SerialEnv, "parallel"]) |
| 4463 | + @pytest.mark.parametrize("consolidate", [False, True]) |
| 4464 | + def test_skipping_history_env_batched( |
| 4465 | + self, device, device_env, batch_cls, maybe_fork_ParallelEnv, consolidate |
| 4466 | + ): |
| 4467 | + if batch_cls == "parallel": |
| 4468 | + batch_cls = maybe_fork_ParallelEnv |
| 4469 | + env = batch_cls( |
| 4470 | + 2, |
| 4471 | + lambda: self._make_skipping_env(device_env), |
| 4472 | + device=device, |
| 4473 | + consolidate=consolidate, |
| 4474 | + ) |
| 4475 | + try: |
| 4476 | + env.check_env_specs() |
| 4477 | + finally: |
| 4478 | + env.close(raise_if_closed=False) |
| 4479 | + |
| 4480 | + @pytest.mark.parametrize("device_env", [None, "cpu"]) |
| 4481 | + @pytest.mark.parametrize("collector_cls", [SyncDataCollector]) |
| 4482 | + def test_env_history_base_collector(self, device_env, collector_cls): |
| 4483 | + env = self._make_env(device_env) |
| 4484 | + collector = collector_cls( |
| 4485 | + env, RandomPolicy(env.full_action_spec), total_frames=35, frames_per_batch=5 |
| 4486 | + ) |
| 4487 | + for d in collector: |
| 4488 | + for i in range(d.shape[0] - 1): |
| 4489 | + assert ( |
| 4490 | + d[i + 1]["history"].content[0] == d[i]["next", "history"].content[0] |
| 4491 | + ) |
| 4492 | + |
| 4493 | + @pytest.mark.parametrize("device_env", [None, "cpu"]) |
| 4494 | + @pytest.mark.parametrize("collector_cls", [SyncDataCollector]) |
| 4495 | + def test_skipping_history_env_collector(self, device_env, collector_cls): |
| 4496 | + env = self._make_skipping_env(device_env, max_steps=10) |
| 4497 | + collector = collector_cls( |
| 4498 | + env, |
| 4499 | + lambda td: td.update(env.full_action_spec.one()), |
| 4500 | + total_frames=35, |
| 4501 | + frames_per_batch=5, |
| 4502 | + ) |
| 4503 | + length = None |
| 4504 | + count = 1 |
| 4505 | + for d in collector: |
| 4506 | + for k in range(1, 5): |
| 4507 | + if len(d[k]["history"].content) == 2: |
| 4508 | + count = 1 |
| 4509 | + continue |
| 4510 | + if count % 3 == 2: |
| 4511 | + assert ( |
| 4512 | + d[k]["next", "history"].content |
| 4513 | + == d[k - 1]["next", "history"].content |
| 4514 | + ), (d["next", "history"].content, k, count) |
| 4515 | + else: |
| 4516 | + assert d[k]["next", "history"].content[-1] == str( |
| 4517 | + int(d[k - 1]["next", "history"].content[-1]) + 1 |
| 4518 | + ), (d["next", "history"].content, k, count) |
| 4519 | + count += 1 |
| 4520 | + count += 1 |
| 4521 | + |
| 4522 | + |
4401 | 4523 | if __name__ == "__main__": |
4402 | 4524 | args, unknown = argparse.ArgumentParser().parse_known_args() |
4403 | 4525 | pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) |
0 commit comments