Skip to content

Commit 77beeb9

Browse files
authored
[BugFix] Fix SyncDataCollector reset (#571)
1 parent 872dd2d commit 77beeb9

File tree

2 files changed

+30
-3
lines changed

2 files changed

+30
-3
lines changed

test/test_collector.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,16 @@
2525
MultiSyncDataCollector,
2626
RandomPolicy,
2727
)
28+
from torchrl.collectors.utils import split_trajectories
2829
from torchrl.data import (
2930
CompositeSpec,
3031
NdUnboundedContinuousTensorSpec,
3132
TensorDict,
3233
UnboundedContinuousTensorSpec,
3334
)
3435
from torchrl.data.tensordict.tensordict import assert_allclose_td
35-
from torchrl.envs import EnvCreator, ParallelEnv
36-
from torchrl.envs.libs.gym import _has_gym
36+
from torchrl.envs import EnvCreator, ParallelEnv, SerialEnv
37+
from torchrl.envs.libs.gym import _has_gym, GymEnv
3738
from torchrl.envs.transforms import TransformedEnv, VecNorm
3839
from torchrl.modules import (
3940
Actor,
@@ -295,6 +296,32 @@ def env_fn(seed):
295296
ccollector.shutdown()
296297

297298

299+
@pytest.mark.skipif(not _has_gym, reason="gym library is not installed")
300+
def test_collector_env_reset():
301+
torch.manual_seed(0)
302+
env = SerialEnv(2, lambda: GymEnv("ALE/Pong-v5", frame_skip=4))
303+
# env = SerialEnv(3, lambda: GymEnv("CartPole-v1", frame_skip=4))
304+
env.set_seed(0)
305+
collector = SyncDataCollector(
306+
env, total_frames=10000, frames_per_batch=10000, split_trajs=False
307+
)
308+
for data in collector:
309+
continue
310+
steps = data["step_count"][..., 1:, :]
311+
done = data["done"][..., :-1, :]
312+
# we don't want just one done
313+
assert done.sum() > 3
314+
# check that after a done, the next step count is always 1
315+
assert (steps[done] == 1).all()
316+
# check that if the env is not done, the next step count is > 1
317+
assert (steps[~done] > 1).all()
318+
# check that if step is 1, then the env was done before
319+
assert (steps == 1)[done].all()
320+
# check that split traj has a minimum total reward of -21 (for pong only)
321+
data = split_trajectories(data)
322+
assert data["reward"].sum(-2).min() == -21
323+
324+
298325
@pytest.mark.parametrize("num_env", [1, 3])
299326
@pytest.mark.parametrize("env_name", ["vec"])
300327
def test_collector_done_persist(num_env, env_name, seed=5):

torchrl/collectors/collectors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -585,7 +585,7 @@ def _reset_if_necessary(self) -> None:
585585
else:
586586
self._tensordict.zero_()
587587

588-
self._tensordict.update(self.env.reset(), inplace=True)
588+
self.env.reset(self._tensordict)
589589
if self._tensordict.get("done").any():
590590
raise RuntimeError(
591591
f"Got {sum(self._tensordict.get('done'))} done envs after reset."

0 commit comments

Comments
 (0)