From fbcd00742a7e1f2e6eca633b8cded41d7628d9d9 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Sun, 9 Jun 2024 23:49:46 -0400 Subject: [PATCH] [Misc] Update to comply with the new `compressed-tensors` config (#5350) Co-authored-by: Michael Goin --- tests/quantization/test_compressed_tensors.py | 20 ++++++++++++------- vllm/config.py | 8 ++------ .../layers/quantization/__init__.py | 2 +- .../model_loader/weight_utils.py | 9 +++------ 4 files changed, 19 insertions(+), 20 deletions(-) diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 9d94d2ecfb222..e6d8218b41372 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -5,15 +5,15 @@ import torch +from vllm import SamplingParams from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 CompressedTensorsLinearMethod, CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor) def test_compressed_tensors_w8a8_static_setup(vllm_runner): - model_path = "nm-testing/tinyllama-one-shot-static-quant-test-compressed" - with vllm_runner(model_path, quantization="sparseml", - enforce_eager=True) as llm: + model_path = "nm-testing/tinyllama-oneshot-w8a8-static-v2" + with vllm_runner(model_path, enforce_eager=True) as llm: model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 layer = model.model.layers[0] @@ -40,11 +40,17 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner): assert qkv_proj.input_scale.dtype is torch.float32 +def test_compressed_tensors_no_enforce_eager(vllm_runner): + model_path = "nm-testing/tinyllama-oneshot-w8a8-static-v2" + with vllm_runner(model_path) as llm: + sampling_params = SamplingParams() + output = llm.generate("Hello world!", sampling_params=sampling_params) + assert output + + def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner): - model_path = "nm-testing/tinyllama-one-shot-dynamic-test" - with vllm_runner(model_path, - quantization="sparseml", - enforce_eager=True, + model_path = "nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2" + with vllm_runner(model_path, enforce_eager=True, dtype=torch.float16) as llm: model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 layer = model.model.layers[0] diff --git a/vllm/config.py b/vllm/config.py index a980168190adc..fa296cd626f17 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -164,12 +164,8 @@ def _verify_embedding_mode(self) -> None: def _parse_quant_hf_config(self): quant_cfg = getattr(self.hf_config, "quantization_config", None) if quant_cfg is None: - # SparseML uses a "compression_config" with a "quantization_config". - compression_cfg = getattr(self.hf_config, "compression_config", - None) - if compression_cfg is not None: - quant_cfg = compression_cfg.get("quantization_config", None) - + # compress-tensors uses a "compression_config" key + quant_cfg = getattr(self.hf_config, "compression_config", None) return quant_cfg def _verify_quantization(self) -> None: diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 0bc42beb66257..40b0df75a69a6 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -31,7 +31,7 @@ "gptq_marlin": GPTQMarlinConfig, "gptq": GPTQConfig, "squeezellm": SqueezeLLMConfig, - "sparseml": CompressedTensorsConfig, + "compressed-tensors": CompressedTensorsConfig, "bitsandbytes": BitsAndBytesConfig, } diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 6174f0a974712..827591b227a2b 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -122,12 +122,9 @@ def get_quant_config(model_config: ModelConfig, hf_quant_config = getattr(model_config.hf_config, "quantization_config", None) if hf_quant_config is None: - compression_config = getattr(model_config.hf_config, - "compression_config", None) - if compression_config is not None: - hf_quant_config = compression_config.get("quantization_config", - None) - + # compressed-tensors uses a compressions_config + hf_quant_config = getattr(model_config.hf_config, "compression_config", + None) if hf_quant_config is not None: return quant_cls.from_config(hf_quant_config) # In case of bitsandbytes/QLoRA, get quant config from the adapter model.