-
Notifications
You must be signed in to change notification settings - Fork 606
Add GPTQ Quantization #1216
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Add GPTQ Quantization #1216
Changes from all commits
Commits
Show all changes
57 commits
Select commit
Hold shift + click to select a range
a1a33ef
v1 test draft
younesbelkada 6ca0f12
code runs but outputs gibberish.
younesbelkada 0fe030b
draft v1.1
SunMarc 71129bb
remove duplicate
SunMarc ac7023c
remove dep to transformers and cleaning
SunMarc 4abb9b8
Add serialization and loading
SunMarc 7150a97
Clean code and doc
SunMarc 2472804
add flexibility
SunMarc 88dbe0e
remove triton
SunMarc 90ec342
remove some dep with transformers
SunMarc c7e49a0
add testing
SunMarc 110c8c1
make style
SunMarc f64632f
add accelerate flag
SunMarc ed9b743
handle device placement
SunMarc f65a979
make style
SunMarc 7720b36
Apply suggestions
SunMarc 437329a
add doc in data.py
SunMarc cfe6239
apply suggestion for utils file
SunMarc 3254d6e
remove multiple output
SunMarc 939e4ab
fix Optional
SunMarc e39f5b7
Apply suggestions from code review
SunMarc f8a25e2
remove useless check
SunMarc 9afdbb4
fix doc and style
SunMarc e404bde
fix name
SunMarc 89d18d6
replace catcher by prefoward hook
SunMarc 7ac898a
update doctstring for true_sequential
SunMarc e34d960
apply suggestion
SunMarc d18226a
Fix import
SunMarc 754cd01
Add docstring for tests
SunMarc 6d10f73
move args
SunMarc bba3516
fix typo
SunMarc e662240
fix cpu offload and tokenizer
SunMarc 58e3e7b
fix typo
SunMarc 3633d43
fix offload cpu
SunMarc 1df19a1
modify attribute
SunMarc 28f4ce4
more explicit error
SunMarc a019885
dataset optional
SunMarc d272099
add tqdm bar instead
SunMarc 28acd3c
style
SunMarc ae77ffa
add doc
SunMarc c745309
replace by tqdm.auto
SunMarc 98591ab
Merge remote-tracking branch 'upstream/main' into add-gptq-marc
SunMarc 088f56f
change model
SunMarc 4b019ea
add CI
SunMarc 49362ac
Apply suggestions from code review
SunMarc 9de8918
Update .github/workflows/test_gptq.yml
SunMarc ba9b2c9
add peft compatibility
SunMarc e255ca9
Apply suggestions from code review doc
SunMarc b01bbfd
merge examples
SunMarc 62ac8bb
code review
SunMarc b0007fc
fix test
SunMarc 19dff00
make style
SunMarc 15727f7
change var
SunMarc c506947
fix doc
SunMarc 744c249
add exllama
SunMarc 66d7104
change naming
SunMarc b43d6e0
more doc
SunMarc File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,90 @@ | ||
| name: GPTQ Quantization / Test GPU | ||
|
|
||
| on: | ||
| workflow_dispatch: | ||
| schedule: | ||
| - cron: 0 1 */3 * * # at 1am every 3 days | ||
| pull_request: | ||
| types: [opened, synchronize, reopened, labeled] | ||
| # uncomment to enable on PR merge on main branch: | ||
| #push: | ||
| # branches: | ||
| # - main | ||
|
|
||
| jobs: | ||
| start-runner: | ||
| if: ${{ (github.event_name == 'workflow_dispatch') || (github.event_name == 'schedule') || contains( github.event.pull_request.labels.*.name, 'gpu-test') }} | ||
| name: Start self-hosted EC2 runner | ||
| runs-on: ubuntu-latest | ||
| env: | ||
| AWS_REGION: us-east-1 | ||
| EC2_AMI_ID: ami-0dc1c26161f869ed1 | ||
| EC2_INSTANCE_TYPE: g4dn.xlarge | ||
| EC2_SUBNET_ID: subnet-859322b4,subnet-b7533b96,subnet-47cfad21,subnet-a396b2ad,subnet-06576a4b,subnet-df0f6180 | ||
| EC2_SECURITY_GROUP: sg-0bb210cd3ec725a13 | ||
| EC2_IAM_ROLE: optimum-ec2-github-actions-role | ||
| outputs: | ||
| label: ${{ steps.start-ec2-runner.outputs.label }} | ||
| ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }} | ||
| steps: | ||
| - name: Configure AWS credentials | ||
| uses: aws-actions/configure-aws-credentials@v1 | ||
| with: | ||
| aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} | ||
| aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} | ||
| aws-region: ${{ env.AWS_REGION }} | ||
| - name: Start EC2 runner | ||
| id: start-ec2-runner | ||
| uses: philschmid/philschmid-ec2-github-runner@main | ||
| with: | ||
| mode: start | ||
| github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }} | ||
| ec2-image-id: ${{ env.EC2_AMI_ID }} | ||
| ec2-instance-type: ${{ env.EC2_INSTANCE_TYPE }} | ||
| subnet-id: ${{ env.EC2_SUBNET_ID }} | ||
| security-group-id: ${{ env.EC2_SECURITY_GROUP }} | ||
| iam-role-name: ${{ env.EC2_IAM_ROLE }} | ||
| aws-resource-tags: > # optional, requires additional permissions | ||
| [ | ||
| {"Key": "Name", "Value": "ec2-optimum-github-runner"}, | ||
| {"Key": "GitHubRepository", "Value": "${{ github.repository }}"} | ||
| ] | ||
| do-the-job: | ||
| name: Setup | ||
| needs: start-runner # required to start the main job when the runner is ready | ||
| runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner | ||
| env: | ||
| AWS_REGION: us-east-1 | ||
| steps: | ||
| - name: Checkout | ||
| uses: actions/checkout@v2 | ||
| - name: Build image | ||
| run: | | ||
| docker build -f tests/gptq/docker/Dockerfile_quantization_gpu -t gptq-gpu . | ||
| - name: Test with unittest within docker container | ||
| run: | | ||
| docker run --rm --gpus all -v $(pwd)/hf_cache:/root/.cache/huggingface --workdir=/workspace/optimum/tests gptq-gpu:latest | ||
|
|
||
| stop-runner: | ||
| name: Stop self-hosted EC2 runner | ||
| needs: | ||
| - start-runner # required to get output from the start-runner job | ||
| - do-the-job # required to wait when the main job is done | ||
| runs-on: ubuntu-latest | ||
| env: | ||
| AWS_REGION: us-east-1 | ||
| if: ${{ always() && !(needs.start-runner.result == 'skipped' && needs.do-the-job.result == 'skipped') }} # required to stop the runner even if the error happened in the previous jobs are all skipped | ||
| steps: | ||
| - name: Configure AWS credentials | ||
| uses: aws-actions/configure-aws-credentials@v1 | ||
| with: | ||
| aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} | ||
| aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} | ||
| aws-region: ${{ env.AWS_REGION }} | ||
| - name: Stop EC2 runner | ||
| uses: philschmid/philschmid-ec2-github-runner@main | ||
| with: | ||
| mode: stop | ||
| github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }} | ||
| label: ${{ needs.start-runner.outputs.label }} | ||
| ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
104 changes: 104 additions & 0 deletions
104
docs/source/llm_quantization/usage_guides/quantization.mdx
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,104 @@ | ||
| # Quantization | ||
|
|
||
| ## AutoGPTQ Integration | ||
|
|
||
| 🤗 Optimum collaborated with [AutoGPTQ library](https://github.com/PanQiWei/AutoGPTQ) to provide a simple API that apply GPTQ quantization on language models. With GPTQ quantization, you can quantize your favorite language model to 8, 6, 4 or even 2 bits. This comes without a big drop of performance and with faster inference speed. This is supported by most GPU hardwares. | ||
|
|
||
| If you want to quantize 🤗 Transformers models with GPTQ, follow this [documentation](https://huggingface.co/docs/transformers/main_classes/quantization). | ||
|
|
||
| To learn more about the quantization technique used in GPTQ, please refer to: | ||
| - the [GPTQ](https://arxiv.org/pdf/2210.17323.pdf) paper | ||
| - the [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ) library used as the backend | ||
| Note that the AutoGPTQ library provides more advanced usage (triton backend, fused attention, fused MLP) that are not integrated with Optimum. For now, we leverage only the CUDA kernel for GPTQ. | ||
|
|
||
| ### Requirements | ||
|
|
||
| You need to have the following requirements installed to run the code below: | ||
|
|
||
| - AutoGPTQ library: | ||
| `pip install auto-gptq` | ||
|
|
||
| - Optimum library: | ||
| `pip install --upgrade optimum` | ||
|
|
||
| - Install latest `transformers` library from source: | ||
| `pip install --upgrade git+https://github.com/huggingface/transformers.git` | ||
|
|
||
| - Install latest `accelerate` library: | ||
| `pip install --upgrade accelerate` | ||
|
|
||
| ### Load and quantize a model | ||
|
|
||
| The [`~optimum.gptq.GPTQQuantizer`] class is used to quantize your model. In order to quantize your model, you need to provide a few arguemnts: | ||
| - the number of bits: `bits` | ||
| - the dataset used to calibrate the quantization: `dataset` | ||
| - the model sequence length used to process the dataset: `model_seqlen` | ||
| - the block name to quantize: `block_name_to_quantize` | ||
|
|
||
| With 🤗 Transformers integration, you don't need to pass the `block_name_to_quantize` and `model_seqlen` as we can retrieve them. However, for custom model, you need to specify them. Also, make sure that your model is converted to `torch.float16` before quantization. | ||
|
|
||
| ```py | ||
| from transformers import AutoModelForCausalLM, AutoTokenizer | ||
| from optimum.gptq import GPTQQuantizer, load_quantized_model | ||
| import torch | ||
| model_name = "facebook/opt-125m" | ||
| tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
| model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16) | ||
|
|
||
| quantizer = GPTQQuantizer(bits=4, dataset="c4", block_name_to_quantize = "model.decoder.layers", model_seqlen = 2048) | ||
| quantized_model = quantizer.quantize_model(model, tokenizer) | ||
| ``` | ||
|
|
||
| <Tip warning={true}> | ||
| GPTQ quantization only works for text model for now. Futhermore, the quantization process can take a lot of time depending on one's hardware (175B model = 4 gpu hours using NVIDIA A100). Please check on the Hugging Face Hub if there is not already a GPTQ quantized version of the model you would like to quantize. | ||
| </Tip> | ||
|
|
||
| ### Save the model | ||
|
|
||
| To save your model, use the save method from [`~optimum.gptq.GPTQQuantizer`] class. It will create a folder with your model state dict along with the quantization config. | ||
| ```python | ||
| save_folder = "/path/to/save_folder/" | ||
| quantizer.save(model,save_folder) | ||
| ``` | ||
|
|
||
| ### Load quantized weights | ||
|
|
||
| You can load your quantized weights by using the [`~optimum.gptq.load_quantized_model`] function. | ||
| Through the Accelerate library, it is possible to load a model faster with a lower memory usage. The model needs to be initialized using empty weights, with weights loaded as a next step. | ||
| ```python | ||
| from accelerate import init_empty_weights | ||
| with init_empty_weights(): | ||
| empty_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16) | ||
| empty_model.tie_weights() | ||
| quantized_model = load_quantized_model(empty_model, save_folder=save_folder, device_map="auto") | ||
| ``` | ||
|
|
||
| ### Exllama kernels for faster inference | ||
|
|
||
| For 4-bit model, you can use the exllama kernels in order to a faster inference speed. It is activated by default. If you want to change its value, you just need to pass `disable_exllama` in [`~optimum.gptq.load_quantized_model`]. In order to use these kernels, you need to have the entire model on gpus. | ||
|
|
||
| ```py | ||
| from optimum.gptq import GPTQQuantizer, load_quantized_model | ||
| import torch | ||
|
|
||
| from accelerate import init_empty_weights | ||
| with init_empty_weights(): | ||
| empty_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16) | ||
| empty_model.tie_weights() | ||
| quantized_model = load_quantized_model(empty_model, save_folder=save_folder, device_map="auto", disable_exllama=False) | ||
| ``` | ||
|
|
||
| Note that only 4-bit models are supported with exllama kernels for now. Furthermore, it is recommended to disable the exllama kernel when you are finetuning your model with peft. | ||
|
|
||
| #### Fine-tune a quantized model | ||
|
|
||
| With the official support of adapters in the Hugging Face ecosystem, you can fine-tune models that have been quantized with GPTQ. | ||
| Please have a look at [`peft`](https://github.com/huggingface/peft) library for more details. | ||
|
|
||
| ### References | ||
|
|
||
| [[autodoc]] gtpq.GPTQQuantizer | ||
| - all | ||
|
|
||
| [[autodoc]] gtpq.load_quantized_model | ||
| - all |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| # coding=utf-8 | ||
| # Copyright 2023 HuggingFace Inc. team. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| from .quantizer import GPTQQuantizer, load_quantized_model |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| # Copyright 2023 The HuggingFace Team. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| SEQLEN_KEYS_TRANFORMERS = ["max_position_embeddings", "seq_length", "n_positions"] | ||
| BLOCK_PATTERNS = [ | ||
| "transformer.h", | ||
| "model.decoder.layers", | ||
| "gpt_neox.layers", | ||
| "model.layers", | ||
| ] | ||
|
|
||
| GPTQ_CONFIG = "quantization_config.json" | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.