Skip to content

Problem in Passing Along Custom Gym Env Constructor Parameters in make_vec_env #850

Open
@christianjcc

Description

@christianjcc

Bug description

I was testing out the imitation learning library with a custom gym environment and ran into a shortcoming in imitation/util/util.py. I get the error message provided below.

`Traceback (most recent call last):
File "/home/anaconda3/envs/env/lib/python3.8/site-packages/gymnasium/envs/registration.py", line 802, in make
env = env_creator(**env_spec_kwargs)
TypeError: init() missing 1 required positional argument: 'config'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "runner.py", line 118, in
main()
File "runner.py", line 58, in main
env = make_vec_env(
File "/home/anaconda3/envs/educagent_libtraffic/lib/python3.8/site-packages/imitation/util/util.py", line 117, in make_vec_env
tmp_env = gym.make(env_name)
File "/home/anaconda3/envs/env/lib/python3.8/site-packages/gymnasium/envs/registration.py", line 814, in make
raise type(e)(
TypeError: init() missing 1 required positional argument: 'config' was raised from the environment creator for custom/mycustom-env with kwargs ({})
`

Below is an example of how I define, register the custom gym environment, and pass along the config dictionary to the env instance,
`

Define your config dictionary

config = {
"parameters": "example",
"render_mode": "human",
}

env_name="custom/mycustom-env"
gym.register(
id=env_name,
entry_point=MyCustomEnv,
max_episode_steps=500,
)

env = make_vec_env(
    env_name=env_name,
    n_envs=1,
    post_wrappers=[
        lambda env, _: RolloutInfoWrapper(env)
    ],  # needed for computing rollouts later
    env_make_kwargs={"config": config},
)

I can get it to work with the way the imitation library has implemented it's gym.make instance by doing the following instead:
# Define a factory function that returns the custom environment class with the desired configuration
def make_traffic_env(**kwargs):
return MyCustomEnv(config=config, **kwargs)

env_name="custom/mycustom-env"

gym.register(
    id=env_name,
    entry_point=Tmake_traffic_env,
    max_episode_steps=500,
)

env = make_vec_env(
    env_name=env_name,
    n_envs=1,
    post_wrappers=[
        lambda env, _: RolloutInfoWrapper(env)
    ],  # needed for computing rollouts later
)

`

This occurs because the following variable is defined as follows:
tmp_env = gym.make(env_name)

As implemented here:
https://github.com/HumanCompatibleAI/imitation/blob/master/src/imitation/util/util.py#L117

instead of:
tmp_env = gym.make(env_name, **env_make_kwargs)

to pass along the kwargs as defined in,
https://gymnasium.farama.org/api/registry/#gymnasium.make

By implementing the suggestion above, it will help users avoid having to define factory function, and simplifying the steps.

Environment

  • Operating system and version: Ubuntu 22.04.4 LTS
  • Python version: Python 3.8.19
  • Output of pip freeze --all:
    absl-py==0.15.0
    aiohttp==3.9.5
    aiosignal==1.2.0
    ale-py==0.8.1
    alembic==1.13.1
    anyio==3.7.1
    argon2-cffi==21.3.0
    argon2-cffi-bindings==21.2.0
    asttokens==2.4.1
    astunparse==1.6.3
    async-generator==1.10
    async-timeout==4.0.3
    attrs==22.2.0
    AutoROM==0.4.2
    AutoROM.accept-rom-license==0.6.1
    backcall==0.2.0
    beautifulsoup4==4.12.3
    black==22.6.0
    bleach==4.1.0
    bottle==0.12.25
    cached-property==1.5.2
    cachetools==4.2.4
    certifi==2021.5.30
    cffi==1.15.1
    cfgv==3.4.0
    charset-normalizer==2.0.12
    clang==5.0
    click==8.0.4
    cloudpickle==2.2.1
    codecov==2.1.13
    codespell==2.1.0
    colorama==0.4.5
    colorlog==6.8.2
    comm==0.1.4
    commonmark==0.9.1
    conan==1.60.1
    coverage==6.4.4
    cycler==0.11.0
    darglint==1.8.1
    datasets==2.19.1
    debugpy==1.8.1
    decorator==4.4.2
    defusedxml==0.7.1
    Deprecated==1.2.13
    dill==0.3.8
    distlib==0.3.7
    distro==1.8.0
    dm-tree==0.1.8
    docker-pycreds==0.4.0
    docopt==0.6.2
    docstring-parser==0.13
    entrypoints==0.4
    exceptiongroup==1.2.1
    execnet==2.1.1
    executing==2.0.1
    Farama-Notifications==0.0.4
    fasteners==0.19
    fastjsonschema==2.19.1
    filelock==3.7.1
    fire==0.4.0
    flake8==4.0.1
    flake8-blind-except==0.2.1
    flake8-builtins==1.5.3
    flake8-commas==2.1.0
    flake8-debugger==4.1.2
    flake8-docstrings==1.6.0
    flake8-isort==4.1.2.post0
    flatbuffers==1.12
    fonttools==4.34.4
    frozenlist==1.2.0
    fsspec==2024.3.1
    gast==0.3.3
    gitdb==4.0.11
    GitPython==3.1.43
    google-api-python-client==1.12.8
    google-auth==2.29.0
    google-auth-httplib2==0.1.0
    google-auth-oauthlib==0.4.6
    google-crc32c==1.3.0
    google-pasta==0.2.0
    greenlet==3.0.3
    grpcio==1.43.0
    gym==0.26.2
    gym-notices==0.0.8
    gymnasium==0.29.1
    gymnasium-notices==0.0.1
    h5py==2.10.0
    highway-env==1.8.2
    httplib2==0.20.2
    huggingface-hub==0.23.0
    huggingface-sb3==3.0
    hypothesis==6.54.6
    identify==2.5.36
    idna==3.4
    imageio==2.15.0
    imageio-ffmpeg==0.4.9
    imitation==1.0.0
    importlab==0.8.1
    importlib-resources==5.4.0
    importlib_metadata==7.1.0
    iniconfig==2.0.0
    ipykernel==6.15.3
    ipython==8.12.3
    ipython-genutils==0.2.0
    ipywidgets==7.8.1
    isort==5.13.2
    jedi==0.17.2
    Jinja2==3.1.4
    joblib==1.4.2
    jsonpickle==3.0.4
    jsonschema==3.2.0
    jupyter==1.0.0
    jupyter-client==6.1.12
    jupyter-console==6.4.2
    jupyter-server==1.24.0
    jupyter-server-mathjax==0.2.6
    jupyter_core==5.7.2
    jupyterlab-pygments==0.1.2
    jupyterlab-widgets==1.1.7
    keras==2.6.0
    Keras-Preprocessing==1.1.2
    kfp==1.8.9
    kfp-pipeline-spec==0.1.13
    kfp-server-api==1.7.1
    kiwisolver==1.3.1
    kubernetes==18.20.0
    libcst==1.1.0
    lz4==3.1.10
    Mako==1.3.3
    Markdown==3.3.7
    MarkupSafe==2.0.1
    matplotlib==3.3.4
    matplotlib-inline==0.1.7
    mccabe==0.6.1
    memory-profiler==0.61.0
    mistune==3.0.2
    moviepy==1.0.3
    mpmath==1.3.0
    msgpack==1.0.5
    multidict==6.0.5
    multiprocess==0.70.16
    munch==4.0.0
    mypy==0.991
    mypy-extensions==1.0.0
    nbclient==0.5.13
    nbconvert==7.16.4
    nbdime==4.0.1
    nbformat==5.10.4
    nest-asyncio==1.5.8
    networkx==2.5.1
    ninja==1.11.1.1
    node-semver==0.6.1
    nodeenv==1.8.0
    notebook==6.4.10
    numpy==1.21.0
    nvidia-cublas-cu12==12.1.3.1
    nvidia-cuda-cupti-cu12==12.1.105
    nvidia-cuda-nvrtc-cu12==12.1.105
    nvidia-cuda-runtime-cu12==12.1.105
    nvidia-cudnn-cu12==8.9.2.26
    nvidia-cufft-cu12==11.0.2.54
    nvidia-curand-cu12==10.3.2.106
    nvidia-cusolver-cu12==11.4.5.107
    nvidia-cusparse-cu12==12.1.0.106
    nvidia-nccl-cu12==2.20.5
    nvidia-nvjitlink-cu12==12.4.127
    nvidia-nvtx-cu12==12.1.105
    oauthlib==3.2.2
    opencv-python==4.9.0.80
    opt-einsum==3.3.0
    optuna==3.6.1
    packaging==21.3
    pandas==1.4.4
    pandocfilters==1.5.0
    parso==0.7.1
    patch-ng==1.17.4
    pathspec==0.9.0
    pathtools==0.1.2
    pexpect==4.8.0
    pickleshare==0.7.5
    Pillow==8.4.0
    pip==24.0
    platformdirs==2.6.2
    plotly==5.18.0
    pluggy==1.5.0
    pluginbase==1.0.1
    pre-commit==3.5.0
    proglog==0.1.10
    prometheus-client==0.17.1
    promise==2.3
    prompt-toolkit==3.0.36
    protobuf==3.20.3
    psutil==5.9.8
    ptyprocess==0.7.0
    pure-eval==0.2.2
    py==1.11.0
    py-cpuinfo==9.0.0
    pyarrow==16.0.0
    pyarrow-hotfix==0.6
    pyasn1==0.5.0
    pyasn1-modules==0.3.0
    pycnite==2023.10.11
    pycocotools==2.0.4
    pycodestyle==2.8.0
    pycparser==2.21
    pydantic==1.8.2
    pydocstyle==6.3.0
    pydot==2.0.0
    pyflakes==2.4.0
    pygame==2.5.2
    Pygments==2.14.0
    PyJWT==2.4.0
    pyparsing==3.1.1
    pypiserver==2.0.1
    pyrsistent==0.18.0
    pytest==7.1.3
    pytest-cov==3.0.0
    pytest-forked==1.6.0
    pytest-timeout==2.1.0
    pytest-xdist==2.5.0
    pytest_notebook==0.8.0
    python-dateutil==2.8.2
    pytype==2023.9.27
    pytz==2023.3.post1
    PyWavelets==1.1.1
    PyYAML==6.0
    pyzmq==25.1.1
    qtconsole==5.2.2
    QtPy==2.0.1
    ray==2.0.1
    requests==2.27.1
    requests-oauthlib==1.3.1
    requests-toolbelt==0.9.1
    rich==12.6.0
    rsa==4.9
    sacred==0.8.5
    scikit-image==0.17.2
    scikit-learn==1.3.2
    scipy==1.9.3
    seals==0.2.1
    Send2Trash==1.8.2
    sentry-sdk==2.1.1
    setproctitle==1.3.3
    setuptools==69.5.1
    setuptools-scm==7.0.5
    Shimmy==0.2.1
    shortuuid==1.0.13
    six==1.15.0
    smmap==5.0.1
    sniffio==1.3.1
    snowballstemmer==2.2.0
    sortedcontainers==2.4.0
    soupsieve==2.5
    SQLAlchemy==2.0.30
    stable_baselines3==2.3.2
    stack-data==0.6.3
    strip-hints==0.1.10
    sympy==1.12
    tabulate==0.8.10
    tenacity==8.3.0
    tensorboard==2.11.2
    tensorboard-data-server==0.6.1
    tensorboard-plugin-wit==1.8.1
    tensorboardX==2.6.2.2
    tensorflow==2.2.0
    tensorflow-estimator==2.2.0
    termcolor==1.1.0
    terminado==0.12.1
    terminaltables==3.1.10
    testpath==0.6.0
    threadpoolctl==3.5.0
    tifffile==2020.9.3
    tinycss2==1.3.0
    tokenize-rt==5.2.0
    toml==0.10.2
    tomli==1.2.3
    torch==2.3.0
    tornado==6.1
    tqdm==4.64.1
    traitlets==5.14.3
    triton==2.3.0
    typed-ast==1.5.5
    typer==0.9.0
    typing-inspect==0.9.0
    typing_extensions==4.11.0
    uritemplate==3.0.1
    urllib3==1.26.17
    validators==0.20.0
    virtualenv==20.17.1
    wandb==0.12.21
    wasabi==1.1.2
    wcwidth==0.2.9
    webencodings==0.5.1
    websocket-client==1.2.1
    Werkzeug==2.0.3
    wheel==0.43.0
    widgetsnbextension==3.6.6
    wrapt==1.12.1
    xxhash==3.4.1
    yarl==1.9.4
    zipp==3.6.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions