Skip to content
Draft
21 changes: 18 additions & 3 deletions src/transformers/quantizers/quantizer_compressed_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import defaultdict
from typing import TYPE_CHECKING

from ..utils import is_compressed_tensors_available, is_torch_available, logging
from ..utils.quantization_config import CompressedTensorsConfig
from .base import HfQuantizer


if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel


if is_torch_available():
import torch

Expand Down Expand Up @@ -63,13 +69,19 @@ def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
logger.info("We suggest you to set `dtype=torch.float16` for better efficiency with compressed_tensors.")
return dtype

def _process_model_before_weight_loading(self, model, **kwargs):
def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs):
from compressed_tensors.quantization import apply_quantization_config
from compressed_tensors.transform import apply_transform_config

ct_quantization_config = self.compressor.quantization_config
ct_transform_config = self.quantization_config.transform_config

# Always initialize compressed wrappers to match the checkpoint
# apply configs
if ct_transform_config is not None:
apply_transform_config(model, ct_transform_config)
apply_quantization_config(model, ct_quantization_config, self.run_compressed)

# compress meta model to match compressed checkpoint
if (
self.quantization_config.is_quantization_compressed
or self.quantization_config.is_sparsification_compressed
Expand All @@ -82,7 +94,10 @@ def _process_model_after_weight_loading(self, model, **kwargs):
if (
self.quantization_config.is_quantization_compressed and not self.run_compressed
) or self.quantization_config.is_sparsification_compressed:
self.compressor.decompress_model(model=model)
self.dequantize(model)

def dequantize(self, model: "PreTrainedModel"):
self.compressor.decompress_model(model=model)

def update_tp_plan(self, config):
additional_plan = {
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def is_env_variable_false(env_variable: str) -> bool:
AUTOROUND_MIN_VERSION = "0.5.0"
TRITON_MIN_VERSION = "1.0.0"
KERNELS_MIN_VERSION = "0.9.0"
COMPRESSED_TENSORS_MIN_VERSION = "0.13.1"


@lru_cache
Expand Down Expand Up @@ -1027,8 +1028,9 @@ def is_qutlass_available():


@lru_cache
def is_compressed_tensors_available() -> bool:
return _is_package_available("compressed_tensors")
def is_compressed_tensors_available(min_version: str = COMPRESSED_TENSORS_MIN_VERSION) -> bool:
is_available, ct_version = _is_package_available("compressed_tensors", return_version=True)
return is_available and version.parse(ct_version) >= version.parse(min_version)


@lru_cache
Expand Down
22 changes: 20 additions & 2 deletions src/transformers/utils/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1109,6 +1109,8 @@ class CompressedTensorsConfig(QuantizationConfigMixin):
layer names or types to not quantize, supports regex prefixed by 're:'
sparsity_config (`typing.dict[str, typing.Any]`, *optional*):
configuration for sparsity compression
transform_config (`Optional`, *optional*):
configuration for (hadamard) transforms
quant_method (`str`, *optional*, defaults to `"compressed-tensors"`):
do not override, should be compressed-tensors
run_compressed (`bool`, *optional*, defaults to `True`): alter submodules (usually linear) in order to
Expand All @@ -1124,23 +1126,25 @@ def __init__(
global_compression_ratio: float | None = None,
ignore: list[str] | None = None,
sparsity_config: dict[str, Any] | None = None,
transform_config: dict[str, Any] | None = None,
quant_method: str = "compressed-tensors",
run_compressed: bool = True,
**kwargs,
):
if is_compressed_tensors_available():
from compressed_tensors.config import SparsityCompressionConfig
from compressed_tensors.quantization import QuantizationConfig
from compressed_tensors.transform import TransformConfig
else:
raise ImportError(
"compressed_tensors is not installed and is required for compressed-tensors quantization. Please install it with `pip install compressed-tensors`."
)
self.quantization_config = None
self.sparsity_config = None

self.transform_config = None
self.run_compressed = run_compressed

# parse from dict to load nested QuantizationScheme objects
# quantization
if config_groups or kv_cache_scheme:
self.quantization_config = QuantizationConfig.model_validate(
{
Expand All @@ -1155,11 +1159,16 @@ def __init__(
}
)

# sparsity
if sparsity_config:
self.sparsity_config = SparsityCompressionConfig.load_from_registry(
sparsity_config.get("format"), **sparsity_config
)

# transform
if transform_config:
self.transform_config = TransformConfig.model_validate(transform_config)

self.quant_method = QuantizationMethod.COMPRESSED_TENSORS

def post_init(self):
Expand Down Expand Up @@ -1199,6 +1208,7 @@ def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
if "quantization_config" in config_dict:
config_dict = dict(
sparsity_config=config_dict.get("sparsity_config"),
transform_config=config_dict.get("transform_config"),
**config_dict["quantization_config"],
)

Expand All @@ -1211,17 +1221,25 @@ def to_dict(self) -> dict[str, Any]:
Serializes this instance to a Python dictionary. Returns:
`dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
"""
# quantization
quantization_config = {}
if self.quantization_config is not None:
quantization_config = self.quantization_config.model_dump()
else:
quantization_config["quant_method"] = QuantizationMethod.COMPRESSED_TENSORS

# sparsity
if self.sparsity_config is not None:
quantization_config["sparsity_config"] = self.sparsity_config.model_dump()
else:
quantization_config["sparsity_config"] = {}

# transform
if self.transform_config is not None:
quantization_config["transform_config"] = self.transform_config.model_dump()
else:
quantization_config["transform_config"] = {}

return quantization_config

def to_diff_dict(self) -> dict[str, Any]:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,8 @@
import gc
import unittest
from unittest import skip

from transformers import AutoModelForCausalLM, AutoTokenizer, CompressedTensorsConfig
from transformers.testing_utils import (
backend_empty_cache,
require_compressed_tensors,
require_deterministic_for_xpu,
require_torch,
torch_device,
)
from transformers.testing_utils import backend_empty_cache, require_compressed_tensors, require_torch, torch_device
from transformers.utils import is_torch_available


Expand All @@ -20,12 +13,13 @@
@require_compressed_tensors
@require_torch
class CompressedTensorsTest(unittest.TestCase):
tinyllama_w8a16 = "nm-testing/tinyllama-w8a16-dense"
tinyllama_w4a16 = "nm-testing/tinyllama-w4a16-compressed"
tinyllama_w8a8 = "nm-testing/tinyllama-w8a8-compressed"
llama3_8b_fp8 = "nm-testing/Meta-Llama-3-8B-Instruct-fp8-hf_compat"
tinyllama_w4a16 = "nm-testing/TinyLlama-1.1B-Chat-v1.0-W4A16-e2e"
tinyllama_int8 = "nm-testing/TinyLlama-1.1B-Chat-v1.0-W8A8-e2e"
tinyllama_fp8 = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
tinyllama_w8a16 = "nm-testing/TinyLlama-1.1B-Chat-v1.0-W8A16-e2e"
llama_1b_quip_w4a16 = "nm-testing/Llama-3.2-1B-Instruct-quip-w4a16"

prompt = "Paris is the capital of which country?"
prompt = "The capital of France is Paris, the capital of Germany is Berlin"

def tearDown(self):
gc.collect()
Expand Down Expand Up @@ -53,43 +47,33 @@ def test_config_to_from_dict(self):
self.assertIsInstance(config_from_dict.quantization_config, QuantizationConfig)
self.assertIsInstance(config_from_dict.sparsity_config, SparsityCompressionConfig)

@skip("Test too flaky, depends on hardware also")
def test_tinyllama_w8a8(self):
expected_out = [
"<s> Paris is the capital of which country?\n\n**A) 10** Paris is the capital of which country?\n\n**B) 11** Paris is the capital of which country?\n\n**C) 1",
"<s> Paris is the capital of which country?\n\n** 10.** Which country is the capital of which country?\n\n** 11.** Which country is the capital of which country?\n\n** 12.", # XPU
]
self._test_quantized_model(self.tinyllama_w8a8, expected_out)

def test_tinyllama_w4a16(self):
expected_out = [
"<s> Paris is the capital of which country?\nAnswer: Paris is the capital of France.\nQuestion: Which country is the capital of which city?\nAnswer: The capital of the city of New York is New York.\nQuestion: Which"
]
self._test_quantized_model(self.tinyllama_w4a16, expected_out)
self._test_quantized_model(self.tinyllama_w4a16, 20.0)

def test_tinyllama_int8(self):
self._test_quantized_model(self.tinyllama_int8, 30.0)

def test_tinyllama_fp8(self):
self._test_quantized_model(self.tinyllama_fp8, 20.0)

def test_tinyllama_w8a16(self):
expected_out = [
"<s> Paris is the capital of which country?\nA. France\nB. Germany\nC. Spain\nD. Italy\nE. Switzerland\nQ10. Which of the following is not a country in the European Union?\nA."
]
self._test_quantized_model(self.tinyllama_w8a16, expected_out)

def test_llama_8b_fp8(self):
expected_out = [
"<|begin_of_text|>Paris is the capital of which country? France\nWhat is the name of the famous art museum in Paris? The Louvre\nWhat is the name of the famous bridge in Paris? Pont des Arts\nWhat is the name of the famous opera? ",
"<|begin_of_text|>Paris is the capital of which country? France\nWhat is the name of the famous art museum in Paris? The Louvre\nWhat is the name of the famous bridge in Paris? Pont des Arts\nWhat is the name of the famous opera", # XPU
]
self._test_quantized_model(self.llama3_8b_fp8, expected_out)

@require_deterministic_for_xpu
def _test_quantized_model(self, model_name: str, expected_output: list):
"""Carry out generation"""
self._test_quantized_model(self.tinyllama_w8a16, 20.0)

def test_llama_1b_quip_w4a16(self):
self._test_quantized_model(self.llama_1b_quip_w4a16, 10.0)

def _test_quantized_model(self, model_name: str, expected_perplexity: float):
# load model
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
device = quantized_model.device

# check config
self.assertIsNotNone(
quantized_model.config.quantization_config,
"quantization_config should not be None",
)
# check scales
self.assertTrue(
any(
key
Expand All @@ -98,9 +82,13 @@ def _test_quantized_model(self, model_name: str, expected_output: list):
),
"quantized model should load a non-trivial scale into the state dict",
)

# compute outputs with loss
inputs = tokenizer(self.prompt, return_tensors="pt").to(device)
generated_ids = quantized_model.generate(**inputs, max_length=50, do_sample=False)
outputs = tokenizer.batch_decode(generated_ids)
labels = inputs["input_ids"]
with torch.no_grad():
outputs = quantized_model(**inputs, labels=labels)

self.assertIsNotNone(outputs)
self.assertIn(outputs[0], expected_output)
# check perplexity
perplexity = torch.exp(outputs.loss)
self.assertLessEqual(perplexity, expected_perplexity)