Open
Description
Bug description
A imitation.policies.base.FeedForward32Policy
that is saved using policy.save()
cannot be loaded with imitation.policies.base.FeedForward32Policy.load()
, raising the following error:
Steps to reproduce
Train a policy using imitation.algorithms.bc.BC
, then save the trained policy.
bc_trainer = bc.BC(
observation_space=env.observation_space,
action_space=env.action_space,
demonstrations=transitions,
rng=rng,
)
bc_trainer.train(n_epochs=1)
bc_trainer.policy.save("policy.zip")
This should save a FeedForward32Policy. Then, load the policy.
imitation.policies.base.FeedForward32Policy.load("policy.zip")
This raised the following exception:
File "/home/jayden/Desktop/Programs/aithermostat/venv/lib/python3.8/site-packages/imitation/policies/base.py", line 104, in __init__
super().__init__(*args, **kwargs, net_arch=[32, 32])
TypeError: __init__() got multiple values for keyword argument 'net_arch'
Environment
- Operating system and version: Linux
- Python version: 3.8.19
- Output of
pip freeze --all
:
absl-py==2.1.0
aiohappyeyeballs==2.3.4
aiohttp==3.10.0
aiosignal==1.3.1
ale-py==0.8.1
alembic==1.13.2
async-timeout==4.0.3
attrs==23.2.0
AutoROM==0.6.1
AutoROM.accept-rom-license==0.6.1
cachetools==5.4.0
certifi==2024.7.4
charset-normalizer==3.3.2
click==8.1.7
cloudpickle==3.0.0
colorama==0.4.6
colorlog==6.8.2
contourpy==1.1.1
cycler==0.12.1
datasets==2.20.0
dill==0.3.8
docopt==0.6.2
Farama-Notifications==0.0.4
filelock==3.15.4
fonttools==4.53.1
frozenlist==1.4.1
fsspec==2024.5.0
gitdb==4.0.11
GitPython==3.1.43
google-auth==2.32.0
google-auth-oauthlib==1.0.0
greenlet==3.0.3
grpcio==1.65.2
gymnasium==0.29.1
huggingface-hub==0.24.5
huggingface-sb3==3.0
idna==3.7
imitation==1.0.0
importlib_metadata==8.2.0
importlib_resources==6.4.0
Jinja2==3.1.4
joblib==1.4.2
jsonpickle==3.2.2
kiwisolver==1.4.5
Mako==1.3.5
Markdown==3.6
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.7.5
mdurl==0.1.2
mpmath==1.3.0
multidict==6.0.5
multiprocess==0.70.16
munch==4.0.0
networkx==3.1
numpy==1.24.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.6.20
nvidia-nvtx-cu12==12.1.105
oauthlib==3.2.2
opencv-python==4.10.0.84
optuna==3.6.1
packaging==24.1
pandas==2.0.3
pillow==10.4.0
pip==23.0.1
protobuf==5.27.3
psutil==6.0.0
py-cpuinfo==9.0.0
pyarrow==17.0.0
pyarrow-hotfix==0.6
pyasn1==0.6.0
pyasn1_modules==0.4.0
pygame==2.6.0
Pygments==2.18.0
pyparsing==3.1.2
python-dateutil==2.9.0.post0
pytz==2024.1
PyYAML==6.0.1
requests==2.32.3
requests-oauthlib==2.0.0
rich==13.7.1
rsa==4.9
sacred==0.8.5
scikit-learn==1.3.2
scipy==1.10.1
seals==0.2.1
setuptools==56.0.0
Shimmy==1.3.0
six==1.16.0
smmap==5.0.1
SQLAlchemy==2.0.31
stable_baselines3==2.3.2
sympy==1.13.1
tensorboard==2.14.0
tensorboard-data-server==0.7.2
threadpoolctl==3.5.0
torch==2.4.0
tqdm==4.66.4
triton==3.0.0
typing_extensions==4.12.2
tzdata==2024.1
urllib3==2.2.2
wasabi==1.1.3
Werkzeug==3.0.3
wheel==0.43.0
wrapt==1.16.0
xxhash==3.4.1
yarl==1.9.4
zipp==3.19.2