Skip to content

Commit 3e6cb84

Browse files
author
Vincent Moens
authored
[Quality] Remove global seeding in set_seed (#2195)
1 parent 2370d6e commit 3e6cb84

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

test/test_transforms.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9225,6 +9225,7 @@ def test_transform_compose(self):
92259225
)
92269226
def test_transform_env(self, out_key):
92279227
base_env = self.envclass()
9228+
torch.manual_seed(0)
92289229
actor = self._make_actor()
92299230
# we need to patch the env and create a sample_log_prob spec to make check_env_specs happy
92309231
env = TransformedEnv(
@@ -9234,6 +9235,7 @@ def test_transform_env(self, out_key):
92349235
KLRewardTransform(actor, out_keys=out_key),
92359236
),
92369237
)
9238+
torch.manual_seed(0)
92379239
actor = self._make_actor()
92389240
td1 = env.rollout(3, actor)
92399241
tdparams = TensorDict(dict(actor.named_parameters()), []).unflatten_keys(".")

torchrl/envs/batched_envs.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1877,8 +1877,7 @@ def look_for_cuda(tensor, has_cuda=has_cuda):
18771877
if cmd == "seed":
18781878
if not initialized:
18791879
raise RuntimeError("call 'init' before closing")
1880-
# torch.manual_seed(data)
1881-
# np.random.seed(data)
1880+
torch.manual_seed(data[0])
18821881
new_seed = env.set_seed(data[0], static_seed=data[1])
18831882
child_pipe.send(("seeded", new_seed))
18841883

torchrl/envs/common.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2218,17 +2218,16 @@ def set_seed(
22182218
"""Sets the seed of the environment and returns the next seed to be used (which is the input seed if a single environment is present).
22192219
22202220
Args:
2221-
seed (int): seed to be set
2221+
seed (int): seed to be set. The seed is set only locally in the environment. To handle the global seed,
2222+
see :func:`~torch.manual_seed`.
22222223
static_seed (bool, optional): if ``True``, the seed is not incremented.
22232224
Defaults to False
22242225
22252226
Returns:
22262227
integer representing the "next seed": i.e. the seed that should be
2227-
used for another environment if created concomittently to this environment.
2228+
used for another environment if created concomitantly to this environment.
22282229
22292230
"""
2230-
if seed is not None:
2231-
torch.manual_seed(seed)
22322231
self._set_seed(seed)
22332232
if seed is not None and not static_seed:
22342233
new_seed = seed_generator(seed)

0 commit comments

Comments
 (0)