Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LoRA and QLoRA example #1889

Merged
merged 5 commits into from
Aug 14, 2024
Merged

Conversation

james77777778
Copy link
Contributor

This example borrows a lot of description from https://keras.io/examples/nlp/parameter_efficient_finetuning_of_gpt2_with_lora/
Therefore, I have added the original authors as this example's authors.

Basically, this example demonstrates the fine-tuning of Gemma with LoRA and QLoRA on a French-to-English translation task.

Please let me know if the .py is ready. Then, I will submit the .md and .ipynb.

Note that we need the latest code from KerasNLP since the quantization support hasn't been released yet.

cc @fchollet

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good -- thanks for the PR!

import os

os.environ["KERAS_BACKEND"] = "tensorflow"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TF_CPP_MIN_LOG_LEVEL=3 is used to suppress verbose logging from TF

@mattdangerw
Copy link
Member

Still need to read over the guide, but a general note...

We are working on adding pre quantized versions of Gemma 2. When we do I believe that we will be able to show lora fine-tuning on the 9b model on a free tier colab GPU. 9b * 1 byte per weight + a relatively small number of trainable parameters = ~9GB of vRAM. Does that track with your understanding @james77777778 ?

Once we have those up, we should definitely show that too. That's when things get really exiting -- fine-tuning a model that normally wouldn't even fit on an accelerator.

@james77777778
Copy link
Contributor Author

9b * 1 byte per weight + a relatively small number of trainable parameters = ~9GB of vRAM. Does that track with your understanding @james77777778 ?

Yes, the number should be correct for loading the model. However, the vRAM requirement might be (much) larger than 9GB if we want to do backpropagation.

With my 10GB vRAM rig, I can barely run this example (Gemma 2B) using JAX after this patch keras-team/keras#19954.

@fchollet
Copy link
Contributor

Once we have those up, we should definitely show that too.

Should we wait for those additions before merging?

@mattdangerw
Copy link
Member

@fchollet I think we can merge with what we have an extend later. No need to block!

@mattdangerw
Copy link
Member

@james77777778 I though that with "qlora" like flow, the back propagation requirements should actually be about the same as inference. There's definitely some overhead for intermediate activations and gradients (using a short sequence length will help here), but we only need to keep around the gradients for the lora trainable parameters, which will add up to just a few MB. Overall the dominant memory requirements should just be 1 byte per frozen quantized parameter.

That's the principal that allows a bitsandbytes colab like this, training a 20b parameter model on free tier resources...
https://colab.research.google.com/drive/1VoYNfYDKcKRQRor98Zbf2-9VQTtGJ24k?usp=sharing

For us we only go down to 8 bits a weight (that colab show 4 bits per weight), so we won't be able to tune a 20b parameter model on free hardward. But I think we could tune a pre-quantized model of up to 10b on a T4.

@james77777778
Copy link
Contributor Author

we only need to keep around the gradients for the lora trainable parameters, which will add up to just a few MB. Overall the dominant memory requirements should just be 1 byte per frozen quantized parameter.

I'm unsure how JAX, TF and Torch calculate the gradients, but it is true that we only need a small amount of memory for the qlora-like technique.

However, we might also need a quantized optimizer to achieve this optimal reduction in memory usage.
Please see https://huggingface.co/docs/bitsandbytes/main/en/explanations/optimizers
It involves some low-level ops that might be hard to implement in Keras

For us we only go down to 8 bits a weight (that colab show 4 bits per weight), so we won't be able to tune a 20b parameter model on free hardward. But I think we could tune a pre-quantized model of up to 10b on a T4.

I can try a 7B/8B model on colab T4 once the quantized models are uploaded keras-team/keras-hub#1720

@mattdangerw
Copy link
Member

However, we might also need a quantized optimizer to achieve this optimal reduction in memory usage.
Please see https://huggingface.co/docs/bitsandbytes/main/en/explanations/optimizers
It involves some low-level ops that might be hard to implement in Keras

thanks for the pointer! i'll have a look

@james77777778
Copy link
Contributor Author

@fchollet @mattdangerw
Please let me know if this is ready. If so, I will push .md and .ipynb.

Note that I can only run this example on a CPU using jupyter notebook due to the GPU OOM issue.
This issue doesn't occur when I run the script on the same machine without using the notebook.

@mattdangerw
Copy link
Member

@james77777778 ready I think!

@james77777778
Copy link
Contributor Author

james77777778 commented Aug 6, 2024

@fchollet @mattdangerw

This PR should be ready now.
(Also fixed the link error.)

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you for the contribution!

@fchollet fchollet merged commit 4c89b46 into keras-team:master Aug 14, 2024
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants