forked from huggingface/transformers
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
HFQuantizer implementation for compressed-tensors library (huggingfac…
…e#31704) * Add compressed-tensors HFQuantizer implementation * flag serializable as False * run * revive lines deleted by ruff * fixes to load+save from sparseml, edit config to quantization_config, and load back * address satrat comment * compressed_tensors to compressed-tensors and revert back is_serializable * rename quant_method from sparseml to compressed-tensors * tests * edit tests * clean up tests * make style * cleanup * cleanup * add test skip for when compressed tensors is not installed * remove pydantic import + style * delay torch import in test * initial docs * update main init for compressed tensors config * make fix-copies * docstring * remove fill_docstring * Apply suggestions from code review Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * review comments * review comments * comments - suppress warnings on state dict load, tests, fixes * bug-fix - remove unnecessary call to apply quant lifecycle * run_compressed compatability * revert changes not needed for compression * no longer need unexpected keys fn * unexpected keys not needed either * Apply suggestions from code review Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * add to_diff_dict * update docs and expand testing * Update _toctree.yml with compressed-tensors * Update src/transformers/utils/quantization_config.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * update doc * add note about saving a loaded model --------- Co-authored-by: George Ohashi <george@neuralmagic.com> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: Sara Adkins <sara@neuralmagic.com> Co-authored-by: Sara Adkins <sara.adkins65@gmail.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: Dipika Sikka <ds3822@columbia.edu> Co-authored-by: Dipika <dipikasikka1@gmail.com>
- Loading branch information
1 parent
34a9142
commit 6aeec65
Showing
13 changed files
with
546 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,230 @@ | ||
<!--Copyright 2024 The HuggingFace Team. All rights reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
the License. You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
specific language governing permissions and limitations under the License. | ||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be | ||
rendered properly in your Markdown viewer. | ||
--> | ||
# Compressed Tensors | ||
|
||
The [`compressed-tensors`](https://github.com/neuralmagic/compressed-tensors) library provides a versatile and efficient way to store and manage compressed model checkpoints. This library supports various quantization and sparsity schemes, making it a unified format for handling different model optimizations like GPTQ, AWQ, SmoothQuant, INT8, FP8, SparseGPT, and more. | ||
|
||
Some of the supported formats include: | ||
1. `dense` | ||
2. `int-quantized`: INT8 quantized models | ||
- sample [model/config](https://huggingface.co/nm-testing/tinyllama-w8a8-compressed-hf-quantizer) | ||
3. `float-quantized`: FP8 quantized models; currently support E4M3 | ||
- sample [model/config](https://huggingface.co/nm-testing/Meta-Llama-3-8B-Instruct-fp8-hf_compat/tree/main) | ||
4. `pack-quantized`: INT4 or INT8 weight-quantized models, packed into INT32. For INT4, the weights have an INT4 range but are stored as INT8 and then packed into INT32. | ||
- sample [model/config](nm-testing/tinyllama-w4a16-compressed-hf-quantizer) | ||
|
||
Compressed models can be easily created using [llm-compressor](https://github.com/vllm-project/llm-compressor). | ||
Alternatively models can be created indepedenty and serialized with a compressed tensors config. | ||
|
||
To find existing models on the Hugging Face Model Hub, search for the [`compressed-tensors` tag](https://huggingface.co/models?other=compressed-tensors). | ||
|
||
#### Features: | ||
- Weight and activation precisions: FP8, INT4, INT8 (for Q/DQ arbitrary precision is allowed for INT) | ||
- Quantization scales and zero-points strategies: [tensor, channel, group, block, token](https://github.com/neuralmagic/compressed-tensors/blob/83b2e7a969d70606421a76b9a3d112646077c8de/src/compressed_tensors/quantization/quant_args.py#L43-L52) | ||
- Dynamic per-token activation quantization (or any static strategy) | ||
- Sparsity can be | ||
- Supports quantization of arbitrary modules, not just Linear modules | ||
- Targeted support or ignoring of modules by name or class | ||
|
||
## Installation | ||
|
||
It is recommended to install stable releases of compressed-tensors from [PyPI](https://pypi.org/project/compressed-tensors): | ||
```bash | ||
pip install compressed-tensors | ||
``` | ||
|
||
Developers who want to experiment with the latest features can also install the package from source: | ||
```bash | ||
git clone https://github.com/neuralmagic/compressed-tensors | ||
cd compressed-tensors | ||
pip install -e . | ||
``` | ||
|
||
## Quickstart Model Load | ||
Quantized models can be easily loaded for inference as shown below. Only models that have already been quantized can be loaded at the moment. To quantize a model into the compressed-tensors format see [llm-compressor](https://github.com/vllm-project/llm-compressor). | ||
|
||
```python | ||
from transformers import AutoModelForCausalLM | ||
|
||
# Load the model in compressed-tensors format | ||
ct_model = AutoModelForCausalLM.from_pretrained("nm-testing/Meta-Llama-3.1-8B-Instruct-FP8-hf") | ||
|
||
# Measure memory usage | ||
mem_params = sum([param.nelement()*param.element_size() for param in ct_model.parameters()]) | ||
print(f"{mem/2**30:.4f} GB") | ||
# 8.4575 GB | ||
``` | ||
|
||
We can see just above that the compressed-tensors FP8 checkpoint of Llama 3.1 8B is able to be loaded for inference using half of the memory of the unquantized reference checkpoint. | ||
|
||
## Sample Use Cases - Load and run an FP8 model | ||
|
||
```python | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
||
prompt = [ | ||
"Hello, my name is", | ||
"The capital of France is", | ||
"The future of AI is" | ||
] | ||
|
||
model_name = "nm-testing/Meta-Llama-3-8B-Instruct-fp8-hf_compat" | ||
|
||
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto") | ||
tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
|
||
inputs = tokenizer(prompt, return_tensors="pt") | ||
generated_ids = quantized_model.generate(**inputs, max_length=50, do_sample=False) | ||
outputs = tokenizer.batch_decode(generated_ids) | ||
|
||
print(outputs) | ||
|
||
""" | ||
['<|begin_of_text|>Hello, my name is [Name]. I am a [Your Profession/Student] and I am here to learn about the [Course/Program] at [University/Institution]. I am excited to be here and I am looking forward to', '<|begin_of_text|>The capital of France is Paris, which is located in the north-central part of the country. Paris is the most populous city in France and is known for its stunning architecture, art museums, fashion, and romantic atmosphere. The city is home to', "<|begin_of_text|>The future of AI is here, and it's already changing the way we live and work. From virtual assistants to self-driving cars, AI is transforming industries and revolutionizing the way we interact with technology. But what does the future of AI hold"] | ||
""" | ||
|
||
``` | ||
|
||
The above shows a quick example for running generation using a `compressed-tensors` | ||
model. Currently, once loaded the model cannot be saved. | ||
|
||
## Deep dive into a compressed-tensors model checkpoint | ||
|
||
In this example we will examine how the compressed-tensors model nm-testing/Meta-Llama-3.1-8B-Instruct-FP8-hf is defined through its configuration entry and see how this translates to the loaded model representation. | ||
|
||
First, let us look at the [`quantization_config` of the model](https://huggingface.co/nm-testing/Meta-Llama-3.1-8B-Instruct-FP8-hf/blob/main/config.json). At a glance it looks overwhelming with the number of entries but this is because compressed-tensors is a format that allows for flexible expression both during and after model compression. | ||
|
||
In practice for checkpoint loading and inference the configuration can be simplified to not include all the default or empty entries, so we will do that here to focus on what compression is actually represented. | ||
|
||
```yaml | ||
"quantization_config": { | ||
"config_groups": { | ||
"group_0": { | ||
"input_activations": { | ||
"num_bits": 8, | ||
"strategy": "tensor", | ||
"type": "float" | ||
}, | ||
"targets": ["Linear"], | ||
"weights": { | ||
"num_bits": 8, | ||
"strategy": "tensor", | ||
"type": "float" | ||
} | ||
} | ||
}, | ||
"format": "naive-quantized", | ||
"ignore": ["lm_head"], | ||
"quant_method": "compressed-tensors", | ||
"quantization_status": "frozen" | ||
}, | ||
``` | ||
|
||
We can see from the above configuration that it is specifying one config group that includes weight and activation quantization to FP8 with a static per-tensor strategy. It is also worth noting that in the `ignore` list there is an entry to skip quantization of the `lm_head` module, so that module should be untouched in the checkpoint. | ||
|
||
To see the result of the configuration in practice, we can simply use the [safetensors viewer](https://huggingface.co/nm-testing/Meta-Llama-3.1-8B-Instruct-FP8-hf?show_file_info=model.safetensors.index.json) on the model card to see the quantized weights, input_scale, and weight_scale for all of the Linear modules in the first model layer (and so on for the rest of the layers). | ||
|
||
| Tensors | Shape | Precision | | ||
| ------- | ----- | --------- | | ||
model.layers.0.input_layernorm.weight | [4 096] | BF16 | ||
model.layers.0.mlp.down_proj.input_scale | [1] | BF16 | ||
model.layers.0.mlp.down_proj.weight | [4 096, 14 336] | F8_E4M3 | ||
model.layers.0.mlp.down_proj.weight_scale | [1] | BF16 | ||
model.layers.0.mlp.gate_proj.input_scale | [1] | BF16 | ||
model.layers.0.mlp.gate_proj.weight | [14 336, 4 096] | F8_E4M3 | ||
model.layers.0.mlp.gate_proj.weight_scale | [1] | BF16 | ||
model.layers.0.mlp.up_proj.input_scale| [1] |BF16 | ||
model.layers.0.mlp.up_proj.weight | [14 336, 4 096] | F8_E4M3 | ||
model.layers.0.mlp.up_proj.weight_scale | [1] | BF16 | ||
model.layers.0.post_attention_layernorm.weight | [4 096] |BF16 | ||
model.layers.0.self_attn.k_proj.input_scale | [1] | BF16 | ||
model.layers.0.self_attn.k_proj.weight | [1 024, 4 096]| F8_E4M3 | ||
model.layers.0.self_attn.k_proj.weight_scale |[1] | BF16 | ||
model.layers.0.self_attn.o_proj.input_scale | [1] | BF16 | ||
model.layers.0.self_attn.o_proj.weight | [4 096, 4 096] | F8_E4M3 | ||
model.layers.0.self_attn.o_proj.weight_scale | [1] | BF16 | ||
model.layers.0.self_attn.q_proj.input_scale | [1] | BF16 | ||
model.layers.0.self_attn.q_proj.weight | [4 096, 4 096] | F8_E4M3 | ||
model.layers.0.self_attn.q_proj.weight_scale | [1] | BF16 | ||
model.layers.0.self_attn.v_proj.input_scale | [1] | BF16 | ||
model.layers.0.self_attn.v_proj.weight | [1 024, 4 096] | F8_E4M3 | ||
model.layers.0.self_attn.v_proj.weight_scale | [1] | BF16 | ||
|
||
When we load the model with the compressed-tensors HFQuantizer integration, we can see that all of the Linear modules that are specified within the quantization configuration have been replaced by `CompressedLinear` modules that manage the compressed weights and forward pass for inference. Note that the `lm_head` mentioned before in the ignore list is still kept as an unquantized Linear module. | ||
|
||
```python | ||
from transformers import AutoModelForCausalLM | ||
|
||
ct_model = AutoModelForCausalLM.from_pretrained("nm-testing/Meta-Llama-3.1-8B-Instruct-FP8-hf") | ||
print(ct_model) | ||
""" | ||
LlamaForCausalLM( | ||
(model): LlamaModel( | ||
(embed_tokens): Embedding(128256, 4096) | ||
(layers): ModuleList( | ||
(0-31): 32 x LlamaDecoderLayer( | ||
(self_attn): LlamaSdpaAttention( | ||
(q_proj): CompressedLinear( | ||
in_features=4096, out_features=4096, bias=False | ||
(input_observer): MovingAverageMinMaxObserver() | ||
(weight_observer): MovingAverageMinMaxObserver() | ||
) | ||
(k_proj): CompressedLinear( | ||
in_features=4096, out_features=1024, bias=False | ||
(input_observer): MovingAverageMinMaxObserver() | ||
(weight_observer): MovingAverageMinMaxObserver() | ||
) | ||
(v_proj): CompressedLinear( | ||
in_features=4096, out_features=1024, bias=False | ||
(input_observer): MovingAverageMinMaxObserver() | ||
(weight_observer): MovingAverageMinMaxObserver() | ||
) | ||
(o_proj): CompressedLinear( | ||
in_features=4096, out_features=4096, bias=False | ||
(input_observer): MovingAverageMinMaxObserver() | ||
(weight_observer): MovingAverageMinMaxObserver() | ||
) | ||
(rotary_emb): LlamaRotaryEmbedding() | ||
) | ||
(mlp): LlamaMLP( | ||
(gate_proj): CompressedLinear( | ||
in_features=4096, out_features=14336, bias=False | ||
(input_observer): MovingAverageMinMaxObserver() | ||
(weight_observer): MovingAverageMinMaxObserver() | ||
) | ||
(up_proj): CompressedLinear( | ||
in_features=4096, out_features=14336, bias=False | ||
(input_observer): MovingAverageMinMaxObserver() | ||
(weight_observer): MovingAverageMinMaxObserver() | ||
) | ||
(down_proj): CompressedLinear( | ||
in_features=14336, out_features=4096, bias=False | ||
(input_observer): MovingAverageMinMaxObserver() | ||
(weight_observer): MovingAverageMinMaxObserver() | ||
) | ||
(act_fn): SiLU() | ||
) | ||
(input_layernorm): LlamaRMSNorm((4096,), eps=1e-05) | ||
(post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05) | ||
) | ||
) | ||
(norm): LlamaRMSNorm((4096,), eps=1e-05) | ||
(rotary_emb): LlamaRotaryEmbedding() | ||
) | ||
(lm_head): Linear(in_features=4096, out_features=128256, bias=False) | ||
) | ||
""" | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
77 changes: 77 additions & 0 deletions
77
src/transformers/quantizers/quantizer_compressed_tensors.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from ..utils import is_compressed_tensors_available, is_torch_available, logging | ||
from ..utils.quantization_config import QuantizationConfigMixin | ||
from .base import HfQuantizer | ||
|
||
|
||
if is_torch_available(): | ||
import torch | ||
|
||
logger = logging.get_logger(__name__) | ||
|
||
|
||
class CompressedTensorsHfQuantizer(HfQuantizer): | ||
""" | ||
Quantizer for the compressed_tensors package. Loads and restores models to | ||
quantized state with compressed_tensors | ||
""" | ||
|
||
requires_calibration = True | ||
required_packages = ["compressed_tensors"] | ||
|
||
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs): | ||
super().__init__(quantization_config, **kwargs) | ||
|
||
from compressed_tensors.compressors import ModelCompressor | ||
|
||
self.compressor = ModelCompressor.from_compression_config(quantization_config) | ||
|
||
def validate_environment(self, *args, **kwargs): | ||
if not is_compressed_tensors_available(): | ||
raise ImportError( | ||
"Using `compressed_tensors` quantized models requires the compressed-tensors library: " | ||
"`pip install compressed-tensors`" | ||
) | ||
if not is_torch_available(): | ||
# torch already should be installed as part of compressed tensors | ||
raise ImportError("torch is required for using compressed-tensors quantization") | ||
|
||
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": | ||
if torch_dtype is None: | ||
logger.info("Loading model using torch.float16 for compressed-tensors quantization") | ||
torch_dtype = torch.float16 | ||
elif torch_dtype != torch.float16: | ||
logger.info( | ||
"We suggest you to set `torch_dtype=torch.float16` for better efficiency with compressed_tensors." | ||
) | ||
return torch_dtype | ||
|
||
def _process_model_before_weight_loading(self, model, **kwargs): | ||
from compressed_tensors.quantization import apply_quantization_config | ||
|
||
ct_quantization_config = self.compressor.quantization_config | ||
apply_quantization_config(model, ct_quantization_config, run_compressed=True) | ||
|
||
def _process_model_after_weight_loading(self, model, **kwargs): | ||
pass | ||
|
||
@property | ||
def is_trainable(self): | ||
return False | ||
|
||
@property | ||
def is_serializable(self): | ||
return False |
Oops, something went wrong.