Skip to content

Commit b2eb1b8

Browse files
tcbegleyvmoens
andauthored
[BugFix] Restore missing keys in data collector output (#521)
* Ensure data collectors return all expected keys * Rerun CI * Add tests * Format code * correct unreachable test * Fix broken test * WIP: fix initialisation with policy + test * Fix initialisation with policy + test * Reset env after rollout initialisation * fix build from spec * Check policy has spec attribute before accessing * Address comments Co-authored-by: vmoens <vincentmoens@gmail.com>
1 parent bd0120e commit b2eb1b8

File tree

2 files changed

+156
-7
lines changed

2 files changed

+156
-7
lines changed

test/test_collector.py

Lines changed: 100 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,17 @@
2525
MultiSyncDataCollector,
2626
MultiaSyncDataCollector,
2727
)
28+
from torchrl.data import (
29+
CompositeSpec,
30+
NdUnboundedContinuousTensorSpec,
31+
UnboundedContinuousTensorSpec,
32+
)
2833
from torchrl.data.tensordict.tensordict import assert_allclose_td
2934
from torchrl.envs import EnvCreator
3035
from torchrl.envs import ParallelEnv
3136
from torchrl.envs.libs.gym import _has_gym
3237
from torchrl.envs.transforms import TransformedEnv, VecNorm
38+
from torchrl.modules import LSTMNet, TensorDictModule
3339
from torchrl.modules import OrnsteinUhlenbeckProcessWrapper, Actor
3440

3541
# torch.set_default_dtype(torch.double)
@@ -673,14 +679,22 @@ def test_collector_vecnorm_envcreator(static_seed):
673679

674680

675681
@pytest.mark.parametrize("use_async", [False, True])
676-
@pytest.mark.skipif(torch.cuda.device_count() <= 1, reason="no cuda device found")
682+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device found")
677683
def test_update_weights(use_async):
678-
policy = torch.nn.Linear(3, 4).cuda(1)
684+
def create_env():
685+
return ContinuousActionVecMockEnv()
686+
687+
n_actions = ContinuousActionVecMockEnv().action_spec.shape[-1]
688+
policy = TensorDictModule(
689+
torch.nn.LazyLinear(n_actions), in_keys=["observation"], out_keys=["action"]
690+
)
691+
policy(create_env().reset())
692+
679693
collector_class = (
680694
MultiSyncDataCollector if not use_async else MultiaSyncDataCollector
681695
)
682696
collector = collector_class(
683-
[lambda: DiscreteActionVecMockEnv()] * 3,
697+
[create_env] * 3,
684698
policy=policy,
685699
devices=[torch.device("cuda:0")] * 3,
686700
passing_devices=[torch.device("cuda:0")] * 3,
@@ -769,6 +783,89 @@ def make_env():
769783
dummy_env.close()
770784

771785

786+
@pytest.mark.skipif(not _has_gym, reason="test designed with GymEnv")
787+
@pytest.mark.parametrize(
788+
"collector_class",
789+
[
790+
SyncDataCollector,
791+
MultiaSyncDataCollector,
792+
MultiSyncDataCollector,
793+
],
794+
)
795+
@pytest.mark.parametrize("init_random_frames", [0, 50])
796+
@pytest.mark.parametrize("explicit_spec", [True, False])
797+
def test_collector_output_keys(collector_class, init_random_frames, explicit_spec):
798+
from torchrl.envs.libs.gym import GymEnv
799+
800+
out_features = 1
801+
hidden_size = 12
802+
total_frames = 200
803+
frames_per_batch = 20
804+
num_envs = 3
805+
806+
net = LSTMNet(
807+
out_features,
808+
{"input_size": hidden_size, "hidden_size": hidden_size},
809+
{"out_features": hidden_size},
810+
)
811+
812+
policy_kwargs = {
813+
"module": net,
814+
"in_keys": ["observation", "hidden1", "hidden2"],
815+
"out_keys": ["action", "hidden1", "hidden2", "next_hidden1", "next_hidden2"],
816+
}
817+
if explicit_spec:
818+
hidden_spec = NdUnboundedContinuousTensorSpec((1, hidden_size))
819+
policy_kwargs["spec"] = CompositeSpec(
820+
action=UnboundedContinuousTensorSpec(),
821+
hidden1=hidden_spec,
822+
hidden2=hidden_spec,
823+
next_hidden1=hidden_spec,
824+
next_hidden2=hidden_spec,
825+
)
826+
827+
policy = TensorDictModule(**policy_kwargs)
828+
829+
env_maker = lambda: GymEnv("Pendulum-v1")
830+
831+
policy(env_maker().reset())
832+
833+
collector_kwargs = {
834+
"create_env_fn": env_maker,
835+
"policy": policy,
836+
"total_frames": total_frames,
837+
"frames_per_batch": frames_per_batch,
838+
"init_random_frames": init_random_frames,
839+
}
840+
841+
if collector_class is not SyncDataCollector:
842+
collector_kwargs["create_env_fn"] = [
843+
collector_kwargs["create_env_fn"] for _ in range(num_envs)
844+
]
845+
846+
collector = collector_class(**collector_kwargs)
847+
848+
keys = [
849+
"action",
850+
"done",
851+
"hidden1",
852+
"hidden2",
853+
"mask",
854+
"next_hidden1",
855+
"next_hidden2",
856+
"next_observation",
857+
"observation",
858+
"reward",
859+
"step_count",
860+
"traj_ids",
861+
]
862+
b = next(iter(collector))
863+
864+
assert set(b.keys()) == set(keys)
865+
collector.shutdown()
866+
del collector
867+
868+
772869
def weight_reset(m):
773870
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
774871
m.reset_parameters()

torchrl/collectors/collectors.py

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -322,10 +322,62 @@ def __init__(
322322
self._tensordict.set(
323323
"step_count", torch.zeros(*self.env.batch_size, 1, dtype=torch.int)
324324
)
325-
self._tensordict_out = TensorDict(
326-
{},
327-
batch_size=[*self.env.batch_size, self.frames_per_batch],
328-
device=self.passing_device,
325+
326+
if (
327+
hasattr(policy, "spec")
328+
and policy.spec is not None
329+
and all(v is not None for v in policy.spec.values())
330+
and set(policy.spec.keys()) == set(policy.out_keys)
331+
):
332+
# if policy spec is non-empty, all the values are not None and the keys
333+
# match the out_keys we assume the user has given all relevant information
334+
self._tensordict_out = TensorDict(
335+
{
336+
**env.observation_spec.zero(env.batch_size),
337+
"reward": env.reward_spec.zero(env.batch_size),
338+
"done": torch.zeros(
339+
env.batch_size, dtype=torch.bool, device=env.device
340+
),
341+
**policy.spec.zero(env.batch_size),
342+
},
343+
env.batch_size,
344+
device=env.device,
345+
)
346+
self._tensordict_out = (
347+
self._tensordict_out.unsqueeze(-1)
348+
.expand(*env.batch_size, self.frames_per_batch)
349+
.to_tensordict()
350+
)
351+
self._tensordict_out = self._tensordict_out.update(
352+
step_mdp(self._tensordict_out)
353+
) # add "observation" when there is "next_observation"
354+
else:
355+
# otherwise, we perform a small number of steps with the policy to
356+
# determine the relevant keys with which to pre-populate _tensordict_out.
357+
# See #505 for additional context.
358+
self._tensordict_out = env.rollout(3, policy)
359+
if env.batch_size:
360+
self._tensordict_out = self._tensordict_out[..., :1]
361+
else:
362+
self._tensordict_out = self._tensordict_out[:1]
363+
self._tensordict_out = (
364+
self._tensordict_out.expand(*env.batch_size, self.frames_per_batch)
365+
.to_tensordict()
366+
.zero_()
367+
.detach()
368+
)
369+
env.reset()
370+
371+
# in addition to outputs of the policy, we add traj_ids and step_count to
372+
# _tensordict_out which will be collected during rollout
373+
if len(self.env.batch_size):
374+
traj_ids = torch.zeros(*self._tensordict_out.batch_size, 1)
375+
else:
376+
traj_ids = torch.zeros(*self._tensordict_out.batch_size, 1, 1)
377+
378+
self._tensordict_out.set("traj_ids", traj_ids)
379+
self._tensordict_out.set(
380+
"step_count", torch.zeros(*self._tensordict_out.batch_size, 1)
329381
)
330382

331383
self.return_in_place = return_in_place

0 commit comments

Comments
 (0)