Skip to content

Commit f186e21

Browse files
author
Vincent Moens
committed
[BugFix] Test and fix life cycle of env with dynamic non-tensor spec
ghstack-source-id: 52c8624 Pull Request resolved: #2812
1 parent d4f8846 commit f186e21

File tree

10 files changed

+277
-28
lines changed

10 files changed

+277
-28
lines changed

test/mocking_classes.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import string
99
from typing import Dict, List, Optional
1010

11+
import numpy as np
12+
1113
import torch
1214
import torch.nn as nn
1315
from tensordict import tensorclass, TensorDict, TensorDictBase
@@ -26,6 +28,7 @@
2628
Unbounded,
2729
)
2830
from torchrl.data.utils import consolidate_spec
31+
from torchrl.envs import Transform
2932
from torchrl.envs.common import EnvBase
3033
from torchrl.envs.model_based.common import ModelBasedEnvBase
3134
from torchrl.envs.utils import (
@@ -34,7 +37,6 @@
3437
MarlGroupMapType,
3538
)
3639

37-
3840
spec_dict = {
3941
"bounded": Bounded,
4042
"one_hot": OneHot,
@@ -2395,3 +2397,69 @@ def _step(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
23952397
f1 + 1,
23962398
)
23972399
return td
2400+
2401+
2402+
@tensorclass
2403+
class History:
2404+
role: str
2405+
content: str
2406+
2407+
2408+
class HistoryTransform(Transform):
2409+
"""A mocking class to record history."""
2410+
2411+
def transform_observation_spec(self, observation_spec: Composite) -> Composite:
2412+
defaults = {
2413+
"role": NonTensor(
2414+
example_data="a role!",
2415+
shape=(-1,),
2416+
),
2417+
"content": NonTensor(
2418+
example_data="a content!",
2419+
shape=(-1,),
2420+
),
2421+
}
2422+
observation_spec["history"] = Composite(
2423+
defaults,
2424+
shape=(-1,),
2425+
data_cls=History,
2426+
)
2427+
assert observation_spec.device == self.parent.device
2428+
assert observation_spec["history"].device == self.parent.device
2429+
return observation_spec
2430+
2431+
def _reset(
2432+
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
2433+
) -> TensorDictBase:
2434+
assert tensordict_reset.device == self.parent.device
2435+
tensordict_reset["history"] = torch.stack(
2436+
[
2437+
History(role="system", content="0"),
2438+
History(role="user", content="1"),
2439+
]
2440+
)
2441+
assert tensordict_reset["history"].device == self.parent.device
2442+
return tensordict_reset
2443+
2444+
def _step(
2445+
self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
2446+
) -> TensorDictBase:
2447+
assert next_tensordict.device == self.parent.device
2448+
history = tensordict["history"]
2449+
local_history = History(
2450+
role=np.random.choice(["user", "system", "assistant"]),
2451+
content=str(int(history.content[-1]) + 1),
2452+
device=history.device,
2453+
)
2454+
# history = tensordict["history"].append(local_history)
2455+
try:
2456+
history = torch.stack(list(history.unbind(0)) + [local_history])
2457+
except Exception:
2458+
raise
2459+
assert isinstance(history, History)
2460+
next_tensordict["history"] = history
2461+
assert next_tensordict["history"].device == self.parent.device, (
2462+
next_tensordict["history"],
2463+
self.parent.device,
2464+
)
2465+
return next_tensordict

test/test_env.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
CatFrames,
4343
CatTensors,
4444
ChessEnv,
45+
ConditionalSkip,
4546
DoubleToFloat,
4647
EnvBase,
4748
EnvCreator,
@@ -72,6 +73,7 @@
7273
check_marl_grouping,
7374
make_composite_from_td,
7475
MarlGroupMapType,
76+
RandomPolicy,
7577
step_mdp,
7678
)
7779
from torchrl.modules import Actor, ActorCriticOperator, MLP, SafeModule, ValueOperator
@@ -134,6 +136,7 @@
134136
EnvWithTensorClass,
135137
HeterogeneousCountingEnv,
136138
HeterogeneousCountingEnvPolicy,
139+
HistoryTransform,
137140
MockBatchedLockedEnv,
138141
MockBatchedUnLockedEnv,
139142
MockSerialEnv,
@@ -174,6 +177,7 @@
174177
EnvWithTensorClass,
175178
HeterogeneousCountingEnv,
176179
HeterogeneousCountingEnvPolicy,
180+
HistoryTransform,
177181
MockBatchedLockedEnv,
178182
MockBatchedUnLockedEnv,
179183
MockSerialEnv,
@@ -4398,6 +4402,124 @@ def test_serial_partial_step_and_maybe_reset(self, use_buffers, device, env_devi
43984402
assert (td[3].get("next") != 0).any()
43994403

44004404

4405+
class TestEnvWithHistory:
4406+
@pytest.fixture(autouse=True, scope="class")
4407+
def set_capture(self):
4408+
with set_capture_non_tensor_stack(False), set_auto_unwrap_transformed_env(
4409+
False
4410+
):
4411+
yield
4412+
return
4413+
4414+
def _make_env(self, device, max_steps=10):
4415+
return CountingEnv(device=device, max_steps=max_steps).append_transform(
4416+
HistoryTransform()
4417+
)
4418+
4419+
def _make_skipping_env(self, device, max_steps=10):
4420+
env = self._make_env(device=device, max_steps=max_steps)
4421+
# skip every 3 steps
4422+
env = env.append_transform(
4423+
ConditionalSkip(lambda td: ((td["step_count"] % 3) == 2))
4424+
)
4425+
env = TransformedEnv(env, StepCounter())
4426+
return env
4427+
4428+
@pytest.mark.parametrize("device", [None, "cpu"])
4429+
def test_env_history_base(self, device):
4430+
env = self._make_env(device)
4431+
env.check_env_specs()
4432+
4433+
@pytest.mark.parametrize("device", [None, "cpu"])
4434+
def test_skipping_history_env(self, device):
4435+
env = self._make_skipping_env(device)
4436+
env.check_env_specs()
4437+
r = env.rollout(100)
4438+
4439+
@pytest.mark.parametrize("device_env", [None, "cpu"])
4440+
@pytest.mark.parametrize("device", [None, "cpu"])
4441+
@pytest.mark.parametrize("batch_cls", [SerialEnv, "parallel"])
4442+
@pytest.mark.parametrize("consolidate", [False, True])
4443+
def test_env_history_base_batched(
4444+
self, device, device_env, batch_cls, maybe_fork_ParallelEnv, consolidate
4445+
):
4446+
if batch_cls == "parallel":
4447+
batch_cls = maybe_fork_ParallelEnv
4448+
env = batch_cls(
4449+
2,
4450+
lambda: self._make_env(device_env),
4451+
device=device,
4452+
consolidate=consolidate,
4453+
)
4454+
try:
4455+
assert not env._use_buffers
4456+
env.check_env_specs(break_when_any_done="both")
4457+
finally:
4458+
env.close(raise_if_closed=False)
4459+
4460+
@pytest.mark.parametrize("device_env", [None, "cpu"])
4461+
@pytest.mark.parametrize("device", [None, "cpu"])
4462+
@pytest.mark.parametrize("batch_cls", [SerialEnv, "parallel"])
4463+
@pytest.mark.parametrize("consolidate", [False, True])
4464+
def test_skipping_history_env_batched(
4465+
self, device, device_env, batch_cls, maybe_fork_ParallelEnv, consolidate
4466+
):
4467+
if batch_cls == "parallel":
4468+
batch_cls = maybe_fork_ParallelEnv
4469+
env = batch_cls(
4470+
2,
4471+
lambda: self._make_skipping_env(device_env),
4472+
device=device,
4473+
consolidate=consolidate,
4474+
)
4475+
try:
4476+
env.check_env_specs()
4477+
finally:
4478+
env.close(raise_if_closed=False)
4479+
4480+
@pytest.mark.parametrize("device_env", [None, "cpu"])
4481+
@pytest.mark.parametrize("collector_cls", [SyncDataCollector])
4482+
def test_env_history_base_collector(self, device_env, collector_cls):
4483+
env = self._make_env(device_env)
4484+
collector = collector_cls(
4485+
env, RandomPolicy(env.full_action_spec), total_frames=35, frames_per_batch=5
4486+
)
4487+
for d in collector:
4488+
for i in range(d.shape[0] - 1):
4489+
assert (
4490+
d[i + 1]["history"].content[0] == d[i]["next", "history"].content[0]
4491+
)
4492+
4493+
@pytest.mark.parametrize("device_env", [None, "cpu"])
4494+
@pytest.mark.parametrize("collector_cls", [SyncDataCollector])
4495+
def test_skipping_history_env_collector(self, device_env, collector_cls):
4496+
env = self._make_skipping_env(device_env, max_steps=10)
4497+
collector = collector_cls(
4498+
env,
4499+
lambda td: td.update(env.full_action_spec.one()),
4500+
total_frames=35,
4501+
frames_per_batch=5,
4502+
)
4503+
length = None
4504+
count = 1
4505+
for d in collector:
4506+
for k in range(1, 5):
4507+
if len(d[k]["history"].content) == 2:
4508+
count = 1
4509+
continue
4510+
if count % 3 == 2:
4511+
assert (
4512+
d[k]["next", "history"].content
4513+
== d[k - 1]["next", "history"].content
4514+
), (d["next", "history"].content, k, count)
4515+
else:
4516+
assert d[k]["next", "history"].content[-1] == str(
4517+
int(d[k - 1]["next", "history"].content[-1]) + 1
4518+
), (d["next", "history"].content, k, count)
4519+
count += 1
4520+
count += 1
4521+
4522+
44014523
if __name__ == "__main__":
44024524
args, unknown = argparse.ArgumentParser().parse_known_args()
44034525
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

test/test_transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13496,7 +13496,7 @@ def check_non_tensor_match(self, td):
1349613496

1349713497
class ToString(Transform):
1349813498
def _apply_transform(self, obs: torch.Tensor) -> None:
13499-
return NonTensorData(str(obs), device=obs.device)
13499+
return NonTensorData(str(obs), device=self.parent.device)
1350013500

1350113501
def _reset(
1350213502
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase

torchrl/_utils.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -162,14 +162,20 @@ def erase():
162162
def _check_for_faulty_process(processes):
163163
terminate = False
164164
for p in processes:
165-
if not p.is_alive():
165+
if not p._closed and not p.is_alive():
166166
terminate = True
167167
for _p in processes:
168-
if _p.is_alive():
169-
_p.terminate()
170-
_p.close()
171-
if terminate:
172-
break
168+
_p: mp.Process
169+
if not _p._closed and _p.is_alive():
170+
try:
171+
_p.terminate()
172+
except Exception:
173+
_p.kill()
174+
finally:
175+
time.sleep(0.1)
176+
_p.close()
177+
if terminate:
178+
break
173179
if terminate:
174180
raise RuntimeError(
175181
"At least one process failed. Check for more infos in the log."

torchrl/collectors/collectors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1057,7 +1057,7 @@ def cuda_check(tensor: torch.Tensor):
10571057
# This may be a bit dangerous as `torch.device("cuda")` may not have a precise
10581058
# device associated, whereas `tensor.device` always has
10591059
for spec in self.env.specs.values(True, True):
1060-
if spec.device.type == "cuda":
1060+
if spec.device is not None and spec.device.type == "cuda":
10611061
if ":" not in str(spec.device):
10621062
raise RuntimeError(
10631063
"A cuda spec did not have a device associated. Make sure to "

torchrl/data/tensor_specs.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2525,7 +2525,6 @@ def __init__(
25252525
if isinstance(shape, int):
25262526
shape = _size([shape])
25272527

2528-
_, device = _default_dtype_and_device(None, device)
25292528
domain = None
25302529
super().__init__(
25312530
shape=shape, space=None, device=device, dtype=dtype, domain=domain, **kwargs

0 commit comments

Comments
 (0)