|
25 | 25 | MultiSyncDataCollector,
|
26 | 26 | MultiaSyncDataCollector,
|
27 | 27 | )
|
| 28 | +from torchrl.data import ( |
| 29 | + CompositeSpec, |
| 30 | + NdUnboundedContinuousTensorSpec, |
| 31 | + UnboundedContinuousTensorSpec, |
| 32 | +) |
28 | 33 | from torchrl.data.tensordict.tensordict import assert_allclose_td
|
29 | 34 | from torchrl.envs import EnvCreator
|
30 | 35 | from torchrl.envs import ParallelEnv
|
31 | 36 | from torchrl.envs.libs.gym import _has_gym
|
32 | 37 | from torchrl.envs.transforms import TransformedEnv, VecNorm
|
| 38 | +from torchrl.modules import LSTMNet, TensorDictModule |
33 | 39 | from torchrl.modules import OrnsteinUhlenbeckProcessWrapper, Actor
|
34 | 40 |
|
35 | 41 | # torch.set_default_dtype(torch.double)
|
@@ -673,14 +679,22 @@ def test_collector_vecnorm_envcreator(static_seed):
|
673 | 679 |
|
674 | 680 |
|
675 | 681 | @pytest.mark.parametrize("use_async", [False, True])
|
676 |
| -@pytest.mark.skipif(torch.cuda.device_count() <= 1, reason="no cuda device found") |
| 682 | +@pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device found") |
677 | 683 | def test_update_weights(use_async):
|
678 |
| - policy = torch.nn.Linear(3, 4).cuda(1) |
| 684 | + def create_env(): |
| 685 | + return ContinuousActionVecMockEnv() |
| 686 | + |
| 687 | + n_actions = ContinuousActionVecMockEnv().action_spec.shape[-1] |
| 688 | + policy = TensorDictModule( |
| 689 | + torch.nn.LazyLinear(n_actions), in_keys=["observation"], out_keys=["action"] |
| 690 | + ) |
| 691 | + policy(create_env().reset()) |
| 692 | + |
679 | 693 | collector_class = (
|
680 | 694 | MultiSyncDataCollector if not use_async else MultiaSyncDataCollector
|
681 | 695 | )
|
682 | 696 | collector = collector_class(
|
683 |
| - [lambda: DiscreteActionVecMockEnv()] * 3, |
| 697 | + [create_env] * 3, |
684 | 698 | policy=policy,
|
685 | 699 | devices=[torch.device("cuda:0")] * 3,
|
686 | 700 | passing_devices=[torch.device("cuda:0")] * 3,
|
@@ -769,6 +783,89 @@ def make_env():
|
769 | 783 | dummy_env.close()
|
770 | 784 |
|
771 | 785 |
|
| 786 | +@pytest.mark.skipif(not _has_gym, reason="test designed with GymEnv") |
| 787 | +@pytest.mark.parametrize( |
| 788 | + "collector_class", |
| 789 | + [ |
| 790 | + SyncDataCollector, |
| 791 | + MultiaSyncDataCollector, |
| 792 | + MultiSyncDataCollector, |
| 793 | + ], |
| 794 | +) |
| 795 | +@pytest.mark.parametrize("init_random_frames", [0, 50]) |
| 796 | +@pytest.mark.parametrize("explicit_spec", [True, False]) |
| 797 | +def test_collector_output_keys(collector_class, init_random_frames, explicit_spec): |
| 798 | + from torchrl.envs.libs.gym import GymEnv |
| 799 | + |
| 800 | + out_features = 1 |
| 801 | + hidden_size = 12 |
| 802 | + total_frames = 200 |
| 803 | + frames_per_batch = 20 |
| 804 | + num_envs = 3 |
| 805 | + |
| 806 | + net = LSTMNet( |
| 807 | + out_features, |
| 808 | + {"input_size": hidden_size, "hidden_size": hidden_size}, |
| 809 | + {"out_features": hidden_size}, |
| 810 | + ) |
| 811 | + |
| 812 | + policy_kwargs = { |
| 813 | + "module": net, |
| 814 | + "in_keys": ["observation", "hidden1", "hidden2"], |
| 815 | + "out_keys": ["action", "hidden1", "hidden2", "next_hidden1", "next_hidden2"], |
| 816 | + } |
| 817 | + if explicit_spec: |
| 818 | + hidden_spec = NdUnboundedContinuousTensorSpec((1, hidden_size)) |
| 819 | + policy_kwargs["spec"] = CompositeSpec( |
| 820 | + action=UnboundedContinuousTensorSpec(), |
| 821 | + hidden1=hidden_spec, |
| 822 | + hidden2=hidden_spec, |
| 823 | + next_hidden1=hidden_spec, |
| 824 | + next_hidden2=hidden_spec, |
| 825 | + ) |
| 826 | + |
| 827 | + policy = TensorDictModule(**policy_kwargs) |
| 828 | + |
| 829 | + env_maker = lambda: GymEnv("Pendulum-v1") |
| 830 | + |
| 831 | + policy(env_maker().reset()) |
| 832 | + |
| 833 | + collector_kwargs = { |
| 834 | + "create_env_fn": env_maker, |
| 835 | + "policy": policy, |
| 836 | + "total_frames": total_frames, |
| 837 | + "frames_per_batch": frames_per_batch, |
| 838 | + "init_random_frames": init_random_frames, |
| 839 | + } |
| 840 | + |
| 841 | + if collector_class is not SyncDataCollector: |
| 842 | + collector_kwargs["create_env_fn"] = [ |
| 843 | + collector_kwargs["create_env_fn"] for _ in range(num_envs) |
| 844 | + ] |
| 845 | + |
| 846 | + collector = collector_class(**collector_kwargs) |
| 847 | + |
| 848 | + keys = [ |
| 849 | + "action", |
| 850 | + "done", |
| 851 | + "hidden1", |
| 852 | + "hidden2", |
| 853 | + "mask", |
| 854 | + "next_hidden1", |
| 855 | + "next_hidden2", |
| 856 | + "next_observation", |
| 857 | + "observation", |
| 858 | + "reward", |
| 859 | + "step_count", |
| 860 | + "traj_ids", |
| 861 | + ] |
| 862 | + b = next(iter(collector)) |
| 863 | + |
| 864 | + assert set(b.keys()) == set(keys) |
| 865 | + collector.shutdown() |
| 866 | + del collector |
| 867 | + |
| 868 | + |
772 | 869 | def weight_reset(m):
|
773 | 870 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
774 | 871 | m.reset_parameters()
|
|
0 commit comments