Skip to content

Commit 1559a4a

Browse files
committed
amend
1 parent 324b8d7 commit 1559a4a

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

torchrl/collectors/collectors.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,7 @@ def __init__(
577577
reset_when_done: bool = True,
578578
interruptor=None,
579579
):
580-
from torchrl.envs.batched_envs import _BatchedEnv
580+
from torchrl.envs.batched_envs import BatchedEnvBase
581581

582582
self.closed = True
583583

@@ -591,7 +591,7 @@ def __init__(
591591
else:
592592
env = create_env_fn
593593
if create_env_kwargs:
594-
if not isinstance(env, _BatchedEnv):
594+
if not isinstance(env, BatchedEnvBase):
595595
raise RuntimeError(
596596
"kwargs were passed to SyncDataCollector but they can't be set "
597597
f"on environment of type {type(create_env_fn)}."
@@ -1201,11 +1201,11 @@ def state_dict(self) -> OrderedDict:
12011201
`"env_state_dict"`.
12021202
12031203
"""
1204-
from torchrl.envs.batched_envs import _BatchedEnv
1204+
from torchrl.envs.batched_envs import BatchedEnvBase
12051205

12061206
if isinstance(self.env, TransformedEnv):
12071207
env_state_dict = self.env.transform.state_dict()
1208-
elif isinstance(self.env, _BatchedEnv):
1208+
elif isinstance(self.env, BatchedEnvBase):
12091209
env_state_dict = self.env.state_dict()
12101210
else:
12111211
env_state_dict = OrderedDict()

torchrl/envs/batched_envs.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1665,3 +1665,7 @@ def look_for_cuda(tensor, has_cuda=has_cuda):
16651665

16661666
def _filter_empty(tensordict):
16671667
return tensordict.select(*tensordict.keys(True, True))
1668+
1669+
1670+
# Create an alias for possible imports
1671+
_BatchedEnv = BatchedEnvBase

0 commit comments

Comments
 (0)