Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HFQuantizer implementation for compressed-tensors library #31704

Merged
merged 42 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
d695ec3
Add compressed-tensors HFQuantizer implementation
Jun 5, 2024
f468964
flag serializable as False
Jun 5, 2024
41224d3
run
horheynm Jun 10, 2024
b61bfb9
revive lines deleted by ruff
horheynm Jun 10, 2024
ff8f1c5
fixes to load+save from sparseml, edit config to quantization_config,…
horheynm Jun 11, 2024
c1cb55d
address satrat comment
horheynm Jun 11, 2024
ef9d3f1
compressed_tensors to compressed-tensors and revert back is_serializable
horheynm Jun 12, 2024
117d050
rename quant_method from sparseml to compressed-tensors
horheynm Jun 12, 2024
1901c3e
tests
horheynm Jun 12, 2024
3ca270d
edit tests
horheynm Jun 13, 2024
9a14b09
clean up tests
Jun 28, 2024
ec59052
make style
Jun 28, 2024
520ded8
cleanup
Jun 28, 2024
7dec8fc
cleanup
Jun 28, 2024
afb550d
Merge branch 'main' into compressed-tensors-quantizer
bfineran Jul 25, 2024
d9b3660
add test skip for when compressed tensors is not installed
Jul 25, 2024
e51ac59
remove pydantic import + style
Jul 25, 2024
ccb5442
delay torch import in test
Jul 25, 2024
bfd9220
initial docs
Jul 30, 2024
71a80f9
update main init for compressed tensors config
Jul 30, 2024
547f9cc
make fix-copies
Jul 30, 2024
8acbc09
docstring
Jul 31, 2024
eaa5f20
remove fill_docstring
Jul 31, 2024
4ba75fb
Apply suggestions from code review
bfineran Aug 6, 2024
94ea0d3
review comments
Aug 6, 2024
c48840d
review comments
Aug 6, 2024
ab74d26
Merge branch 'main' into compressed-tensors-quantizer
bfineran Aug 19, 2024
2ecf711
comments - suppress warnings on state dict load, tests, fixes
Aug 20, 2024
e1ae504
bug-fix - remove unnecessary call to apply quant lifecycle
Aug 22, 2024
ea9e927
run_compressed compatability
Aug 30, 2024
1c3ad5c
revert changes not needed for compression
Sep 3, 2024
aa1a4f9
no longer need unexpected keys fn
Sep 3, 2024
81a13dd
unexpected keys not needed either
Sep 3, 2024
f53d7b9
Apply suggestions from code review
Satrat Sep 9, 2024
d8f7073
add to_diff_dict
Sep 9, 2024
c4fbf70
update docs and expand testing
Sep 11, 2024
1992a88
Merge remote-tracking branch 'upstream/main' into compressed-tensors-…
Sep 17, 2024
298a638
Update _toctree.yml with compressed-tensors
Satrat Sep 18, 2024
3cb4415
Update src/transformers/utils/quantization_config.py
Satrat Sep 23, 2024
a943157
Merge branch 'main' into compressed-tensors-quantizer
dsikka Sep 24, 2024
64f475a
update doc
dsikka Sep 24, 2024
fabe8a3
add note about saving a loaded model
dsikka Sep 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion docs/source/en/main_classes/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,10 @@ Learn how to quantize models in the [Quantization](../quantization) guide.

[[autodoc]] FbgemmFp8Config

## CompressedTensorsConfig

[[autodoc]] CompressedTensorsConfig

## TorchAoConfig

[[autodoc]] TorchAoConfig

47 changes: 47 additions & 0 deletions docs/source/en/quantization/compressed_tensors.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->

# Compressed Tensors

Compressed tensors supports the quantization of models to a variety of formats and provides an extensible
framework for adding new formats and strategies.

Compressed models can be easily created using [llm-compressor](https://github.com/vllm-project/llm-compressor).
Alternatively models can be created indepedenty and serialized with a compressed tensors config.

Supported formats include:

- FP8, INT4, INT8 (for Q/DQ arbitrary precision is allowed for INT)
- Activation quantization (static)
- Dynamic per-token activation quantization
- Supports quantization of arbitrary layer types
SunMarc marked this conversation as resolved.
Show resolved Hide resolved
- Targeted support or ignoring of layers by name or class

## Installation

```bash
pip install compressed-tensors
```


## Sample Model Load
```python
from transformers import AutoModelForCausalLM
compressed_tensors_model = AutoModelForCausalLM.from_pretrained("nm-testing/tinyllama-oneshot-w4a16-group128-v3")
```

SunMarc marked this conversation as resolved.
Show resolved Hide resolved

## More Coming Soon!
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you have example use-cases of the config's different parameters and why they should be use would be awesome here as well!

2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,6 +945,7 @@
"AqlmConfig",
"AwqConfig",
"BitsAndBytesConfig",
"CompressedTensorsConfig",
"EetqConfig",
"FbgemmFp8Config",
"GPTQConfig",
Expand Down Expand Up @@ -5742,6 +5743,7 @@
AqlmConfig,
AwqConfig,
BitsAndBytesConfig,
CompressedTensorsConfig,
EetqConfig,
FbgemmFp8Config,
GPTQConfig,
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4035,7 +4035,7 @@ def from_pretrained(
dispatch_model(model, **device_map_kwargs)

if hf_quantizer is not None:
hf_quantizer.postprocess_model(model)
hf_quantizer.postprocess_model(model, resolved_archive_file=resolved_archive_file)
SunMarc marked this conversation as resolved.
Show resolved Hide resolved
model.hf_quantizer = hf_quantizer

if _adapter_model_path is not None:
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/quantizers/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
AqlmConfig,
AwqConfig,
BitsAndBytesConfig,
CompressedTensorsConfig,
EetqConfig,
FbgemmFp8Config,
GPTQConfig,
Expand All @@ -32,6 +33,7 @@
from .quantizer_awq import AwqQuantizer
from .quantizer_bnb_4bit import Bnb4BitHfQuantizer
from .quantizer_bnb_8bit import Bnb8BitHfQuantizer
from .quantizer_compressed_tensors import CompressedTensorsHfQuantizer
from .quantizer_eetq import EetqHfQuantizer
from .quantizer_fbgemm_fp8 import FbgemmFp8HfQuantizer
from .quantizer_gptq import GptqHfQuantizer
Expand All @@ -49,6 +51,7 @@
"quanto": QuantoHfQuantizer,
"eetq": EetqHfQuantizer,
"hqq": HqqHfQuantizer,
"compressed_tensors": CompressedTensorsHfQuantizer,
"fbgemm_fp8": FbgemmFp8HfQuantizer,
"torchao": TorchAoHfQuantizer,
}
Expand All @@ -62,6 +65,7 @@
"aqlm": AqlmConfig,
"quanto": QuantoConfig,
"hqq": HqqConfig,
"compressed_tensors": CompressedTensorsConfig,
"fbgemm_fp8": FbgemmFp8Config,
"torchao": TorchAoConfig,
}
Expand Down
76 changes: 76 additions & 0 deletions src/transformers/quantizers/quantizer_compressed_tensors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..utils import is_compressed_tensors_available, is_torch_available, logging
from ..utils.quantization_config import QuantizationConfigMixin
from .base import HfQuantizer


if is_torch_available():
import torch

logger = logging.get_logger(__name__)


class CompressedTensorsHfQuantizer(HfQuantizer):
"""
Quantizer for the compressed_tensors package. Loads and restores models to
quantized state with compressed_tensors
"""

requires_calibration = False
Satrat marked this conversation as resolved.
Show resolved Hide resolved
required_packages = ["compressed_tensors"]

def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
super().__init__(quantization_config, **kwargs)

from compressed_tensors.compressors import ModelCompressor

self.compressor = ModelCompressor.from_compression_config(quantization_config)

def validate_environment(self, *args, **kwargs):
if not is_compressed_tensors_available():
raise ImportError(
"Using `compressed_tensors` quantized models requires the compressed-tensors library: "
"`pip install compressed-tensors`"
)
if not is_torch_available():
# torch already should be installed as part of compressed tensors
raise ImportError("torch is required for using compressed-tensors quantization")

def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
if torch_dtype is None:
logger.info("Loading model using torch.float16 for compressed-tensors quantization")
torch_dtype = torch.float16
elif torch_dtype != torch.float16:
logger.info(
"We suggest you to set `torch_dtype=torch.float16` for better efficiency with compressed_tensors."
)
return torch_dtype
Comment on lines +53 to +60
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an issue with bfloat16? We should try to allow this for llama models

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No issue with bfloat16, we just recommend float16 as a default since that is what vLLM expects for the scale/zp


def _process_model_before_weight_loading(self, model, **kwargs):
if self.quantization_config.quantization_config is not None:
from compressed_tensors.quantization import apply_quantization_config

apply_quantization_config(model, self.quantization_config.quantization_config)

def _process_model_after_weight_loading(self, model, resolved_archive_file, **kwargs):
self.compressor.decompress(model_path=resolved_archive_file, model=model)

bfineran marked this conversation as resolved.
Show resolved Hide resolved
@property
def is_trainable(self):
return False

@property
def is_serializable(self):
return False
8 changes: 8 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
is_av_available,
is_bitsandbytes_available,
is_bs4_available,
is_compressed_tensors_available,
is_cv2_available,
is_cython_available,
is_decord_available,
Expand Down Expand Up @@ -1134,6 +1135,13 @@ def require_quanto(test_case):
return unittest.skipUnless(is_quanto_available(), "test requires quanto")(test_case)


def require_compressed_tensors(test_case):
"""
Decorator for compressed_tensors dependency
"""
return unittest.skipUnless(is_compressed_tensors_available(), "test requires compressed_tensors")(test_case)


def require_fbgemm_gpu(test_case):
"""
Decorator for fbgemm_gpu dependency
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@
is_bitsandbytes_available,
is_bs4_available,
is_coloredlogs_available,
is_compressed_tensors_available,
is_cv2_available,
is_cython_available,
is_datasets_available,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
# `importlib.metadata.version` doesn't work with `awq`
_auto_awq_available = importlib.util.find_spec("awq") is not None
_quanto_available = _is_package_available("quanto")
_compressed_tensors_available = _is_package_available("compressed_tensors")
_pandas_available = _is_package_available("pandas")
_peft_available = _is_package_available("peft")
_phonemizer_available = _is_package_available("phonemizer")
Expand Down Expand Up @@ -938,6 +939,10 @@ def is_quanto_available():
return _quanto_available


def is_compressed_tensors_available():
return _compressed_tensors_available


def is_auto_gptq_available():
return _auto_gptq_available

Expand Down
74 changes: 73 additions & 1 deletion src/transformers/utils/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class QuantizationMethod(str, Enum):
QUANTO = "quanto"
EETQ = "eetq"
HQQ = "hqq"
COMPRESSED_TENSORS = "compressed-tensors"
FBGEMM_FP8 = "fbgemm_fp8"
TORCHAO = "torchao"

Expand Down Expand Up @@ -1051,7 +1052,78 @@ def post_init(self):
raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights}")


@dataclass
class CompressedTensorsConfig(QuantizationConfigMixin):
"""
This is a wrapper class that handles compressed-tensors quantization config options.
It is a wrapper around `compressed_tensors.QuantizationConfig`
Args:
config_groups (`typing.Dict[str, typing.Union[ForwardRef('QuantizationScheme'), typing.List[str]]]`, *optional*):
dictionary mapping group name to a quantization scheme definition
format (`str`, *optional*, defaults to `"dense"`):
format the model is represented as
Comment on lines +1062 to +1063
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the available formats?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker this includes the different compression formats, depending on how the model is quantized/saved on disk, including:

  1. dense
  2. int-quantized
  3. float-quantized
  4. pack-quantized
  5. marlin-24

quantization_status (`QuantizationStatus`, *optional*, defaults to `"initialized"`):
status of model in the quantization lifecycle, ie 'initialized', 'calibration', 'frozen'
global_compression_ratio (`typing.Union[float, NoneType]`, *optional*):
0-1 float percentage of model compression
ignore (`typing.Union[typing.List[str], NoneType]`, *optional*):
layer names or types to not quantize, supports regex prefixed by 're:'
sparsity_config (`typing.Dict[str, typing.Any]`, *optional*):
configuration for sparsity compression
quant_method (`str`, *optional*, defaults to `"compressed-tensors"`):
do not override, should be compressed-tensors
"""

def __init__(
self,
config_groups: Dict[str, Union["QuantizationScheme", List[str]]] = None, # noqa: F821
format: str = "dense",
quantization_status: "QuantizationStatus" = "initialized", # noqa: F821
global_compression_ratio: Optional[float] = None,
ignore: Optional[List[str]] = None,
sparsity_config: Dict[str, Any] = None,
quant_method: str = "compressed-tensors",
**kwargs,
):
from compressed_tensors import QuantizationConfig
from compressed_tensors.config import SparsityCompressionConfig

self.quantization_config = None
self.sparsity_config = None

# parse from dict to load nested QuantizationScheme objects
if config_groups:
self.quantization_config = QuantizationConfig.parse_obj(
{
"config_groups": config_groups,
"quant_method": quant_method,
"format": format,
"quantization_status": quantization_status,
"global_compression_ratio": global_compression_ratio,
"ignore": ignore,
}
)
bfineran marked this conversation as resolved.
Show resolved Hide resolved

if sparsity_config:
self.sparsity_config = SparsityCompressionConfig.load_from_registry(
sparsity_config.get("format"), **sparsity_config
)

super().__init__(quant_method=QuantizationMethod.COMPRESSED_TENSORS)

def to_dict(self) -> Dict[str, Any]:
"""
Serializes this instance to a Python dictionary. Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
"""
quantization_config = self.quantization_config.dict() if self.quantization_config is not None else None
sparsity_config = self.sparsity_config.dict() if self.sparsity_config is not None else None

return {
"quantization_config": quantization_config,
"sparsity_config": sparsity_config,
}


class FbgemmFp8Config(QuantizationConfigMixin):
Satrat marked this conversation as resolved.
Show resolved Hide resolved
"""
This is a wrapper class about all possible attributes and features that you can play with a model that has been
Expand Down
Empty file.
52 changes: 52 additions & 0 deletions tests/quantization/compressed_tensor/test_compressed_tensors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import gc
import unittest

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.testing_utils import require_compressed_tensors, require_torch
from transformers.utils import is_torch_available


if is_torch_available():
import torch


@require_compressed_tensors
@require_torch
class CompressedTensorsTest(unittest.TestCase):
quantized_model_name = "nm-testing/tinyllama-oneshot-w8a8-test-static-shape-change-v3"
bfineran marked this conversation as resolved.
Show resolved Hide resolved

prompt = "Paris is the capital of which country?"

def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
gc.collect()

@classmethod
def setUpClass(self):
"""
Setup quantized model
"""
self.tokenizer = AutoTokenizer.from_pretrained(self.quantized_model_name)
self.quantized_model = AutoModelForCausalLM.from_pretrained(self.quantized_model_name)
self.device = self.quantized_model.device

def test_quantized_model(self):
"""Carry out generation"""
self.assertIsNotNone(
self.quantized_model.config.quantization_config,
"quantization_config should not be None",
)
self.assertTrue(
any(
key
for key, tensor in self.quantized_model.state_dict().items()
if "scale" in key and not torch.all(tensor == 1.0)
),
"quantized model should load a non-trivail scale into the state dict",
bfineran marked this conversation as resolved.
Show resolved Hide resolved
)
inputs = self.tokenizer(self.prompt, return_tensors="pt").to(self.device)
generated_ids = self.quantized_model.generate(**inputs, max_length=50)
outputs = self.tokenizer.batch_decode(generated_ids)

self.assertIsNotNone(outputs)
SunMarc marked this conversation as resolved.
Show resolved Hide resolved