-
Notifications
You must be signed in to change notification settings - Fork 26.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Significantly increased VRAM usage for Mixtral qlora training compared to 4.36.2? #28339
Comments
Hey! Thanks for the report, here are potential PRs that I would suspect:
|
This may not be relevant to you but I found this recent change to Axolotl has made a significant difference to VRAM usage. Previously I could just squeeze in a LoRA on a 34B model on my 3x3090s at batch size 2, seq length 4096, now it OOMs immediately. I undid the change and it fits again. |
Hmm it's certainly possible since that commit was in between when I did my initial train and the run where I had to drop the batch size. Unfortunately don't have a training instance up right now, so I'd have to test it the next time I try to train. |
I've determined that the cause of the increased VRAM usage was indeed axolotl changing the default for use_reentrant to False for gradient checkpointing. Going to go ahead and close the issue. |
thanks for sharing the solution! 🤗 |
System Info
The environment is a Runpod container with python 3.10, single A100 80gb, transformers 4.37.0dev (3cefac1), using axolotl training script (https://github.com/OpenAccess-AI-Collective/axolotl).
Who can help?
No response
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Hello, just tried doing a training run on the dev version of transformers (as of 3cefac1) via the common training repository axolotl (https://github.com/OpenAccess-AI-Collective/axolotl) and noticed that I went OOM using the same configuration that I had previously used successfully with transformers 4.36.2 stable. And not even just a small difference - I had to reduce my batch size by 4x to make the training fit in VRAM.
I was previously able to fit 8192 ctx, batch size 4, grad accum steps 2 without difficulty, but I found that I now had to reduce my batch size to 1 to avoid OOM. The relevant training hyperparameters are:
Would appreciate any insights into what caused the massive increase in memory usage. I noticed that ehartford's latest dolphin 2.7 qlora used a batch size of 3 per device at 16k ctx on A100 80gb, so surely I'm missing something here?
Expected behavior
The training run should take a relatively similar amount of VRAM as it did previously with the same config.
The text was updated successfully, but these errors were encountered: