-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add LoRA and QLoRA example #1889
Conversation
d8c76d8
to
38be789
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good -- thanks for the PR!
import os | ||
|
||
os.environ["KERAS_BACKEND"] = "tensorflow" | ||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TF_CPP_MIN_LOG_LEVEL=3
is used to suppress verbose logging from TF
examples/keras_recipes/parameter_efficient_finetuning_of_gemma_with_lora_and_qlora.py
Outdated
Show resolved
Hide resolved
Still need to read over the guide, but a general note... We are working on adding pre quantized versions of Gemma 2. When we do I believe that we will be able to show lora fine-tuning on the 9b model on a free tier colab GPU. 9b * 1 byte per weight + a relatively small number of trainable parameters = ~9GB of vRAM. Does that track with your understanding @james77777778 ? Once we have those up, we should definitely show that too. That's when things get really exiting -- fine-tuning a model that normally wouldn't even fit on an accelerator. |
Yes, the number should be correct for loading the model. However, the vRAM requirement might be (much) larger than 9GB if we want to do backpropagation. With my 10GB vRAM rig, I can barely run this example (Gemma 2B) using JAX after this patch keras-team/keras#19954. |
Should we wait for those additions before merging? |
@fchollet I think we can merge with what we have an extend later. No need to block! |
@james77777778 I though that with "qlora" like flow, the back propagation requirements should actually be about the same as inference. There's definitely some overhead for intermediate activations and gradients (using a short sequence length will help here), but we only need to keep around the gradients for the lora trainable parameters, which will add up to just a few MB. Overall the dominant memory requirements should just be 1 byte per frozen quantized parameter. That's the principal that allows a bitsandbytes colab like this, training a 20b parameter model on free tier resources... For us we only go down to 8 bits a weight (that colab show 4 bits per weight), so we won't be able to tune a 20b parameter model on free hardward. But I think we could tune a pre-quantized model of up to 10b on a T4. |
I'm unsure how JAX, TF and Torch calculate the gradients, but it is true that we only need a small amount of memory for the qlora-like technique. However, we might also need a quantized optimizer to achieve this optimal reduction in memory usage.
I can try a 7B/8B model on colab T4 once the quantized models are uploaded keras-team/keras-hub#1720 |
thanks for the pointer! i'll have a look |
@fchollet @mattdangerw Note that I can only run this example on a CPU using jupyter notebook due to the GPU OOM issue. |
@james77777778 ready I think! |
75609d0
to
958c38e
Compare
This PR should be ready now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you for the contribution!
This example borrows a lot of description from https://keras.io/examples/nlp/parameter_efficient_finetuning_of_gpt2_with_lora/
Therefore, I have added the original authors as this example's authors.
Basically, this example demonstrates the fine-tuning of Gemma with LoRA and QLoRA on a French-to-English translation task.
Please let me know if the
.py
is ready. Then, I will submit the.md
and.ipynb
.Note that we need the latest code from KerasNLP since the quantization support hasn't been released yet.
cc @fchollet