Skip to content

Commit 6f6c896

Browse files
author
Vincent Moens
authored
[BugFix] Adaptable non-blocking for mps and non cuda device in batched-envs (#1900)
1 parent 1647fa4 commit 6f6c896

File tree

2 files changed

+44
-28
lines changed

2 files changed

+44
-28
lines changed

torchrl/collectors/collectors.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,7 @@ def __init__(
577577
reset_when_done: bool = True,
578578
interruptor=None,
579579
):
580-
from torchrl.envs.batched_envs import _BatchedEnv
580+
from torchrl.envs.batched_envs import BatchedEnvBase
581581

582582
self.closed = True
583583

@@ -591,7 +591,7 @@ def __init__(
591591
else:
592592
env = create_env_fn
593593
if create_env_kwargs:
594-
if not isinstance(env, _BatchedEnv):
594+
if not isinstance(env, BatchedEnvBase):
595595
raise RuntimeError(
596596
"kwargs were passed to SyncDataCollector but they can't be set "
597597
f"on environment of type {type(create_env_fn)}."
@@ -1201,11 +1201,11 @@ def state_dict(self) -> OrderedDict:
12011201
`"env_state_dict"`.
12021202
12031203
"""
1204-
from torchrl.envs.batched_envs import _BatchedEnv
1204+
from torchrl.envs.batched_envs import BatchedEnvBase
12051205

12061206
if isinstance(self.env, TransformedEnv):
12071207
env_state_dict = self.env.transform.state_dict()
1208-
elif isinstance(self.env, _BatchedEnv):
1208+
elif isinstance(self.env, BatchedEnvBase):
12091209
env_state_dict = self.env.state_dict()
12101210
else:
12111211
env_state_dict = OrderedDict()

torchrl/envs/batched_envs.py

Lines changed: 40 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848

4949

5050
def _check_start(fun):
51-
def decorated_fun(self: _BatchedEnv, *args, **kwargs):
51+
def decorated_fun(self: BatchedEnvBase, *args, **kwargs):
5252
if self.is_closed:
5353
self._create_td()
5454
self._start_workers()
@@ -121,7 +121,7 @@ def __call__(cls, *args, **kwargs):
121121
return super().__call__(*args, **kwargs)
122122

123123

124-
class _BatchedEnv(EnvBase):
124+
class BatchedEnvBase(EnvBase):
125125
"""Batched environments allow the user to query an arbitrary method / attribute of the environment running remotely.
126126
127127
Those queries will return a list of length equal to the number of workers containing the
@@ -169,6 +169,9 @@ class _BatchedEnv(EnvBase):
169169
serial_for_single (bool, optional): if ``True``, creating a parallel environment
170170
with a single worker will return a :class:`~SerialEnv` instead.
171171
This option has no effect with :class:`~SerialEnv`. Defaults to ``False``.
172+
non_blocking (bool, optional): if ``True``, device moves will be done using the
173+
``non_blocking=True`` option. Defaults to ``True`` for batched environments
174+
on cuda devices, and ``False`` otherwise.
172175
173176
Examples:
174177
>>> from torchrl.envs import GymEnv, ParallelEnv, SerialEnv, EnvCreator
@@ -179,8 +182,8 @@ class _BatchedEnv(EnvBase):
179182
>>> env = ParallelEnv(2, [
180183
... lambda: DMControlEnv("humanoid", "stand"),
181184
... lambda: DMControlEnv("humanoid", "walk")]) # Creates two independent copies of Humanoid, one that walks one that stands
182-
>>> r = env.rollout(10) # executes 10 random steps in the environment
183-
>>> r[0] # data for Humanoid stand
185+
>>> rollout = env.rollout(10) # executes 10 random steps in the environment
186+
>>> rollout[0] # data for Humanoid stand
184187
TensorDict(
185188
fields={
186189
action: Tensor(shape=torch.Size([10, 21]), device=cpu, dtype=torch.float64, is_shared=False),
@@ -211,7 +214,7 @@ class _BatchedEnv(EnvBase):
211214
batch_size=torch.Size([10]),
212215
device=cpu,
213216
is_shared=False)
214-
>>> r[1] # data for Humanoid walk
217+
>>> rollout[1] # data for Humanoid walk
215218
TensorDict(
216219
fields={
217220
action: Tensor(shape=torch.Size([10, 21]), device=cpu, dtype=torch.float64, is_shared=False),
@@ -242,6 +245,7 @@ class _BatchedEnv(EnvBase):
242245
batch_size=torch.Size([10]),
243246
device=cpu,
244247
is_shared=False)
248+
>>> # serial_for_single to avoid creating parallel envs if not necessary
245249
>>> env = ParallelEnv(1, make_env, serial_for_single=True)
246250
>>> assert isinstance(env, SerialEnv) # serial_for_single allows you to avoid creating parallel envs when not necessary
247251
"""
@@ -270,6 +274,7 @@ def __init__(
270274
num_threads: int = None,
271275
num_sub_threads: int = 1,
272276
serial_for_single: bool = False,
277+
non_blocking: bool = False,
273278
):
274279
super().__init__(device=device)
275280
self.serial_for_single = serial_for_single
@@ -327,6 +332,15 @@ def __init__(
327332
# self._prepare_dummy_env(create_env_fn, create_env_kwargs)
328333
self._properties_set = False
329334
self._get_metadata(create_env_fn, create_env_kwargs)
335+
self._non_blocking = non_blocking
336+
337+
@property
338+
def non_blocking(self):
339+
nb = self._non_blocking
340+
if nb is None:
341+
nb = self.device is not None and self.device.type == "cuda"
342+
self._non_blocking = nb
343+
return nb
330344

331345
def _get_metadata(
332346
self, create_env_fn: List[Callable], create_env_kwargs: List[Dict]
@@ -654,6 +668,7 @@ def start(self) -> None:
654668
self._start_workers()
655669

656670
def to(self, device: DEVICE_TYPING):
671+
self._non_blocking = None
657672
device = torch.device(device)
658673
if device == self.device:
659674
return self
@@ -675,10 +690,10 @@ def to(self, device: DEVICE_TYPING):
675690
return self
676691

677692

678-
class SerialEnv(_BatchedEnv):
693+
class SerialEnv(BatchedEnvBase):
679694
"""Creates a series of environments in the same process."""
680695

681-
__doc__ += _BatchedEnv.__doc__
696+
__doc__ += BatchedEnvBase.__doc__
682697

683698
_share_memory = False
684699

@@ -769,7 +784,9 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
769784
else:
770785
env_device = _env.device
771786
if env_device != self.device and env_device is not None:
772-
tensordict_ = tensordict_.to(env_device, non_blocking=True)
787+
tensordict_ = tensordict_.to(
788+
env_device, non_blocking=self.non_blocking
789+
)
773790
else:
774791
tensordict_ = tensordict_.clone(False)
775792
else:
@@ -798,7 +815,7 @@ def select_and_clone(name, tensor):
798815
if device is None:
799816
out = out.clear_device_()
800817
else:
801-
out = out.to(device, non_blocking=True)
818+
out = out.to(device, non_blocking=self.non_blocking)
802819
return out
803820

804821
def _reset_proc_data(self, tensordict, tensordict_reset):
@@ -819,7 +836,9 @@ def _step(
819836
# There may be unexpected keys, such as "_reset", that we should comfortably ignore here.
820837
env_device = self._envs[i].device
821838
if env_device != self.device and env_device is not None:
822-
data_in = tensordict_in[i].to(env_device, non_blocking=True)
839+
data_in = tensordict_in[i].to(
840+
env_device, non_blocking=self.non_blocking
841+
)
823842
else:
824843
data_in = tensordict_in[i]
825844
out_td = self._envs[i]._step(data_in)
@@ -839,7 +858,7 @@ def select_and_clone(name, tensor):
839858
if device is None:
840859
out = out.clear_device_()
841860
elif out.device != device:
842-
out = out.to(device, non_blocking=True)
861+
out = out.to(device, non_blocking=self.non_blocking)
843862
return out
844863

845864
def __getattr__(self, attr: str) -> Any:
@@ -885,14 +904,14 @@ def to(self, device: DEVICE_TYPING):
885904
return self
886905

887906

888-
class ParallelEnv(_BatchedEnv, metaclass=_PEnvMeta):
907+
class ParallelEnv(BatchedEnvBase, metaclass=_PEnvMeta):
889908
"""Creates one environment per process.
890909
891910
TensorDicts are passed via shared memory or memory map.
892911
893912
"""
894913

895-
__doc__ += _BatchedEnv.__doc__
914+
__doc__ += BatchedEnvBase.__doc__
896915
__doc__ += """
897916
898917
.. warning::
@@ -1167,14 +1186,14 @@ def step_and_maybe_reset(
11671186
tensordict_ = tensordict_.clone()
11681187
elif device is not None:
11691188
next_td = next_td._fast_apply(
1170-
lambda x: x.to(device, non_blocking=True)
1189+
lambda x: x.to(device, non_blocking=self.non_blocking)
11711190
if x.device != device
11721191
else x.clone(),
11731192
device=device,
11741193
filter_empty=True,
11751194
)
11761195
tensordict_ = tensordict_._fast_apply(
1177-
lambda x: x.to(device, non_blocking=True)
1196+
lambda x: x.to(device, non_blocking=self.non_blocking)
11781197
if x.device != device
11791198
else x.clone(),
11801199
device=device,
@@ -1239,7 +1258,7 @@ def select_and_clone(name, tensor):
12391258
if device is None:
12401259
out.clear_device_()
12411260
else:
1242-
out = out.to(device, non_blocking=True)
1261+
out = out.to(device, non_blocking=self.non_blocking)
12431262
return out
12441263

12451264
@_check_start
@@ -1325,7 +1344,7 @@ def select_and_clone(name, tensor):
13251344
if device is None:
13261345
out.clear_device_()
13271346
else:
1328-
out = out.to(device, non_blocking=True)
1347+
out = out.to(device, non_blocking=self.non_blocking)
13291348
return out
13301349

13311350
@_check_start
@@ -1644,12 +1663,9 @@ def look_for_cuda(tensor, has_cuda=has_cuda):
16441663
child_pipe.send(("_".join([cmd, "done"]), None))
16451664

16461665

1647-
def _update_cuda(t_dest, t_source):
1648-
if t_source is None:
1649-
return
1650-
t_dest.copy_(t_source.pin_memory(), non_blocking=True)
1651-
return
1652-
1653-
16541666
def _filter_empty(tensordict):
16551667
return tensordict.select(*tensordict.keys(True, True))
1668+
1669+
1670+
# Create an alias for possible imports
1671+
_BatchedEnv = BatchedEnvBase

0 commit comments

Comments
 (0)