Skip to content

Commit a10cdbf

Browse files
[Feature] More flexibility in loading PettingZoo (#1817)
1 parent e98ee38 commit a10cdbf

File tree

1 file changed

+57
-4
lines changed

1 file changed

+57
-4
lines changed

torchrl/envs/libs/pettingzoo.py

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import copy
88
import importlib
9+
import warnings
910
from typing import Dict, List, Tuple, Union
1011

1112
import torch
@@ -27,11 +28,54 @@
2728
def _get_envs():
2829
if not _has_pettingzoo:
2930
raise ImportError("PettingZoo is not installed in your virtual environment.")
30-
from pettingzoo.utils.all_modules import all_environments
31+
try:
32+
from pettingzoo.utils.all_modules import all_environments
33+
except ModuleNotFoundError as err:
34+
warnings.warn(
35+
f"PettingZoo failed to load all modules with error message {err}, trying to load individual modules."
36+
)
37+
all_environments = _load_available_envs()
3138

3239
return list(all_environments.keys())
3340

3441

42+
def _load_available_envs() -> Dict:
43+
all_environments = {}
44+
try:
45+
from pettingzoo.mpe.all_modules import mpe_environments
46+
47+
all_environments.update(mpe_environments)
48+
except ModuleNotFoundError as err:
49+
warnings.warn(f"MPE environments failed to load with error message {err}.")
50+
try:
51+
from pettingzoo.sisl.all_modules import sisl_environments
52+
53+
all_environments.update(sisl_environments)
54+
except ModuleNotFoundError as err:
55+
warnings.warn(f"SISL environments failed to load with error message {err}.")
56+
try:
57+
from pettingzoo.classic.all_modules import classic_environments
58+
59+
all_environments.update(classic_environments)
60+
except ModuleNotFoundError as err:
61+
warnings.warn(f"Classic environments failed to load with error message {err}.")
62+
try:
63+
from pettingzoo.atari.all_modules import atari_environments
64+
65+
all_environments.update(atari_environments)
66+
except ModuleNotFoundError as err:
67+
warnings.warn(f"Atari environments failed to load with error message {err}.")
68+
try:
69+
from pettingzoo.butterfly.all_modules import butterfly_environments
70+
71+
all_environments.update(butterfly_environments)
72+
except ModuleNotFoundError as err:
73+
warnings.warn(
74+
f"Butterfly environments failed to load with error message {err}."
75+
)
76+
return all_environments
77+
78+
3579
class PettingZooWrapper(_EnvWrapper):
3680
"""PettingZoo environment wrapper.
3781
@@ -834,7 +878,8 @@ class PettingZooEnv(PettingZooWrapper):
834878
neural network.
835879
836880
Args:
837-
task (str): the name of the pettingzoo task to create (for example, "multiwalker_v9").
881+
task (str): the name of the pettingzoo task to create in the "<env>/<task>" format (for example, "sisl/multiwalker_v9")
882+
or "<task>" format (for example, "multiwalker_v9").
838883
parallel (bool): if to construct the ``pettingzoo.ParallelEnv`` version of the task or the ``pettingzoo.AECEnv``.
839884
return_state (bool, optional): whether to return the global state from pettingzoo
840885
(not available in all environments). Defaults to ``False``.
@@ -919,7 +964,13 @@ def _build_env(
919964
]:
920965
self.task_name = task
921966

922-
from pettingzoo.utils.all_modules import all_environments
967+
try:
968+
from pettingzoo.utils.all_modules import all_environments
969+
except ModuleNotFoundError as err:
970+
warnings.warn(
971+
f"PettingZoo failed to load all modules with error message {err}, trying to load individual modules."
972+
)
973+
all_environments = _load_available_envs()
923974

924975
if task not in all_environments:
925976
# Try looking at the literal translation of values
@@ -929,7 +980,9 @@ def _build_env(
929980
task_module = value
930981
break
931982
if task_module is None:
932-
raise RuntimeError(f"Specified task not in {_get_envs()}")
983+
raise RuntimeError(
984+
f"Specified task not in available environments {all_environments}"
985+
)
933986
else:
934987
task_module = all_environments[task]
935988

0 commit comments

Comments
 (0)