Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit 8f865f6

Browse files
dsikkamgoin
authored andcommitted
[Misc] Update to comply with the new compressed-tensors config (vllm-project#5350)
Co-authored-by: Michael Goin <michael@neuralmagic.com>
1 parent 27e68e9 commit 8f865f6

File tree

4 files changed

+19
-20
lines changed

4 files changed

+19
-20
lines changed

tests/quantization/test_compressed_tensors.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@
55

66
import torch
77

8+
from vllm import SamplingParams
89
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
910
CompressedTensorsLinearMethod, CompressedTensorsW8A8DynamicToken,
1011
CompressedTensorsW8A8StaticTensor)
1112

1213

1314
def test_compressed_tensors_w8a8_static_setup(vllm_runner):
14-
model_path = "nm-testing/tinyllama-one-shot-static-quant-test-compressed"
15-
with vllm_runner(model_path, quantization="sparseml",
16-
enforce_eager=True) as llm:
15+
model_path = "nm-testing/tinyllama-oneshot-w8a8-static-v2"
16+
with vllm_runner(model_path, enforce_eager=True) as llm:
1717
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
1818
layer = model.model.layers[0]
1919

@@ -40,11 +40,17 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner):
4040
assert qkv_proj.input_scale.dtype is torch.float32
4141

4242

43+
def test_compressed_tensors_no_enforce_eager(vllm_runner):
44+
model_path = "nm-testing/tinyllama-oneshot-w8a8-static-v2"
45+
with vllm_runner(model_path) as llm:
46+
sampling_params = SamplingParams()
47+
output = llm.generate("Hello world!", sampling_params=sampling_params)
48+
assert output
49+
50+
4351
def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner):
44-
model_path = "nm-testing/tinyllama-one-shot-dynamic-test"
45-
with vllm_runner(model_path,
46-
quantization="sparseml",
47-
enforce_eager=True,
52+
model_path = "nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2"
53+
with vllm_runner(model_path, enforce_eager=True,
4854
dtype=torch.float16) as llm:
4955
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
5056
layer = model.model.layers[0]

vllm/config.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -196,12 +196,8 @@ def _verify_embedding_mode(self) -> None:
196196
def _parse_quant_hf_config(self):
197197
quant_cfg = getattr(self.hf_config, "quantization_config", None)
198198
if quant_cfg is None:
199-
# SparseML uses a "compression_config" with a "quantization_config".
200-
compression_cfg = getattr(self.hf_config, "compression_config",
201-
None)
202-
if compression_cfg is not None:
203-
quant_cfg = compression_cfg.get("quantization_config", None)
204-
199+
# compress-tensors uses a "compression_config" key
200+
quant_cfg = getattr(self.hf_config, "compression_config", None)
205201
return quant_cfg
206202

207203
def _verify_quantization(self) -> None:

vllm/model_executor/layers/quantization/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
"gptq_marlin": GPTQMarlinConfig,
3232
"gptq": GPTQConfig,
3333
"squeezellm": SqueezeLLMConfig,
34-
"sparseml": CompressedTensorsConfig,
34+
"compressed-tensors": CompressedTensorsConfig,
3535
"bitsandbytes": BitsAndBytesConfig,
3636
}
3737

vllm/model_executor/model_loader/weight_utils.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -137,12 +137,9 @@ def get_quant_config(model_config: ModelConfig,
137137
hf_quant_config = getattr(model_config.hf_config, "quantization_config",
138138
None)
139139
if hf_quant_config is None:
140-
compression_config = getattr(model_config.hf_config,
141-
"compression_config", None)
142-
if compression_config is not None:
143-
hf_quant_config = compression_config.get("quantization_config",
144-
None)
145-
140+
# compressed-tensors uses a compressions_config
141+
hf_quant_config = getattr(model_config.hf_config, "compression_config",
142+
None)
146143
if hf_quant_config is not None:
147144
return quant_cls.from_config(hf_quant_config)
148145
# In case of bitsandbytes/QLoRA, get quant config from the adapter model.

0 commit comments

Comments
 (0)