Skip to content

Commit 55db70c

Browse files
SunMarcsguggeryounesbelkada
authored
GPTQ integration (#25062)
* GTPQ integration * Add tests for gptq * support for more quantization model * fix style * typo * fix method * Update src/transformers/modeling_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * add dataclass and fix quantization_method * fix doc * Update tests/quantization/gptq/test_gptq.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * modify dataclass * add gtpqconfig import * fix typo * fix tests * remove dataset as req arg * remove tokenizer import * add offload cpu quantization test * fix check dataset * modify dockerfile * protect trainer * style * test for config * add more log * overwrite torch_dtype * draft doc * modify quantization_config docstring * fix class name in docstring * Apply suggestions from code review Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * more warning * fix 8bit kwargs tests * peft compatibility * remove var * fix is_gptq_quantized * remove is_gptq_quantized * fix wrap * Update src/transformers/modeling_utils.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * add exllama * skip test * overwrite float16 * style * fix skip test * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * fix docsting formatting * add doc * better test --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
1 parent 3470012 commit 55db70c

File tree

16 files changed

+750
-103
lines changed

16 files changed

+750
-103
lines changed

docker/transformers-all-latest-gpu/Dockerfile

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,11 @@ RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/acc
4747
# Add bitsandbytes for mixed int8 testing
4848
RUN python3 -m pip install --no-cache-dir bitsandbytes
4949

50-
# For bettertransformer
51-
RUN python3 -m pip install --no-cache-dir optimum
50+
# Add auto-gptq for gtpq quantization testing
51+
RUN python3 -m pip install --no-cache-dir auto-gptq
52+
53+
# For bettertransformer + gptq
54+
RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/optimum@main#egg=optimum
5255

5356
# For video model testing
5457
RUN python3 -m pip install --no-cache-dir decord av==9.2.0

docs/source/en/main_classes/quantization.md

Lines changed: 132 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,137 @@ rendered properly in your Markdown viewer.
1616

1717
# Quantize 🤗 Transformers models
1818

19+
## `AutoGPTQ` Integration
20+
21+
🤗 Transformers has integrated `optimum` API to perform GPTQ quantization on language models. You can load and quantize your model in 8,6,4 or even 2 bits without a big drop of performance and faster inference speed! This is supported by most GPU hardwares.
22+
23+
To learn more about the the quantization model, check out:
24+
- the [GPTQ](https://arxiv.org/pdf/2210.17323.pdf) paper
25+
<!-- - the `optimum` [guide]() on GPTQ quantization -->
26+
- the [`AutoGPTQ`](https://github.com/PanQiWei/AutoGPTQ) library used as the backend
27+
28+
### Requirements
29+
30+
You need to have the following requirements installed to run the code below:
31+
32+
- Install latest `AutoGPTQ` library
33+
`pip install auto-gptq`
34+
35+
- Install latest `optimum` from source
36+
`pip install git+https://github.com/huggingface/optimum.git`
37+
38+
- Install latest `transformers` from source
39+
`pip install git+https://github.com/huggingface/transformers.git`
40+
41+
- Install latest `accelerate` library
42+
`pip install --upgrade accelerate`
43+
GPTQ integration supports for now only text models and you may encounter unexpected behaviour for vision, speech or multi-modal models.
44+
45+
### Load and quantize a model
46+
47+
GPTQ is a quantization method that requires weights calibration before using the quantized models. If you want to quantize transformers model from scratch, it might take some time before producing the quantized model (~10 min on a Google colab for `facebook/opt-350m` model.
48+
49+
Hence, there are two different scenarios where you want to use GPTQ-quantized models. The first use case would be to load models that has been already quantized by other users that are available on the Hub, the second use case would be to quantize your model from scratch and save it or push it on the Hub so that other users can also use it.
50+
#### GPTQ Configuration
51+
52+
In order to load and quantize a model, you need to create a [`GPTQConfig`]. You need to pass the number of `bits`, a `dataset` in order to calibrate the quantization and the `tokenizer` of the model in order prepare the dataset.
53+
54+
```python
55+
model_id = "facebook/opt-125m"
56+
tokenizer = AutoTokenizer.from_pretrained(model_id)
57+
gptq_config = GPTQConfig(bits=4, dataset = "c4", tokenizer=tokenizer)
58+
```
59+
60+
Note that you can pass your own dataset as a list of string. However, it is highly recommended to use the dataset from the GPTQ paper.
61+
```python
62+
dataset = ["auto-gptq is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm."]
63+
quantization = GPTQConfig(bits=4, dataset = dataset, tokenizer=tokenizer)
64+
```
65+
66+
#### Quantization
67+
68+
You can quantize a model by using `from_pretrained` and setting the `quantization_config`.
69+
70+
```python
71+
from transformers import AutoModelForCausalLM
72+
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=gptq_config)
73+
```
74+
Note that you will need a GPU to quantize a model. We will put the model in the cpu and move the modules back and forth to the gpu in order to quantize them.
75+
76+
If you want to maximize your gpus usage while using cpu offload, you can set `device_map = "auto"`.
77+
```python
78+
from transformers import AutoModelForCausalLM
79+
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", quantization_config=gptq_config)
80+
```
81+
Note that disk offload is not supported. Furthermore, if you are out of memory because of the dataset, you may have to pass `max_memory` in `from_pretained`. Checkout this [guide](https://huggingface.co/docs/accelerate/usage_guides/big_modeling#designing-a-device-map) to learn more about `device_map` and `max_memory`.
82+
83+
<Tip warning={true}>
84+
GPTQ quantization only works for text model for now. Futhermore, the quantization process can a lot of time depending on one's hardware (175B model = 4 gpu hours using NVIDIA A100). Please check on the hub if there is not a GPTQ quantized version of the model. If not, you can submit a demand on github.
85+
</Tip>
86+
87+
### Push quantized model to 🤗 Hub
88+
89+
You can push the quantized model like any 🤗 model to Hub with `push_to_hub`. The quantization config will be saved and pushed along the model.
90+
91+
```python
92+
quantized_model.push_to_hub("opt-125m-gptq")
93+
tokenizer.push_to_hub("opt-125m-gptq")
94+
```
95+
96+
If you want to save your quantized model on your local machine, you can also do it with `save_pretrained`:
97+
```python
98+
quantized_model.save_pretrained("opt-125m-gptq")
99+
tokenizer.save_pretrained("opt-125m-gptq")
100+
```
101+
102+
Note that if you have quantized your model with a `device_map`, make sure to move the entire model to one of your gpus or the `cpu` before saving it.
103+
```python
104+
quantized_model.to("cpu")
105+
quantized_model.save_pretrained("opt-125m-gptq")
106+
```
107+
108+
### Load a quantized model from the 🤗 Hub
109+
110+
You can load a quantized model from the Hub by using `from_pretrained`.
111+
Make sure that the pushed weights are quantized, by checking that the attribute `quantization_config` is present in the model configuration object.
112+
113+
```python
114+
from transformers import AutoModelForCausalLM
115+
model = AutoModelForCausalLM.from_pretrained("{your_username}/opt-125m-gptq")
116+
```
117+
118+
If you want to load a model faster and without allocating more memory than needed, the `device_map` argument also works with quantized model. Make sure that you have `accelerate` library installed.
119+
```python
120+
from transformers import AutoModelForCausalLM
121+
model = AutoModelForCausalLM.from_pretrained("{your_username}/opt-125m-gptq", device_map="auto")
122+
```
123+
124+
### Exllama kernels for faster inference
125+
126+
For 4-bit model, you can use the exllama kernels in order to a faster inference speed. It is activated by default. You can change that behavior by passing `disable_exllama` in [`GPTQConfig`]. This will overwrite the quantization config stored in the config. Note that you will only be able to overwrite the attributes related to the kernels. Furthermore, you need to have the entire model on gpus if you want to use exllama kernels.
127+
128+
```py
129+
import torch
130+
gptq_config = GPTQConfig(bits=4, disable_exllama=False)
131+
model = AutoModelForCausalLM.from_pretrained("{your_username}/opt-125m-gptq", device_map="auto", quantization_config = gptq_config)
132+
```
133+
134+
Note that only 4-bit models are supported for now. Furthermore, it is recommended to deactivate the exllama kernels if you are finetuning a quantized model with peft.
135+
136+
#### Fine-tune a quantized model
137+
138+
With the official support of adapters in the Hugging Face ecosystem, you can fine-tune models that have been quantized with GPTQ.
139+
Please have a look at [`peft`](https://github.com/huggingface/peft) library for more details.
140+
141+
### Example demo
142+
143+
Check out the Google Colab [notebook](https://colab.research.google.com/drive/1_TIrmuKOFhuRRiTWN94iLKUFu6ZX4ceb?usp=sharing) to learn how to quantize your model with GPTQ and how finetune the quantized model with peft.
144+
145+
### GPTQConfig
146+
147+
[[autodoc]] GPTQConfig
148+
149+
19150
## `bitsandbytes` Integration
20151

21152
🤗 Transformers is closely integrated with most used modules on `bitsandbytes`. You can load your model in 8-bit precision with few lines of code.
@@ -215,7 +346,7 @@ This section is intended to advanced users, that want to explore what it is poss
215346

216347
One of the advanced use case of this is being able to load a model and dispatch the weights between `CPU` and `GPU`. Note that the weights that will be dispatched on CPU **will not** be converted in 8-bit, thus kept in `float32`. This feature is intended for users that want to fit a very large model and dispatch the model between GPU and CPU.
217348

218-
First, load a `BitsAndBytesConfig` from `transformers` and set the attribute `llm_int8_enable_fp32_cpu_offload` to `True`:
349+
First, load a [`BitsAndBytesConfig`] from `transformers` and set the attribute `llm_int8_enable_fp32_cpu_offload` to `True`:
219350

220351
```python
221352
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

src/transformers/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -731,7 +731,7 @@
731731
"logging",
732732
],
733733
"utils.bitsandbytes": [],
734-
"utils.quantization_config": ["BitsAndBytesConfig"],
734+
"utils.quantization_config": ["BitsAndBytesConfig", "GPTQConfig"],
735735
}
736736

737737
# sentencepiece-backed objects
@@ -4703,7 +4703,7 @@
47034703
)
47044704

47054705
# bitsandbytes config
4706-
from .utils.quantization_config import BitsAndBytesConfig
4706+
from .utils.quantization_config import BitsAndBytesConfig, GPTQConfig
47074707

47084708
try:
47094709
if not is_sentencepiece_available():

0 commit comments

Comments
 (0)