Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mistral-Finetune creates consolidated.safetensors for mixtral 8x7b instruct v0.1 but mistral-chat fails inference for it complains about LoRA weights file being loaded missing an expected key for one of the model layers. #75

Open
tensimixt opened this issue Jul 4, 2024 · 5 comments
Labels
bug Something isn't working

Comments

@tensimixt
Copy link

tensimixt commented Jul 4, 2024

Python Version

Python 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]

Pip Freeze

absl-py==2.1.0
annotated-types==0.7.0
anyio==4.0.0
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asttokens==2.4.1
async-lru==2.0.4
attrs==23.1.0
Babel==2.13.1
beautifulsoup4==4.12.2
bleach==6.1.0
blinker==1.4
certifi==2022.12.7
cffi==1.16.0
charset-normalizer==2.1.1
click==8.1.7
cmake==3.25.0
comm==0.1.4
cryptography==3.4.8
dbus-python==1.2.18
debugpy==1.8.0
decorator==5.1.1
defusedxml==0.7.1
distro==1.7.0
docker-pycreds==0.4.0
docstring_parser==0.16
entrypoints==0.4
exceptiongroup==1.1.3
executing==2.0.1
fastjsonschema==2.18.1
filelock==3.9.0
fire==0.6.0
fqdn==1.5.1
fsspec==2024.6.1
gitdb==4.0.11
GitPython==3.1.43
grpcio==1.64.1
httplib2==0.20.2
idna==3.4
importlib-metadata==4.6.4
ipykernel==6.26.0
ipython==8.17.2
ipython-genutils==0.2.0
ipywidgets==8.1.1
isoduration==20.11.0
jedi==0.19.1
jeepney==0.7.1
Jinja2==3.1.2
json5==0.9.14
jsonpointer==2.4
jsonschema==4.21.1
jsonschema-specifications==2023.7.1
jupyter-archive==3.4.0
jupyter-contrib-core==0.4.2
jupyter-contrib-nbextensions==0.7.0
jupyter-events==0.8.0
jupyter-highlight-selected-word==0.2.0
jupyter-lsp==2.2.0
jupyter-nbextensions-configurator==0.6.3
jupyter_client==7.4.9
jupyter_core==5.5.0
jupyter_server==2.9.1
jupyter_server_terminals==0.4.4
jupyterlab==4.0.8
jupyterlab-pygments==0.2.2
jupyterlab-widgets==3.0.9
jupyterlab_server==2.25.0
keyring==23.5.0
launchpadlib==1.10.16
lazr.restfulclient==0.14.4
lazr.uri==1.0.6
lit==15.0.7
lxml==4.9.3
Markdown==3.6
MarkupSafe==2.1.2
matplotlib-inline==0.1.6
mistral_common==1.2.1
mistral_inference==1.1.0
mistune==3.0.2
more-itertools==8.10.0
mpmath==1.3.0
nbclassic==1.0.0
nbclient==0.8.0
nbconvert==7.10.0
nbformat==5.9.2
nest-asyncio==1.5.8
networkx==3.0
notebook==6.5.5
notebook_shim==0.2.3
numpy==1.24.1
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.19.3
nvidia-nvjitlink-cu12==12.5.82
nvidia-nvtx-cu12==12.1.105
oauthlib==3.2.0
overrides==7.4.0
packaging==23.2
pandocfilters==1.5.0
parso==0.8.3
pexpect==4.8.0
Pillow==9.3.0
platformdirs==3.11.0
prometheus-client==0.18.0
prompt-toolkit==3.0.39
protobuf==4.25.3
psutil==5.9.6
ptyprocess==0.7.0
pure-eval==0.2.2
pycparser==2.21
pydantic==2.6.1
pydantic_core==2.16.2
Pygments==2.16.1
PyGObject==3.42.1
PyJWT==2.3.0
pyparsing==2.4.7
python-apt==2.4.0+ubuntu3
python-dateutil==2.8.2
python-json-logger==2.0.7
PyYAML==6.0.1
pyzmq==24.0.1
referencing==0.30.2
requests==2.31.0
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rpds-py==0.10.6
safetensors==0.4.3
SecretStorage==3.3.1
Send2Trash==1.8.2
sentencepiece==0.1.99
sentry-sdk==2.7.1
setproctitle==1.3.3
simple_parsing==0.1.5
six==1.16.0
smmap==5.0.1
sniffio==1.3.0
soupsieve==2.5
stack-data==0.6.3
sympy==1.12
tensorboard==2.17.0
tensorboard-data-server==0.7.2
termcolor==2.4.0
terminado==0.17.1
tinycss2==1.2.1
tomli==2.0.1
torch==2.2.0
torchaudio==2.0.2+cu118
torchvision==0.15.2+cu118
tornado==6.3.3
tqdm==4.66.4
traitlets==5.13.0
triton==2.2.0
types-python-dateutil==2.8.19.14
typing_extensions==4.12.2
uri-template==1.3.0
urllib3==1.26.13
wadllib==1.3.6
wandb==0.17.4
wcwidth==0.2.9
webcolors==1.13
webencodings==0.5.1
websocket-client==1.6.4
Werkzeug==3.0.3
widgetsnbextension==4.0.9
xformers==0.0.24
zipp==1.0.0

Reproduction Steps

clone repo
download mixtral 8x7b instruct v0.1 ==> put in /mistral_models
download v3 tokenizer and put into /mistral_models
run util extend which generates /mistral_models_extended
put v3 tokenizer into /mistral_models_extended directory
put data into /data
run data validation
train (generates checkpoints. after 300 steps get /workspace/mistral-finetune/experiment5/checkpoints/checkpoint_000300/consolidated/lora.safetensors)

Finally run mistral-chat:
torchrun --nproc-per-node 2 --no-python mistral-chat /workspace/mistral_models_extended --max_tokens 256 --temperature 0.7 --instruct --lora_path workspace/mistral-finetune/experiment5/checkpoints/checkpoint_000300/consolidated/lora.safetensors

This generates the following error

[2024-07-04 00:29:06,277] torch.distributed.run: [WARNING] 
[2024-07-04 00:29:06,277] torch.distributed.run: [WARNING] *****************************************
[2024-07-04 00:29:06,277] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
[2024-07-04 00:29:06,277] torch.distributed.run: [WARNING] *****************************************
Traceback (most recent call last):
  File "/usr/local/bin/mistral-chat", line 8, in <module>
    sys.exit(mistral_chat())
  File "/usr/local/lib/python3.10/dist-packages/mistral_inference/main.py", line 179, in mistral_chat
    fire.Fire(interactive)
  File "/usr/local/lib/python3.10/dist-packages/fire/core.py", line 143, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/usr/local/lib/python3.10/dist-packages/fire/core.py", line 477, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "/usr/local/lib/python3.10/dist-packages/fire/core.py", line 693, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/mistral_inference/main.py", line 70, in interactive
    transformer.load_lora(Path(lora_path))
  File "/usr/local/lib/python3.10/dist-packages/mistral_inference/lora.py", line 104, in load_lora
    self._load_lora_state_dict(state_dict, scaling=scaling)
  File "/usr/local/lib/python3.10/dist-packages/mistral_inference/lora.py", line 144, in _load_lora_state_dict
    lora_state_dict[name + ".lora_B.weight"]
KeyError: 'layers.16.feed_forward.gate.lora_B.weight'
Traceback (most recent call last):
  File "/usr/local/bin/mistral-chat", line 8, in <module>
    sys.exit(mistral_chat())
  File "/usr/local/lib/python3.10/dist-packages/mistral_inference/main.py", line 179, in mistral_chat
    fire.Fire(interactive)
  File "/usr/local/lib/python3.10/dist-packages/fire/core.py", line 143, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/usr/local/lib/python3.10/dist-packages/fire/core.py", line 477, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "/usr/local/lib/python3.10/dist-packages/fire/core.py", line 693, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/mistral_inference/main.py", line 70, in interactive
    transformer.load_lora(Path(lora_path))
  File "/usr/local/lib/python3.10/dist-packages/mistral_inference/lora.py", line 104, in load_lora
    self._load_lora_state_dict(state_dict, scaling=scaling)
  File "/usr/local/lib/python3.10/dist-packages/mistral_inference/lora.py", line 144, in _load_lora_state_dict
    lora_state_dict[name + ".lora_B.weight"]
KeyError: 'layers.0.feed_forward.gate.lora_B.weight'
[2024-07-04 00:29:21,301] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 0 (pid: 3264) of binary: mistral-chat
Traceback (most recent call last):
  File "/usr/local/bin/torchrun", line 8, in <module>
    sys.exit(main())
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 347, in wrapper
    return f(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/run.py", line 812, in main
    run(args)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/run.py", line 803, in run
    elastic_launch(
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/launcher/api.py", line 135, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/launcher/api.py", line 268, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
mistral-chat FAILED
------------------------------------------------------------
Failures:
[1]:
  time      : 2024-07-04_00:29:21
  host      : finetuning-latest-2-0
  rank      : 1 (local_rank: 1)
  exitcode  : 1 (pid: 3265)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2024-07-04_00:29:21
  host      : finetuning-latest-2-0
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 3264)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================

Expected Behavior

Expect Prompt to appear for interactive chat in terminal, but getting the above error instead.

@tensimixt tensimixt added the bug Something isn't working label Jul 4, 2024
@tensimixt
Copy link
Author

@patrickvonplaten Hi do you know if mistral-inference works for lora+mixtral8x7b instruct v0.1? It does work for lora+mistral-7b v0.3 but getting error about LoRA weights file being loaded missing an expected key for one of the model layers when trying for lora+mixtral8x7b instruct v0.1

Is there something else required to make it work?

Thank you

@patrickvonplaten
Copy link
Collaborator

Nice catch! We should fix this indeed

@tensimixt
Copy link
Author

tensimixt commented Jul 10, 2024

Nice catch! We should fix this indeed

Thank you! Do you think that mistral-finetune is creating bad LoRA's for when finetuning mixtral 8x7b v0.1 instruct?
Is there a place in the repo worth checking and updating where you think this issue is arising from?

@patrickvonplaten
Copy link
Collaborator

Sorry to reply only now. We'll make an update to mistral-inference that should make sure that a LoRA trained with the 8x7B will work correctly with mistral-chat. Sorry to have you waiting for so long

@patrickvonplaten
Copy link
Collaborator

Can you check if you still encounter the problem when installing pip install mistral_inference>=1.2.0 ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants