File tree 2 files changed +8
-4
lines changed 2 files changed +8
-4
lines changed Original file line number Diff line number Diff line change @@ -577,7 +577,7 @@ def __init__(
577
577
reset_when_done : bool = True ,
578
578
interruptor = None ,
579
579
):
580
- from torchrl .envs .batched_envs import _BatchedEnv
580
+ from torchrl .envs .batched_envs import BatchedEnvBase
581
581
582
582
self .closed = True
583
583
@@ -591,7 +591,7 @@ def __init__(
591
591
else :
592
592
env = create_env_fn
593
593
if create_env_kwargs :
594
- if not isinstance (env , _BatchedEnv ):
594
+ if not isinstance (env , BatchedEnvBase ):
595
595
raise RuntimeError (
596
596
"kwargs were passed to SyncDataCollector but they can't be set "
597
597
f"on environment of type { type (create_env_fn )} ."
@@ -1201,11 +1201,11 @@ def state_dict(self) -> OrderedDict:
1201
1201
`"env_state_dict"`.
1202
1202
1203
1203
"""
1204
- from torchrl .envs .batched_envs import _BatchedEnv
1204
+ from torchrl .envs .batched_envs import BatchedEnvBase
1205
1205
1206
1206
if isinstance (self .env , TransformedEnv ):
1207
1207
env_state_dict = self .env .transform .state_dict ()
1208
- elif isinstance (self .env , _BatchedEnv ):
1208
+ elif isinstance (self .env , BatchedEnvBase ):
1209
1209
env_state_dict = self .env .state_dict ()
1210
1210
else :
1211
1211
env_state_dict = OrderedDict ()
Original file line number Diff line number Diff line change @@ -1665,3 +1665,7 @@ def look_for_cuda(tensor, has_cuda=has_cuda):
1665
1665
1666
1666
def _filter_empty (tensordict ):
1667
1667
return tensordict .select (* tensordict .keys (True , True ))
1668
+
1669
+
1670
+ # Create an alias for possible imports
1671
+ _BatchedEnv = BatchedEnvBase
You can’t perform that action at this time.
0 commit comments