Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HFQuantizer implementation for compressed-tensors library #31704

Merged
merged 42 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
d695ec3
Add compressed-tensors HFQuantizer implementation
Jun 5, 2024
f468964
flag serializable as False
Jun 5, 2024
41224d3
run
horheynm Jun 10, 2024
b61bfb9
revive lines deleted by ruff
horheynm Jun 10, 2024
ff8f1c5
fixes to load+save from sparseml, edit config to quantization_config,…
horheynm Jun 11, 2024
c1cb55d
address satrat comment
horheynm Jun 11, 2024
ef9d3f1
compressed_tensors to compressed-tensors and revert back is_serializable
horheynm Jun 12, 2024
117d050
rename quant_method from sparseml to compressed-tensors
horheynm Jun 12, 2024
1901c3e
tests
horheynm Jun 12, 2024
3ca270d
edit tests
horheynm Jun 13, 2024
9a14b09
clean up tests
Jun 28, 2024
ec59052
make style
Jun 28, 2024
520ded8
cleanup
Jun 28, 2024
7dec8fc
cleanup
Jun 28, 2024
afb550d
Merge branch 'main' into compressed-tensors-quantizer
bfineran Jul 25, 2024
d9b3660
add test skip for when compressed tensors is not installed
Jul 25, 2024
e51ac59
remove pydantic import + style
Jul 25, 2024
ccb5442
delay torch import in test
Jul 25, 2024
bfd9220
initial docs
Jul 30, 2024
71a80f9
update main init for compressed tensors config
Jul 30, 2024
547f9cc
make fix-copies
Jul 30, 2024
8acbc09
docstring
Jul 31, 2024
eaa5f20
remove fill_docstring
Jul 31, 2024
4ba75fb
Apply suggestions from code review
bfineran Aug 6, 2024
94ea0d3
review comments
Aug 6, 2024
c48840d
review comments
Aug 6, 2024
ab74d26
Merge branch 'main' into compressed-tensors-quantizer
bfineran Aug 19, 2024
2ecf711
comments - suppress warnings on state dict load, tests, fixes
Aug 20, 2024
e1ae504
bug-fix - remove unnecessary call to apply quant lifecycle
Aug 22, 2024
ea9e927
run_compressed compatability
Aug 30, 2024
1c3ad5c
revert changes not needed for compression
Sep 3, 2024
aa1a4f9
no longer need unexpected keys fn
Sep 3, 2024
81a13dd
unexpected keys not needed either
Sep 3, 2024
f53d7b9
Apply suggestions from code review
Satrat Sep 9, 2024
d8f7073
add to_diff_dict
Sep 9, 2024
c4fbf70
update docs and expand testing
Sep 11, 2024
1992a88
Merge remote-tracking branch 'upstream/main' into compressed-tensors-…
Sep 17, 2024
298a638
Update _toctree.yml with compressed-tensors
Satrat Sep 18, 2024
3cb4415
Update src/transformers/utils/quantization_config.py
Satrat Sep 23, 2024
a943157
Merge branch 'main' into compressed-tensors-quantizer
dsikka Sep 24, 2024
64f475a
update doc
dsikka Sep 24, 2024
fabe8a3
add note about saving a loaded model
dsikka Sep 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@
title: Optimum
- local: quantization/torchao
title: TorchAO
- local: quantization/compressed_tensors
title: compressed-tensors
- local: quantization/contribute
title: Contribute new quantization method
title: Quantization Methods
Expand Down
5 changes: 4 additions & 1 deletion docs/source/en/main_classes/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,10 @@ Learn how to quantize models in the [Quantization](../quantization) guide.

[[autodoc]] FbgemmFp8Config

## CompressedTensorsConfig

[[autodoc]] CompressedTensorsConfig

## TorchAoConfig

[[autodoc]] TorchAoConfig

230 changes: 230 additions & 0 deletions docs/source/en/quantization/compressed_tensors.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->
# Compressed Tensors

The [`compressed-tensors`](https://github.com/neuralmagic/compressed-tensors) library provides a versatile and efficient way to store and manage compressed model checkpoints. This library supports various quantization and sparsity schemes, making it a unified format for handling different model optimizations like GPTQ, AWQ, SmoothQuant, INT8, FP8, SparseGPT, and more.

Some of the supported formats include:
1. `dense`
2. `int-quantized`: INT8 quantized models
- sample [model/config](https://huggingface.co/nm-testing/tinyllama-w8a8-compressed-hf-quantizer)
3. `float-quantized`: FP8 quantized models; currently support E4M3
- sample [model/config](https://huggingface.co/nm-testing/Meta-Llama-3-8B-Instruct-fp8-hf_compat/tree/main)
4. `pack-quantized`: INT4 or INT8 weight-quantized models, packed into INT32. For INT4, the weights have an INT4 range but are stored as INT8 and then packed into INT32.
- sample [model/config](nm-testing/tinyllama-w4a16-compressed-hf-quantizer)

Compressed models can be easily created using [llm-compressor](https://github.com/vllm-project/llm-compressor).
Alternatively models can be created indepedenty and serialized with a compressed tensors config.

To find existing models on the Hugging Face Model Hub, search for the [`compressed-tensors` tag](https://huggingface.co/models?other=compressed-tensors).

#### Features:
- Weight and activation precisions: FP8, INT4, INT8 (for Q/DQ arbitrary precision is allowed for INT)
- Quantization scales and zero-points strategies: [tensor, channel, group, block, token](https://github.com/neuralmagic/compressed-tensors/blob/83b2e7a969d70606421a76b9a3d112646077c8de/src/compressed_tensors/quantization/quant_args.py#L43-L52)
- Dynamic per-token activation quantization (or any static strategy)
- Sparsity can be
- Supports quantization of arbitrary modules, not just Linear modules
- Targeted support or ignoring of modules by name or class

## Installation

It is recommended to install stable releases of compressed-tensors from [PyPI](https://pypi.org/project/compressed-tensors):
```bash
pip install compressed-tensors
```

Developers who want to experiment with the latest features can also install the package from source:
```bash
git clone https://github.com/neuralmagic/compressed-tensors
cd compressed-tensors
pip install -e .
```

## Quickstart Model Load
Quantized models can be easily loaded for inference as shown below. Only models that have already been quantized can be loaded at the moment. To quantize a model into the compressed-tensors format see [llm-compressor](https://github.com/vllm-project/llm-compressor).

```python
from transformers import AutoModelForCausalLM

# Load the model in compressed-tensors format
ct_model = AutoModelForCausalLM.from_pretrained("nm-testing/Meta-Llama-3.1-8B-Instruct-FP8-hf")

# Measure memory usage
mem_params = sum([param.nelement()*param.element_size() for param in ct_model.parameters()])
print(f"{mem/2**30:.4f} GB")
# 8.4575 GB
```

We can see just above that the compressed-tensors FP8 checkpoint of Llama 3.1 8B is able to be loaded for inference using half of the memory of the unquantized reference checkpoint.

## Sample Use Cases - Load and run an FP8 model

```python
from transformers import AutoModelForCausalLM, AutoTokenizer

prompt = [
"Hello, my name is",
"The capital of France is",
"The future of AI is"
]

model_name = "nm-testing/Meta-Llama-3-8B-Instruct-fp8-hf_compat"

quantized_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)

inputs = tokenizer(prompt, return_tensors="pt")
generated_ids = quantized_model.generate(**inputs, max_length=50, do_sample=False)
outputs = tokenizer.batch_decode(generated_ids)

print(outputs)

"""
['<|begin_of_text|>Hello, my name is [Name]. I am a [Your Profession/Student] and I am here to learn about the [Course/Program] at [University/Institution]. I am excited to be here and I am looking forward to', '<|begin_of_text|>The capital of France is Paris, which is located in the north-central part of the country. Paris is the most populous city in France and is known for its stunning architecture, art museums, fashion, and romantic atmosphere. The city is home to', "<|begin_of_text|>The future of AI is here, and it's already changing the way we live and work. From virtual assistants to self-driving cars, AI is transforming industries and revolutionizing the way we interact with technology. But what does the future of AI hold"]
"""

```

The above shows a quick example for running generation using a `compressed-tensors`
model. Currently, once loaded the model cannot be saved.

## Deep dive into a compressed-tensors model checkpoint

In this example we will examine how the compressed-tensors model nm-testing/Meta-Llama-3.1-8B-Instruct-FP8-hf is defined through its configuration entry and see how this translates to the loaded model representation.

First, let us look at the [`quantization_config` of the model](https://huggingface.co/nm-testing/Meta-Llama-3.1-8B-Instruct-FP8-hf/blob/main/config.json). At a glance it looks overwhelming with the number of entries but this is because compressed-tensors is a format that allows for flexible expression both during and after model compression.

In practice for checkpoint loading and inference the configuration can be simplified to not include all the default or empty entries, so we will do that here to focus on what compression is actually represented.

```yaml
"quantization_config": {
"config_groups": {
"group_0": {
"input_activations": {
"num_bits": 8,
"strategy": "tensor",
"type": "float"
},
"targets": ["Linear"],
"weights": {
"num_bits": 8,
"strategy": "tensor",
"type": "float"
}
}
},
"format": "naive-quantized",
"ignore": ["lm_head"],
"quant_method": "compressed-tensors",
"quantization_status": "frozen"
},
```

We can see from the above configuration that it is specifying one config group that includes weight and activation quantization to FP8 with a static per-tensor strategy. It is also worth noting that in the `ignore` list there is an entry to skip quantization of the `lm_head` module, so that module should be untouched in the checkpoint.

To see the result of the configuration in practice, we can simply use the [safetensors viewer](https://huggingface.co/nm-testing/Meta-Llama-3.1-8B-Instruct-FP8-hf?show_file_info=model.safetensors.index.json) on the model card to see the quantized weights, input_scale, and weight_scale for all of the Linear modules in the first model layer (and so on for the rest of the layers).

| Tensors | Shape | Precision |
| ------- | ----- | --------- |
model.layers.0.input_layernorm.weight | [4 096] | BF16
model.layers.0.mlp.down_proj.input_scale | [1] | BF16
model.layers.0.mlp.down_proj.weight | [4 096, 14 336] | F8_E4M3
model.layers.0.mlp.down_proj.weight_scale | [1] | BF16
model.layers.0.mlp.gate_proj.input_scale | [1] | BF16
model.layers.0.mlp.gate_proj.weight | [14 336, 4 096] | F8_E4M3
model.layers.0.mlp.gate_proj.weight_scale | [1] | BF16
model.layers.0.mlp.up_proj.input_scale| [1] |BF16
model.layers.0.mlp.up_proj.weight | [14 336, 4 096] | F8_E4M3
model.layers.0.mlp.up_proj.weight_scale | [1] | BF16
model.layers.0.post_attention_layernorm.weight | [4 096] |BF16
model.layers.0.self_attn.k_proj.input_scale | [1] | BF16
model.layers.0.self_attn.k_proj.weight | [1 024, 4 096]| F8_E4M3
model.layers.0.self_attn.k_proj.weight_scale |[1] | BF16
model.layers.0.self_attn.o_proj.input_scale | [1] | BF16
model.layers.0.self_attn.o_proj.weight | [4 096, 4 096] | F8_E4M3
model.layers.0.self_attn.o_proj.weight_scale | [1] | BF16
model.layers.0.self_attn.q_proj.input_scale | [1] | BF16
model.layers.0.self_attn.q_proj.weight | [4 096, 4 096] | F8_E4M3
model.layers.0.self_attn.q_proj.weight_scale | [1] | BF16
model.layers.0.self_attn.v_proj.input_scale | [1] | BF16
model.layers.0.self_attn.v_proj.weight | [1 024, 4 096] | F8_E4M3
model.layers.0.self_attn.v_proj.weight_scale | [1] | BF16

When we load the model with the compressed-tensors HFQuantizer integration, we can see that all of the Linear modules that are specified within the quantization configuration have been replaced by `CompressedLinear` modules that manage the compressed weights and forward pass for inference. Note that the `lm_head` mentioned before in the ignore list is still kept as an unquantized Linear module.

```python
from transformers import AutoModelForCausalLM

ct_model = AutoModelForCausalLM.from_pretrained("nm-testing/Meta-Llama-3.1-8B-Instruct-FP8-hf")
print(ct_model)
"""
LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(128256, 4096)
(layers): ModuleList(
(0-31): 32 x LlamaDecoderLayer(
(self_attn): LlamaSdpaAttention(
(q_proj): CompressedLinear(
in_features=4096, out_features=4096, bias=False
(input_observer): MovingAverageMinMaxObserver()
(weight_observer): MovingAverageMinMaxObserver()
)
(k_proj): CompressedLinear(
in_features=4096, out_features=1024, bias=False
(input_observer): MovingAverageMinMaxObserver()
(weight_observer): MovingAverageMinMaxObserver()
)
(v_proj): CompressedLinear(
in_features=4096, out_features=1024, bias=False
(input_observer): MovingAverageMinMaxObserver()
(weight_observer): MovingAverageMinMaxObserver()
)
(o_proj): CompressedLinear(
in_features=4096, out_features=4096, bias=False
(input_observer): MovingAverageMinMaxObserver()
(weight_observer): MovingAverageMinMaxObserver()
)
(rotary_emb): LlamaRotaryEmbedding()
)
(mlp): LlamaMLP(
(gate_proj): CompressedLinear(
in_features=4096, out_features=14336, bias=False
(input_observer): MovingAverageMinMaxObserver()
(weight_observer): MovingAverageMinMaxObserver()
)
(up_proj): CompressedLinear(
in_features=4096, out_features=14336, bias=False
(input_observer): MovingAverageMinMaxObserver()
(weight_observer): MovingAverageMinMaxObserver()
)
(down_proj): CompressedLinear(
in_features=14336, out_features=4096, bias=False
(input_observer): MovingAverageMinMaxObserver()
(weight_observer): MovingAverageMinMaxObserver()
)
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
(post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
)
)
(norm): LlamaRMSNorm((4096,), eps=1e-05)
(rotary_emb): LlamaRotaryEmbedding()
)
(lm_head): Linear(in_features=4096, out_features=128256, bias=False)
)
"""
```
1 change: 1 addition & 0 deletions docs/source/en/quantization/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ Use the table below to help you decide which quantization method to use.
| [AQLM](./aqlm) | 🔴 | 🟢 | 🟢 | 🔴 | 🔴 | 🟢 | 1 / 2 | 🟢 | 🟢 | 🟢 | https://github.com/Vahe1994/AQLM |
| [AWQ](./awq) | 🔴 | 🔴 | 🟢 | 🟢 | 🔴 | ? | 4 | 🟢 | 🟢 | 🟢 | https://github.com/casper-hansen/AutoAWQ |
| [bitsandbytes](./bitsandbytes) | 🟢 | 🟡 * | 🟢 | 🟡 * | 🔴 ** | 🔴 (soon!) | 4 / 8 | 🟢 | 🟢 | 🟢 | https://github.com/bitsandbytes-foundation/bitsandbytes |
| [compressed-tensors](./compressed_tensors) | 🔴 | 🟢 | 🟢 | 🟢 | 🔴 | 🔴 | 1 - 8 | 🟢 | 🟢 | 🟢 | https://github.com/neuralmagic/compressed-tensors |
| [EETQ](./eetq) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | ? | 8 | 🟢 | 🟢 | 🟢 | https://github.com/NetEase-FuXi/EETQ |
| GGUF / GGML (llama.cpp) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🔴 | 1 - 8 | 🔴 | [See GGUF section](../gguf) | [See GGUF section](../gguf) | https://github.com/ggerganov/llama.cpp |
| [GPTQ](./gptq) | 🔴 | 🔴 | 🟢 | 🟢 | 🔴 | 🔴 | 2 - 3 - 4 - 8 | 🟢 | 🟢 | 🟢 | https://github.com/AutoGPTQ/AutoGPTQ |
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,6 +958,7 @@
"AqlmConfig",
"AwqConfig",
"BitsAndBytesConfig",
"CompressedTensorsConfig",
"EetqConfig",
"FbgemmFp8Config",
"GPTQConfig",
Expand Down Expand Up @@ -5802,6 +5803,7 @@
AqlmConfig,
AwqConfig,
BitsAndBytesConfig,
CompressedTensorsConfig,
EetqConfig,
FbgemmFp8Config,
GPTQConfig,
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/quantizers/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
AqlmConfig,
AwqConfig,
BitsAndBytesConfig,
CompressedTensorsConfig,
EetqConfig,
FbgemmFp8Config,
GPTQConfig,
Expand All @@ -32,6 +33,7 @@
from .quantizer_awq import AwqQuantizer
from .quantizer_bnb_4bit import Bnb4BitHfQuantizer
from .quantizer_bnb_8bit import Bnb8BitHfQuantizer
from .quantizer_compressed_tensors import CompressedTensorsHfQuantizer
from .quantizer_eetq import EetqHfQuantizer
from .quantizer_fbgemm_fp8 import FbgemmFp8HfQuantizer
from .quantizer_gptq import GptqHfQuantizer
Expand All @@ -49,6 +51,7 @@
"quanto": QuantoHfQuantizer,
"eetq": EetqHfQuantizer,
"hqq": HqqHfQuantizer,
"compressed-tensors": CompressedTensorsHfQuantizer,
SunMarc marked this conversation as resolved.
Show resolved Hide resolved
"fbgemm_fp8": FbgemmFp8HfQuantizer,
"torchao": TorchAoHfQuantizer,
}
Expand All @@ -62,6 +65,7 @@
"aqlm": AqlmConfig,
"quanto": QuantoConfig,
"hqq": HqqConfig,
"compressed-tensors": CompressedTensorsConfig,
bfineran marked this conversation as resolved.
Show resolved Hide resolved
"fbgemm_fp8": FbgemmFp8Config,
"torchao": TorchAoConfig,
}
Expand Down
77 changes: 77 additions & 0 deletions src/transformers/quantizers/quantizer_compressed_tensors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from ..utils import is_compressed_tensors_available, is_torch_available, logging
from ..utils.quantization_config import QuantizationConfigMixin
from .base import HfQuantizer


if is_torch_available():
import torch

logger = logging.get_logger(__name__)


class CompressedTensorsHfQuantizer(HfQuantizer):
"""
Quantizer for the compressed_tensors package. Loads and restores models to
quantized state with compressed_tensors
"""

requires_calibration = True
required_packages = ["compressed_tensors"]

def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
super().__init__(quantization_config, **kwargs)

from compressed_tensors.compressors import ModelCompressor

self.compressor = ModelCompressor.from_compression_config(quantization_config)

def validate_environment(self, *args, **kwargs):
if not is_compressed_tensors_available():
raise ImportError(
"Using `compressed_tensors` quantized models requires the compressed-tensors library: "
"`pip install compressed-tensors`"
)
if not is_torch_available():
# torch already should be installed as part of compressed tensors
raise ImportError("torch is required for using compressed-tensors quantization")

def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
if torch_dtype is None:
logger.info("Loading model using torch.float16 for compressed-tensors quantization")
torch_dtype = torch.float16
elif torch_dtype != torch.float16:
logger.info(
"We suggest you to set `torch_dtype=torch.float16` for better efficiency with compressed_tensors."
)
return torch_dtype
Comment on lines +53 to +60
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an issue with bfloat16? We should try to allow this for llama models

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No issue with bfloat16, we just recommend float16 as a default since that is what vLLM expects for the scale/zp


def _process_model_before_weight_loading(self, model, **kwargs):
from compressed_tensors.quantization import apply_quantization_config

ct_quantization_config = self.compressor.quantization_config
apply_quantization_config(model, ct_quantization_config, run_compressed=True)
SunMarc marked this conversation as resolved.
Show resolved Hide resolved

def _process_model_after_weight_loading(self, model, **kwargs):
pass

@property
def is_trainable(self):
return False

@property
def is_serializable(self):
return False
Loading
Loading