Skip to content

Commit 3b6b559

Browse files
author
Vincent Moens
committed
[Test] Add PEnv tests for devices
ghstack-source-id: 3653d73 Pull Request resolved: #2843
1 parent 1813e8e commit 3b6b559

File tree

2 files changed

+42
-11
lines changed

2 files changed

+42
-11
lines changed

test/_utils_internal.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -278,21 +278,23 @@ def _make_envs(
278278
transformed_in,
279279
transformed_out,
280280
N,
281-
device="cpu",
281+
p_env_device=None,
282+
env_device=None,
283+
# device="cpu",
282284
kwargs=None,
283285
local_mp_ctx=mp_ctx,
284286
):
285287
torch.manual_seed(0)
286288
if not transformed_in:
287289

288290
def create_env_fn():
289-
return GymEnv(env_name, frame_skip=frame_skip, device=device)
291+
return GymEnv(env_name, frame_skip=frame_skip, device=env_device)
290292

291293
else:
292294
if env_name == PONG_VERSIONED():
293295

294296
def create_env_fn():
295-
base_env = GymEnv(env_name, frame_skip=frame_skip, device=device)
297+
base_env = GymEnv(env_name, frame_skip=frame_skip, device=env_device)
296298
in_keys = list(base_env.observation_spec.keys(True, True))[:1]
297299
return TransformedEnv(
298300
base_env,
@@ -303,7 +305,7 @@ def create_env_fn():
303305

304306
def create_env_fn():
305307

306-
base_env = GymEnv(env_name, frame_skip=frame_skip, device=device)
308+
base_env = GymEnv(env_name, frame_skip=frame_skip, device=env_device)
307309
in_keys = list(base_env.observation_spec.keys(True, True))[:1]
308310

309311
return TransformedEnv(
@@ -316,9 +318,15 @@ def create_env_fn():
316318

317319
env0 = create_env_fn()
318320
env_parallel = ParallelEnv(
319-
N, create_env_fn, create_env_kwargs=kwargs, mp_start_method=local_mp_ctx
321+
N,
322+
create_env_fn,
323+
create_env_kwargs=kwargs,
324+
mp_start_method=local_mp_ctx,
325+
device=p_env_device,
326+
)
327+
env_serial = SerialEnv(
328+
N, create_env_fn, create_env_kwargs=kwargs, device=p_env_device
320329
)
321-
env_serial = SerialEnv(N, create_env_fn, create_env_kwargs=kwargs)
322330

323331
for key in env0.observation_spec.keys(True, True):
324332
obs_key = key

test/test_env.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1471,12 +1471,29 @@ def make_env():
14711471
"transformed_in,transformed_out", [[True, True], [False, False]]
14721472
) # 1226: effociency
14731473
@pytest.mark.parametrize("static_seed", [False, True])
1474+
@pytest.mark.parametrize("penv_device", ["cpu", None])
1475+
@pytest.mark.parametrize("env_device", ["cpu", None])
1476+
@pytest.mark.parametrize("bwad", [True, False])
14741477
def test_parallel_env_seed(
1475-
self, env_name, frame_skip, transformed_in, transformed_out, static_seed
1478+
self,
1479+
env_name,
1480+
frame_skip,
1481+
transformed_in,
1482+
transformed_out,
1483+
static_seed,
1484+
penv_device,
1485+
env_device,
1486+
bwad,
14761487
):
14771488
env_name = env_name()
14781489
env_parallel, env_serial, _, _ = _make_envs(
1479-
env_name, frame_skip, transformed_in, transformed_out, 5
1490+
env_name,
1491+
frame_skip,
1492+
transformed_in,
1493+
transformed_out,
1494+
5,
1495+
p_env_device=penv_device,
1496+
env_device=env_device,
14801497
)
14811498
try:
14821499
out_seed_serial = env_serial.set_seed(0, static_seed=static_seed)
@@ -1486,7 +1503,10 @@ def test_parallel_env_seed(
14861503
torch.manual_seed(0)
14871504

14881505
td_serial = env_serial.rollout(
1489-
max_steps=10, auto_reset=False, tensordict=td0_serial
1506+
max_steps=10,
1507+
auto_reset=False,
1508+
tensordict=td0_serial,
1509+
break_when_any_done=bwad,
14901510
).contiguous()
14911511
key = "pixels" if "pixels" in td_serial.keys() else "observation"
14921512
torch.testing.assert_close(
@@ -1501,7 +1521,10 @@ def test_parallel_env_seed(
15011521
torch.manual_seed(0)
15021522
assert out_seed_parallel == out_seed_serial
15031523
td_parallel = env_parallel.rollout(
1504-
max_steps=10, auto_reset=False, tensordict=td0_parallel
1524+
max_steps=10,
1525+
auto_reset=False,
1526+
tensordict=td0_parallel,
1527+
break_when_any_done=bwad,
15051528
).contiguous()
15061529
torch.testing.assert_close(
15071530
td_parallel[:, :-1].get(("next", key)), td_parallel[:, 1:].get(key)
@@ -1677,7 +1700,7 @@ def test_parallel_env_device(
16771700
frame_skip,
16781701
transformed_in=transformed_in,
16791702
transformed_out=transformed_out,
1680-
device=device,
1703+
env_device=device,
16811704
N=N,
16821705
local_mp_ctx="spawn",
16831706
)

0 commit comments

Comments
 (0)