Description
Bug description
I was testing out the imitation learning library with a custom gym environment and ran into a shortcoming in imitation/util/util.py. I get the error message provided below.
`Traceback (most recent call last):
File "/home/anaconda3/envs/env/lib/python3.8/site-packages/gymnasium/envs/registration.py", line 802, in make
env = env_creator(**env_spec_kwargs)
TypeError: init() missing 1 required positional argument: 'config'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "runner.py", line 118, in
main()
File "runner.py", line 58, in main
env = make_vec_env(
File "/home/anaconda3/envs/educagent_libtraffic/lib/python3.8/site-packages/imitation/util/util.py", line 117, in make_vec_env
tmp_env = gym.make(env_name)
File "/home/anaconda3/envs/env/lib/python3.8/site-packages/gymnasium/envs/registration.py", line 814, in make
raise type(e)(
TypeError: init() missing 1 required positional argument: 'config' was raised from the environment creator for custom/mycustom-env with kwargs ({})
`
Below is an example of how I define, register the custom gym environment, and pass along the config dictionary to the env instance,
`
Define your config dictionary
config = {
"parameters": "example",
"render_mode": "human",
}
env_name="custom/mycustom-env"
gym.register(
id=env_name,
entry_point=MyCustomEnv,
max_episode_steps=500,
)
env = make_vec_env(
env_name=env_name,
n_envs=1,
post_wrappers=[
lambda env, _: RolloutInfoWrapper(env)
], # needed for computing rollouts later
env_make_kwargs={"config": config},
)
I can get it to work with the way the imitation library has implemented it's gym.make instance by doing the following instead:
# Define a factory function that returns the custom environment class with the desired configuration
def make_traffic_env(**kwargs):
return MyCustomEnv(config=config, **kwargs)
env_name="custom/mycustom-env"
gym.register(
id=env_name,
entry_point=Tmake_traffic_env,
max_episode_steps=500,
)
env = make_vec_env(
env_name=env_name,
n_envs=1,
post_wrappers=[
lambda env, _: RolloutInfoWrapper(env)
], # needed for computing rollouts later
)
`
This occurs because the following variable is defined as follows:
tmp_env = gym.make(env_name)
As implemented here:
https://github.com/HumanCompatibleAI/imitation/blob/master/src/imitation/util/util.py#L117
instead of:
tmp_env = gym.make(env_name, **env_make_kwargs)
to pass along the kwargs as defined in,
https://gymnasium.farama.org/api/registry/#gymnasium.make
By implementing the suggestion above, it will help users avoid having to define factory function, and simplifying the steps.
Environment
- Operating system and version: Ubuntu 22.04.4 LTS
- Python version: Python 3.8.19
- Output of
pip freeze --all
:
absl-py==0.15.0
aiohttp==3.9.5
aiosignal==1.2.0
ale-py==0.8.1
alembic==1.13.1
anyio==3.7.1
argon2-cffi==21.3.0
argon2-cffi-bindings==21.2.0
asttokens==2.4.1
astunparse==1.6.3
async-generator==1.10
async-timeout==4.0.3
attrs==22.2.0
AutoROM==0.4.2
AutoROM.accept-rom-license==0.6.1
backcall==0.2.0
beautifulsoup4==4.12.3
black==22.6.0
bleach==4.1.0
bottle==0.12.25
cached-property==1.5.2
cachetools==4.2.4
certifi==2021.5.30
cffi==1.15.1
cfgv==3.4.0
charset-normalizer==2.0.12
clang==5.0
click==8.0.4
cloudpickle==2.2.1
codecov==2.1.13
codespell==2.1.0
colorama==0.4.5
colorlog==6.8.2
comm==0.1.4
commonmark==0.9.1
conan==1.60.1
coverage==6.4.4
cycler==0.11.0
darglint==1.8.1
datasets==2.19.1
debugpy==1.8.1
decorator==4.4.2
defusedxml==0.7.1
Deprecated==1.2.13
dill==0.3.8
distlib==0.3.7
distro==1.8.0
dm-tree==0.1.8
docker-pycreds==0.4.0
docopt==0.6.2
docstring-parser==0.13
entrypoints==0.4
exceptiongroup==1.2.1
execnet==2.1.1
executing==2.0.1
Farama-Notifications==0.0.4
fasteners==0.19
fastjsonschema==2.19.1
filelock==3.7.1
fire==0.4.0
flake8==4.0.1
flake8-blind-except==0.2.1
flake8-builtins==1.5.3
flake8-commas==2.1.0
flake8-debugger==4.1.2
flake8-docstrings==1.6.0
flake8-isort==4.1.2.post0
flatbuffers==1.12
fonttools==4.34.4
frozenlist==1.2.0
fsspec==2024.3.1
gast==0.3.3
gitdb==4.0.11
GitPython==3.1.43
google-api-python-client==1.12.8
google-auth==2.29.0
google-auth-httplib2==0.1.0
google-auth-oauthlib==0.4.6
google-crc32c==1.3.0
google-pasta==0.2.0
greenlet==3.0.3
grpcio==1.43.0
gym==0.26.2
gym-notices==0.0.8
gymnasium==0.29.1
gymnasium-notices==0.0.1
h5py==2.10.0
highway-env==1.8.2
httplib2==0.20.2
huggingface-hub==0.23.0
huggingface-sb3==3.0
hypothesis==6.54.6
identify==2.5.36
idna==3.4
imageio==2.15.0
imageio-ffmpeg==0.4.9
imitation==1.0.0
importlab==0.8.1
importlib-resources==5.4.0
importlib_metadata==7.1.0
iniconfig==2.0.0
ipykernel==6.15.3
ipython==8.12.3
ipython-genutils==0.2.0
ipywidgets==7.8.1
isort==5.13.2
jedi==0.17.2
Jinja2==3.1.4
joblib==1.4.2
jsonpickle==3.0.4
jsonschema==3.2.0
jupyter==1.0.0
jupyter-client==6.1.12
jupyter-console==6.4.2
jupyter-server==1.24.0
jupyter-server-mathjax==0.2.6
jupyter_core==5.7.2
jupyterlab-pygments==0.1.2
jupyterlab-widgets==1.1.7
keras==2.6.0
Keras-Preprocessing==1.1.2
kfp==1.8.9
kfp-pipeline-spec==0.1.13
kfp-server-api==1.7.1
kiwisolver==1.3.1
kubernetes==18.20.0
libcst==1.1.0
lz4==3.1.10
Mako==1.3.3
Markdown==3.3.7
MarkupSafe==2.0.1
matplotlib==3.3.4
matplotlib-inline==0.1.7
mccabe==0.6.1
memory-profiler==0.61.0
mistune==3.0.2
moviepy==1.0.3
mpmath==1.3.0
msgpack==1.0.5
multidict==6.0.5
multiprocess==0.70.16
munch==4.0.0
mypy==0.991
mypy-extensions==1.0.0
nbclient==0.5.13
nbconvert==7.16.4
nbdime==4.0.1
nbformat==5.10.4
nest-asyncio==1.5.8
networkx==2.5.1
ninja==1.11.1.1
node-semver==0.6.1
nodeenv==1.8.0
notebook==6.4.10
numpy==1.21.0
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.4.127
nvidia-nvtx-cu12==12.1.105
oauthlib==3.2.2
opencv-python==4.9.0.80
opt-einsum==3.3.0
optuna==3.6.1
packaging==21.3
pandas==1.4.4
pandocfilters==1.5.0
parso==0.7.1
patch-ng==1.17.4
pathspec==0.9.0
pathtools==0.1.2
pexpect==4.8.0
pickleshare==0.7.5
Pillow==8.4.0
pip==24.0
platformdirs==2.6.2
plotly==5.18.0
pluggy==1.5.0
pluginbase==1.0.1
pre-commit==3.5.0
proglog==0.1.10
prometheus-client==0.17.1
promise==2.3
prompt-toolkit==3.0.36
protobuf==3.20.3
psutil==5.9.8
ptyprocess==0.7.0
pure-eval==0.2.2
py==1.11.0
py-cpuinfo==9.0.0
pyarrow==16.0.0
pyarrow-hotfix==0.6
pyasn1==0.5.0
pyasn1-modules==0.3.0
pycnite==2023.10.11
pycocotools==2.0.4
pycodestyle==2.8.0
pycparser==2.21
pydantic==1.8.2
pydocstyle==6.3.0
pydot==2.0.0
pyflakes==2.4.0
pygame==2.5.2
Pygments==2.14.0
PyJWT==2.4.0
pyparsing==3.1.1
pypiserver==2.0.1
pyrsistent==0.18.0
pytest==7.1.3
pytest-cov==3.0.0
pytest-forked==1.6.0
pytest-timeout==2.1.0
pytest-xdist==2.5.0
pytest_notebook==0.8.0
python-dateutil==2.8.2
pytype==2023.9.27
pytz==2023.3.post1
PyWavelets==1.1.1
PyYAML==6.0
pyzmq==25.1.1
qtconsole==5.2.2
QtPy==2.0.1
ray==2.0.1
requests==2.27.1
requests-oauthlib==1.3.1
requests-toolbelt==0.9.1
rich==12.6.0
rsa==4.9
sacred==0.8.5
scikit-image==0.17.2
scikit-learn==1.3.2
scipy==1.9.3
seals==0.2.1
Send2Trash==1.8.2
sentry-sdk==2.1.1
setproctitle==1.3.3
setuptools==69.5.1
setuptools-scm==7.0.5
Shimmy==0.2.1
shortuuid==1.0.13
six==1.15.0
smmap==5.0.1
sniffio==1.3.1
snowballstemmer==2.2.0
sortedcontainers==2.4.0
soupsieve==2.5
SQLAlchemy==2.0.30
stable_baselines3==2.3.2
stack-data==0.6.3
strip-hints==0.1.10
sympy==1.12
tabulate==0.8.10
tenacity==8.3.0
tensorboard==2.11.2
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorboardX==2.6.2.2
tensorflow==2.2.0
tensorflow-estimator==2.2.0
termcolor==1.1.0
terminado==0.12.1
terminaltables==3.1.10
testpath==0.6.0
threadpoolctl==3.5.0
tifffile==2020.9.3
tinycss2==1.3.0
tokenize-rt==5.2.0
toml==0.10.2
tomli==1.2.3
torch==2.3.0
tornado==6.1
tqdm==4.64.1
traitlets==5.14.3
triton==2.3.0
typed-ast==1.5.5
typer==0.9.0
typing-inspect==0.9.0
typing_extensions==4.11.0
uritemplate==3.0.1
urllib3==1.26.17
validators==0.20.0
virtualenv==20.17.1
wandb==0.12.21
wasabi==1.1.2
wcwidth==0.2.9
webencodings==0.5.1
websocket-client==1.2.1
Werkzeug==2.0.3
wheel==0.43.0
widgetsnbextension==3.6.6
wrapt==1.12.1
xxhash==3.4.1
yarl==1.9.4
zipp==3.6.0