Skip to content

Commit f5fa9de

Browse files
Vincent Moensalbertbou92
Vincent Moens
authored andcommitted
[BugFix] Fix envpool (pytorch#1530)
1 parent acf3510 commit f5fa9de

File tree

9 files changed

+348
-320
lines changed

9 files changed

+348
-320
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ And it is `functorch` and `torch.compile` compatible!
265265
- A common [interface for environments](torchrl/envs)
266266
which supports common libraries (OpenAI gym, deepmind control lab, etc.)<sup>(1)</sup> and state-less execution
267267
(e.g. Model-based environments).
268-
The [batched environments](torchrl/envs/vec_env.py) containers allow parallel execution<sup>(2)</sup>.
268+
The [batched environments](torchrl/envs/batched_envs.py) containers allow parallel execution<sup>(2)</sup>.
269269
A common PyTorch-first class of [tensor-specification class](torchrl/data/tensor_specs.py) is also provided.
270270
TorchRL's environments API is simple but stringent and specific. Check the
271271
[documentation](https://pytorch.org/rl/reference/envs.html)

test/_utils_internal.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,16 @@
2020
from tensordict import tensorclass
2121
from torchrl._utils import implement_for, seed_generator
2222

23-
from torchrl.envs import ObservationNorm
23+
from torchrl.envs import MultiThreadedEnv, ObservationNorm
24+
from torchrl.envs.batched_envs import ParallelEnv, SerialEnv
25+
from torchrl.envs.libs.envpool import _has_envpool
2426
from torchrl.envs.libs.gym import _has_gym, GymEnv
2527
from torchrl.envs.transforms import (
2628
Compose,
2729
RewardClipping,
2830
ToTensorImage,
2931
TransformedEnv,
3032
)
31-
from torchrl.envs.vec_env import _has_envpool, MultiThreadedEnv, ParallelEnv, SerialEnv
3233

3334
# Specified for test_utils.py
3435
__version__ = "0.3"

test/test_libs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,10 @@
4949
ParallelEnv,
5050
RenameTransform,
5151
)
52+
from torchrl.envs.batched_envs import SerialEnv
5253
from torchrl.envs.libs.brax import _has_brax, BraxEnv
5354
from torchrl.envs.libs.dm_control import _has_dmc, DMControlEnv, DMControlWrapper
55+
from torchrl.envs.libs.envpool import _has_envpool, MultiThreadedEnvWrapper
5456
from torchrl.envs.libs.gym import (
5557
_has_gym,
5658
_is_from_pixels,
@@ -66,7 +68,6 @@
6668
from torchrl.envs.libs.robohive import RoboHiveEnv
6769
from torchrl.envs.libs.vmas import _has_vmas, VmasEnv, VmasWrapper
6870
from torchrl.envs.utils import check_env_specs, ExplorationType, MarlGroupMapType
69-
from torchrl.envs.vec_env import _has_envpool, MultiThreadedEnvWrapper, SerialEnv
7071
from torchrl.modules import ActorCriticOperator, MLP, SafeModule, ValueOperator
7172

7273
_has_d4rl = importlib.util.find_spec("d4rl") is not None

torchrl/collectors/collectors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from torchrl.collectors.utils import split_trajectories
3939
from torchrl.data.tensor_specs import CompositeSpec, TensorSpec
4040
from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING
41+
from torchrl.envs.batched_envs import _BatchedEnv
4142
from torchrl.envs.common import EnvBase
4243
from torchrl.envs.transforms import StepCounter, TransformedEnv
4344
from torchrl.envs.utils import (
@@ -47,7 +48,6 @@
4748
set_exploration_type,
4849
step_mdp,
4950
)
50-
from torchrl.envs.vec_env import _BatchedEnv
5151

5252
_TIMEOUT = 1.0
5353
_MIN_TIMEOUT = 1e-3 # should be several orders of magnitude inferior wrt time spent collecting a trajectory

torchrl/envs/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from .batched_envs import ParallelEnv, SerialEnv
67
from .common import EnvBase, EnvMetaData, make_tensordict
78
from .env_creator import EnvCreator, get_env_metadata
89
from .gym_like import default_info_dict_reader, GymLikeEnv
10+
from .libs.envpool import MultiThreadedEnv
911
from .model_based import ModelBasedEnvBase
1012
from .transforms import (
1113
ActionMask,
@@ -66,4 +68,3 @@
6668
set_exploration_type,
6769
step_mdp,
6870
)
69-
from .vec_env import MultiThreadedEnv, ParallelEnv, SerialEnv

0 commit comments

Comments
 (0)