Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3038,11 +3038,13 @@ def _init_weights(self, module):
std = getattr(self.config.get_text_config(), "initializer_range", 0.02)

if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d)):
module.weight.data.normal_(mean=0.0, std=std)
if hasattr(module, "weight"):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if hasattr(module, "weight"):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.MultiheadAttention):
Expand Down
65 changes: 57 additions & 8 deletions src/transformers/quantizers/quantizer_compressed_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,17 @@
# limitations under the License.


from typing import List, Tuple, TYPE_CHECKING
from collections import defaultdict
import torch

from ..utils import is_compressed_tensors_available, is_torch_available, logging
from ..utils.quantization_config import CompressedTensorsConfig
from .base import HfQuantizer

if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel

if is_torch_available():
import torch

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -61,7 +65,7 @@ def validate_environment(self, *args, **kwargs):
if not is_torch_available():
# torch already should be installed as part of compressed tensors
raise ImportError("torch is required for using compressed-tensors quantization")

def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
if dtype is None:
logger.info("Loading model using torch.float16 for compressed-tensors quantization")
Expand All @@ -70,26 +74,67 @@ def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
logger.info("We suggest you to set `dtype=torch.float16` for better efficiency with compressed_tensors.")
return dtype

def _process_model_before_weight_loading(self, model, **kwargs):
def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs):
from compressed_tensors.quantization import apply_quantization_config
from compressed_tensors.transform import apply_transform_config

ct_quantization_config = self.compressor.quantization_config
ct_transform_config = self.quantization_config.transform_config

# Always initialize compressed wrappers to match the checkpoint
# apply transform config first (before modules are converted to run compressed)
if ct_transform_config is not None:
apply_transform_config(model, ct_transform_config)

# apply quantization config (potentially convert to run compressed)
apply_quantization_config(model, ct_quantization_config, self.run_compressed)

# compress meta model to match checkpoint format
if (
self.quantization_config.is_quantization_compressed
or self.quantization_config.is_sparsification_compressed
):
self.compressor.compress_model(model=model)

def _process_model_after_weight_loading(self, model, **kwargs):
"""Decompress loaded model if necessary - need for qat"""
# update tied weights to include added transforms (_dynamic_tied_weights_keys)
self.patch_tie_weights_fn(model)

def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
"""Decompress loaded model if necessary - need for qat"""
if (
self.quantization_config.is_quantization_compressed and not self.run_compressed
) or self.quantization_config.is_sparsification_compressed:
self.compressor.decompress_model(model=model)
self.dequantize(model)

def patch_tie_weights_fn(self, model: "PreTrainedModel"):
# record shared tensors before weight loading
shared_weights: defaultdict[int, List[Tuple[torch.nn.Module, str]]] = defaultdict(list)
for module in model.modules():
shared_keys = getattr(module, "_dynamic_tied_weights_keys", list())
for key in shared_keys:
weight = getattr(module, key, None)
if weight is not None:
shared_weights[id(weight)].append((module, key))

original_fn = model.tie_weights.__func__

# this function is called after weight loading but before dispatch
def tie_weights(self: "PreTrainedModel"):
# broadcast loaded weight to other shared weights
for modules_keys in shared_weights.values():
weights = [getattr(module, key) for module, key in modules_keys]
loaded_weights = [weight for weight in weights if weight.device.type != "meta"]
if len(loaded_weights) <= 0:
raise ValueError("Failed to load shared weight")
if len(loaded_weights) >= 2:
raise ValueError("Loaded too many shared weights")

loaded_weight = loaded_weights[0]
for module, key in modules_keys:
module.load_state_dict({key: loaded_weight}, strict=False, assign=True)

original_fn(self)

model.tie_weights = tie_weights.__get__(model)

def update_tp_plan(self, config):
additional_plan = {
Expand All @@ -116,3 +161,7 @@ def is_qat_trainable(self) -> bool:
def is_serializable(self, safe_serialization=None) -> bool:
"""Models quantized using compressed tensors can be saved to disk"""
return True

def dequantize(self, model: "PreTrainedModel"):
"""Decompress model"""
self.compressor.decompress_model(model=model)
10 changes: 6 additions & 4 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
TORCHAO_MIN_VERSION = "0.4.0"
AUTOROUND_MIN_VERSION = "0.5.0"
TRITON_MIN_VERSION = "1.0.0"
COMPRESSED_TENSORS_MIN_VERSION = "0.11.0"

_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
_apex_available = _is_package_available("apex")
Expand Down Expand Up @@ -191,8 +192,9 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
_is_optimum_quanto_available = True
except importlib.metadata.PackageNotFoundError:
_is_optimum_quanto_available = False
# For compressed_tensors, only check spec to allow compressed_tensors-nightly package
_compressed_tensors_available = importlib.util.find_spec("compressed_tensors") is not None
_compressed_tensors_available, _compressed_tensors_version = _is_package_available(
"compressed_tensors", return_version=True
)
_pandas_available = _is_package_available("pandas")
_peft_available = _is_package_available("peft")
_phonemizer_available = _is_package_available("phonemizer")
Expand Down Expand Up @@ -1364,8 +1366,8 @@ def is_qutlass_available() -> Union[tuple[bool, str], bool]:
return _qutlass_available


def is_compressed_tensors_available() -> bool:
return _compressed_tensors_available
def is_compressed_tensors_available(min_version: str = COMPRESSED_TENSORS_MIN_VERSION) -> bool:
return _compressed_tensors_available and version.parse(_compressed_tensors_version) >= version.parse(min_version)


def is_auto_gptq_available() -> Union[tuple[bool, str], bool]:
Expand Down
18 changes: 16 additions & 2 deletions src/transformers/utils/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1306,6 +1306,8 @@ class CompressedTensorsConfig(QuantizationConfigMixin):
do not override, should be compressed-tensors
run_compressed (`bool`, *optional*, defaults to `True`): alter submodules (usually linear) in order to
emulate compressed model execution if True, otherwise use default submodule
transform_config (`typing.dict[str, typing.Any]`, *optional*):
configuration for online and offline transforms to improve accuracy recovery
"""

def __init__(
Expand All @@ -1319,21 +1321,22 @@ def __init__(
sparsity_config: Optional[dict[str, Any]] = None,
quant_method: str = "compressed-tensors",
run_compressed: bool = True,
transform_config: Optional[dict[str, Any]] = None,
**kwargs,
):
if is_compressed_tensors_available():
from compressed_tensors.config import SparsityCompressionConfig
from compressed_tensors.quantization import QuantizationConfig
from compressed_tensors.transform import TransformConfig
else:
raise ImportError(
"compressed_tensors is not installed and is required for compressed-tensors quantization. Please install it with `pip install compressed-tensors`."
)
self.quantization_config = None
self.sparsity_config = None

self.run_compressed = run_compressed

# parse from dict to load nested QuantizationScheme objects
self.quantization_config = None
if config_groups or kv_cache_scheme:
self.quantization_config = QuantizationConfig.model_validate(
{
Expand All @@ -1348,11 +1351,16 @@ def __init__(
}
)

self.sparsity_config = None
if sparsity_config:
self.sparsity_config = SparsityCompressionConfig.load_from_registry(
sparsity_config.get("format"), **sparsity_config
)

self.transform_config = None
if transform_config:
self.transform_config = TransformConfig.model_validate(transform_config)

self.quant_method = QuantizationMethod.COMPRESSED_TENSORS

def post_init(self):
Expand Down Expand Up @@ -1392,6 +1400,7 @@ def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
if "quantization_config" in config_dict:
config_dict = dict(
sparsity_config=config_dict.get("sparsity_config"),
transform_config=config_dict.get("transform_config"),
**config_dict["quantization_config"],
)

Expand All @@ -1415,6 +1424,11 @@ def to_dict(self) -> dict[str, Any]:
else:
quantization_config["sparsity_config"] = {}

if self.transform_config is not None:
quantization_config["transform_config"] = self.transform_config.model_dump()
else:
quantization_config["transform_config"] = {}

return quantization_config

def to_diff_dict(self) -> dict[str, Any]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ class CompressedTensorsTest(unittest.TestCase):
tinyllama_w4a16 = "nm-testing/tinyllama-w4a16-compressed-hf-quantizer"
tinyllama_w8a8 = "nm-testing/tinyllama-w8a8-compressed-hf-quantizer"
llama3_8b_fp8 = "nm-testing/Meta-Llama-3-8B-Instruct-fp8-hf_compat"
quip_w4a16 = "nm-testing/Llama-3.2-1B-Instruct-quip-w4a16"

prompt = "Paris is the capital of which country?"
prompt = "The capital of France is Paris, the capital of Germany is Berlin"

def tearDown(self):
gc.collect()
Expand Down Expand Up @@ -47,30 +48,32 @@ def test_config_to_from_dict(self):
self.assertIsInstance(config_from_dict.sparsity_config, SparsityCompressionConfig)

def test_tinyllama_w8a8(self):
expected_out = "<s> Paris is the capital of which country?\n\n 1. Paris is the capital of which country?\n\n 1. Paris is the capital of which country?\n\n 1. Paris is the capital of which country?\n\n"
self._test_quantized_model(self.tinyllama_w8a8, expected_out)
self._test_quantized_model(self.tinyllama_w8a8, 30.0)

def test_tinyllama_w4a16(self):
expected_out = "<s> Paris is the capital of which country?\nAnswer: Paris is the capital of France.\nQuestion: Which country is the capital of which city?\nAnswer: The capital of the city of New York is New York.\nQuestion: Which"
self._test_quantized_model(self.tinyllama_w4a16, expected_out)
self._test_quantized_model(self.tinyllama_w4a16, 20.0)

def test_tinyllama_w8a16(self):
expected_out = "<s> Paris is the capital of which country?\nA. France\nB. Germany\nC. Spain\nD. Italy\nE. Switzerland\nQ10. Which of the following is not a country in the European Union?\nA."
self._test_quantized_model(self.tinyllama_w8a16, expected_out)
self._test_quantized_model(self.tinyllama_w8a16, 20.0)

def test_llama_8b_fp8(self):
expected_out = "<|begin_of_text|>Paris is the capital of which country? France\nWhat is the name of the famous museum in Paris that is home to the Mona Lisa? The Louvre\nWhat is the name of the famous bridge in Paris that is often associated with the city"
self._test_quantized_model(self.llama3_8b_fp8, expected_out)
self._test_quantized_model(self.llama3_8b_fp8, 10.0)

def _test_quantized_model(self, model_name: str, expected_output: str):
"""Carry out generation"""
def test_quip_w4a16(self):
self._test_quantized_model(self.quip_w4a16, 10.0)

def _test_quantized_model(self, model_name: str, expected_perplexity: float):
# load model
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
device = quantized_model.device

# check config
self.assertIsNotNone(
quantized_model.config.quantization_config,
"quantization_config should not be None",
)
# check scales
self.assertTrue(
any(
key
Expand All @@ -79,9 +82,13 @@ def _test_quantized_model(self, model_name: str, expected_output: str):
),
"quantized model should load a non-trivial scale into the state dict",
)

# compute outputs with loss
inputs = tokenizer(self.prompt, return_tensors="pt").to(device)
generated_ids = quantized_model.generate(**inputs, max_length=50, do_sample=False)
outputs = tokenizer.batch_decode(generated_ids)
labels = inputs["input_ids"]
with torch.no_grad():
outputs = quantized_model(**inputs, labels=labels)

self.assertIsNotNone(outputs)
self.assertEqual(outputs[0], expected_output)
# check perplexity
perplexity = torch.exp(outputs.loss)
self.assertLessEqual(perplexity, expected_perplexity)