48
48
49
49
50
50
def _check_start (fun ):
51
- def decorated_fun (self : _BatchedEnv , * args , ** kwargs ):
51
+ def decorated_fun (self : BatchedEnvBase , * args , ** kwargs ):
52
52
if self .is_closed :
53
53
self ._create_td ()
54
54
self ._start_workers ()
@@ -121,7 +121,7 @@ def __call__(cls, *args, **kwargs):
121
121
return super ().__call__ (* args , ** kwargs )
122
122
123
123
124
- class _BatchedEnv (EnvBase ):
124
+ class BatchedEnvBase (EnvBase ):
125
125
"""Batched environments allow the user to query an arbitrary method / attribute of the environment running remotely.
126
126
127
127
Those queries will return a list of length equal to the number of workers containing the
@@ -169,6 +169,9 @@ class _BatchedEnv(EnvBase):
169
169
serial_for_single (bool, optional): if ``True``, creating a parallel environment
170
170
with a single worker will return a :class:`~SerialEnv` instead.
171
171
This option has no effect with :class:`~SerialEnv`. Defaults to ``False``.
172
+ non_blocking (bool, optional): if ``True``, device moves will be done using the
173
+ ``non_blocking=True`` option. Defaults to ``True`` for batched environments
174
+ on cuda devices, and ``False`` otherwise.
172
175
173
176
Examples:
174
177
>>> from torchrl.envs import GymEnv, ParallelEnv, SerialEnv, EnvCreator
@@ -179,8 +182,8 @@ class _BatchedEnv(EnvBase):
179
182
>>> env = ParallelEnv(2, [
180
183
... lambda: DMControlEnv("humanoid", "stand"),
181
184
... lambda: DMControlEnv("humanoid", "walk")]) # Creates two independent copies of Humanoid, one that walks one that stands
182
- >>> r = env.rollout(10) # executes 10 random steps in the environment
183
- >>> r [0] # data for Humanoid stand
185
+ >>> rollout = env.rollout(10) # executes 10 random steps in the environment
186
+ >>> rollout [0] # data for Humanoid stand
184
187
TensorDict(
185
188
fields={
186
189
action: Tensor(shape=torch.Size([10, 21]), device=cpu, dtype=torch.float64, is_shared=False),
@@ -211,7 +214,7 @@ class _BatchedEnv(EnvBase):
211
214
batch_size=torch.Size([10]),
212
215
device=cpu,
213
216
is_shared=False)
214
- >>> r [1] # data for Humanoid walk
217
+ >>> rollout [1] # data for Humanoid walk
215
218
TensorDict(
216
219
fields={
217
220
action: Tensor(shape=torch.Size([10, 21]), device=cpu, dtype=torch.float64, is_shared=False),
@@ -242,6 +245,7 @@ class _BatchedEnv(EnvBase):
242
245
batch_size=torch.Size([10]),
243
246
device=cpu,
244
247
is_shared=False)
248
+ >>> # serial_for_single to avoid creating parallel envs if not necessary
245
249
>>> env = ParallelEnv(1, make_env, serial_for_single=True)
246
250
>>> assert isinstance(env, SerialEnv) # serial_for_single allows you to avoid creating parallel envs when not necessary
247
251
"""
@@ -270,6 +274,7 @@ def __init__(
270
274
num_threads : int = None ,
271
275
num_sub_threads : int = 1 ,
272
276
serial_for_single : bool = False ,
277
+ non_blocking : bool = False ,
273
278
):
274
279
super ().__init__ (device = device )
275
280
self .serial_for_single = serial_for_single
@@ -327,6 +332,15 @@ def __init__(
327
332
# self._prepare_dummy_env(create_env_fn, create_env_kwargs)
328
333
self ._properties_set = False
329
334
self ._get_metadata (create_env_fn , create_env_kwargs )
335
+ self ._non_blocking = non_blocking
336
+
337
+ @property
338
+ def non_blocking (self ):
339
+ nb = self ._non_blocking
340
+ if nb is None :
341
+ nb = self .device is not None and self .device .type == "cuda"
342
+ self ._non_blocking = nb
343
+ return nb
330
344
331
345
def _get_metadata (
332
346
self , create_env_fn : List [Callable ], create_env_kwargs : List [Dict ]
@@ -654,6 +668,7 @@ def start(self) -> None:
654
668
self ._start_workers ()
655
669
656
670
def to (self , device : DEVICE_TYPING ):
671
+ self ._non_blocking = None
657
672
device = torch .device (device )
658
673
if device == self .device :
659
674
return self
@@ -675,10 +690,10 @@ def to(self, device: DEVICE_TYPING):
675
690
return self
676
691
677
692
678
- class SerialEnv (_BatchedEnv ):
693
+ class SerialEnv (BatchedEnvBase ):
679
694
"""Creates a series of environments in the same process."""
680
695
681
- __doc__ += _BatchedEnv .__doc__
696
+ __doc__ += BatchedEnvBase .__doc__
682
697
683
698
_share_memory = False
684
699
@@ -769,7 +784,9 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
769
784
else :
770
785
env_device = _env .device
771
786
if env_device != self .device and env_device is not None :
772
- tensordict_ = tensordict_ .to (env_device , non_blocking = True )
787
+ tensordict_ = tensordict_ .to (
788
+ env_device , non_blocking = self .non_blocking
789
+ )
773
790
else :
774
791
tensordict_ = tensordict_ .clone (False )
775
792
else :
@@ -798,7 +815,7 @@ def select_and_clone(name, tensor):
798
815
if device is None :
799
816
out = out .clear_device_ ()
800
817
else :
801
- out = out .to (device , non_blocking = True )
818
+ out = out .to (device , non_blocking = self . non_blocking )
802
819
return out
803
820
804
821
def _reset_proc_data (self , tensordict , tensordict_reset ):
@@ -819,7 +836,9 @@ def _step(
819
836
# There may be unexpected keys, such as "_reset", that we should comfortably ignore here.
820
837
env_device = self ._envs [i ].device
821
838
if env_device != self .device and env_device is not None :
822
- data_in = tensordict_in [i ].to (env_device , non_blocking = True )
839
+ data_in = tensordict_in [i ].to (
840
+ env_device , non_blocking = self .non_blocking
841
+ )
823
842
else :
824
843
data_in = tensordict_in [i ]
825
844
out_td = self ._envs [i ]._step (data_in )
@@ -839,7 +858,7 @@ def select_and_clone(name, tensor):
839
858
if device is None :
840
859
out = out .clear_device_ ()
841
860
elif out .device != device :
842
- out = out .to (device , non_blocking = True )
861
+ out = out .to (device , non_blocking = self . non_blocking )
843
862
return out
844
863
845
864
def __getattr__ (self , attr : str ) -> Any :
@@ -885,14 +904,14 @@ def to(self, device: DEVICE_TYPING):
885
904
return self
886
905
887
906
888
- class ParallelEnv (_BatchedEnv , metaclass = _PEnvMeta ):
907
+ class ParallelEnv (BatchedEnvBase , metaclass = _PEnvMeta ):
889
908
"""Creates one environment per process.
890
909
891
910
TensorDicts are passed via shared memory or memory map.
892
911
893
912
"""
894
913
895
- __doc__ += _BatchedEnv .__doc__
914
+ __doc__ += BatchedEnvBase .__doc__
896
915
__doc__ += """
897
916
898
917
.. warning::
@@ -1167,14 +1186,14 @@ def step_and_maybe_reset(
1167
1186
tensordict_ = tensordict_ .clone ()
1168
1187
elif device is not None :
1169
1188
next_td = next_td ._fast_apply (
1170
- lambda x : x .to (device , non_blocking = True )
1189
+ lambda x : x .to (device , non_blocking = self . non_blocking )
1171
1190
if x .device != device
1172
1191
else x .clone (),
1173
1192
device = device ,
1174
1193
filter_empty = True ,
1175
1194
)
1176
1195
tensordict_ = tensordict_ ._fast_apply (
1177
- lambda x : x .to (device , non_blocking = True )
1196
+ lambda x : x .to (device , non_blocking = self . non_blocking )
1178
1197
if x .device != device
1179
1198
else x .clone (),
1180
1199
device = device ,
@@ -1239,7 +1258,7 @@ def select_and_clone(name, tensor):
1239
1258
if device is None :
1240
1259
out .clear_device_ ()
1241
1260
else :
1242
- out = out .to (device , non_blocking = True )
1261
+ out = out .to (device , non_blocking = self . non_blocking )
1243
1262
return out
1244
1263
1245
1264
@_check_start
@@ -1325,7 +1344,7 @@ def select_and_clone(name, tensor):
1325
1344
if device is None :
1326
1345
out .clear_device_ ()
1327
1346
else :
1328
- out = out .to (device , non_blocking = True )
1347
+ out = out .to (device , non_blocking = self . non_blocking )
1329
1348
return out
1330
1349
1331
1350
@_check_start
@@ -1644,12 +1663,9 @@ def look_for_cuda(tensor, has_cuda=has_cuda):
1644
1663
child_pipe .send (("_" .join ([cmd , "done" ]), None ))
1645
1664
1646
1665
1647
- def _update_cuda (t_dest , t_source ):
1648
- if t_source is None :
1649
- return
1650
- t_dest .copy_ (t_source .pin_memory (), non_blocking = True )
1651
- return
1652
-
1653
-
1654
1666
def _filter_empty (tensordict ):
1655
1667
return tensordict .select (* tensordict .keys (True , True ))
1668
+
1669
+
1670
+ # Create an alias for possible imports
1671
+ _BatchedEnv = BatchedEnvBase
0 commit comments