diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index aaa366335d196..6eb7ff72fb11d 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -8,9 +8,9 @@ from vllm import SamplingParams from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 - CompressedTensorsLinearMethod, CompressedTensorsW4A16, - CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8DynamicToken, - CompressedTensorsW8A8StaticTensor) + CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24, + CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor, + CompressedTensorsWNA16) @pytest.mark.parametrize("model_args", [ @@ -74,26 +74,27 @@ def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner, model_args): assert qkv_proj.weight.dtype is torch.int8 -@pytest.mark.parametrize("w4a16_args", [ - ("nm-testing/tinyllama-oneshot-w4a16-channel-v2", "channel", None), - ("nm-testing/tinyllama-oneshot-w4a16-group128-v2", "group", 128), -]) -def test_compressed_tensors_w4a16(vllm_runner, w4a16_args): - model, strategy, group = w4a16_args +@pytest.mark.parametrize( + "wNa16_args", + [("nm-testing/tinyllama-oneshot-w4a16-channel-v2", "channel", None, 8), + ("nm-testing/tinyllama-oneshot-w4a16-group128-v2", "group", 128, 8), + ("nm-testing/tinyllama-oneshot-w8a16-per-channel", "channel", None, 4)]) +def test_compressed_tensors_w4a16(vllm_runner, wNa16_args): + model, strategy, group, pack_factor = wNa16_args with vllm_runner(model) as llm: model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 layer = model.model.layers[0] qkv_proj = layer.self_attn.qkv_proj assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) - assert isinstance(qkv_proj.scheme, CompressedTensorsW4A16) + assert isinstance(qkv_proj.scheme, CompressedTensorsWNA16) assert qkv_proj.scheme.strategy == strategy assert qkv_proj.scheme.group_size == group assert qkv_proj.weight_packed.dtype is torch.int32 assert qkv_proj.weight_scale.dtype is torch.float16 - assert qkv_proj.weight_packed.pack_factor == 8 + assert qkv_proj.weight_packed.pack_factor == pack_factor def test_compressed_tensors_w4a16_marlin24(vllm_runner): diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 44dd024afe74d..c69e2f3bcf9fa 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -7,9 +7,10 @@ from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 QuantizationConfig) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsScheme, CompressedTensorsW4A16, - CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8DynamicToken, - CompressedTensorsW8A8StaticTensor) + W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, + CompressedTensorsScheme, CompressedTensorsW4A16Sparse24, + CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor, + CompressedTensorsWNA16) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( CompressionFormat, QuantizationArgs, QuantizationStrategy, find_first_name_or_class_match) @@ -108,26 +109,31 @@ def _is_dynamic_token_w8a8(self, weight_quant: BaseModel, return is_8_bits and is_token and is_symmetric and is_dynamic - def _is_w4a16(self, weight_quant: BaseModel, - input_quant: BaseModel) -> bool: + def _is_wNa16_group_channel(self, weight_quant: BaseModel, + input_quant: BaseModel) -> bool: input_quant_none = input_quant is None - is_4_bits = weight_quant.num_bits == 4 is_symmetric = weight_quant.symmetric + is_channel_group = ( + weight_quant.strategy == QuantizationStrategy.CHANNEL.value + or weight_quant.strategy == QuantizationStrategy.GROUP.value) is_static = not weight_quant.dynamic - return is_4_bits and input_quant_none and is_symmetric and is_static + return (is_channel_group and input_quant_none and is_symmetric + and is_static) def _get_schema(self, weight_quant: BaseModel, input_quant: BaseModel) -> "CompressedTensorsScheme": - if self._is_w4a16(weight_quant, input_quant): - if self.quant_format == CompressionFormat.marlin_24.value: + if self._is_wNa16_group_channel(weight_quant, input_quant): + if (self.quant_format == CompressionFormat.marlin_24.value + and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS): return CompressedTensorsW4A16Sparse24( strategy=weight_quant.strategy, num_bits=weight_quant.num_bits, group_size=weight_quant.group_size) - if self.quant_format == CompressionFormat.pack_quantized.value: - return CompressedTensorsW4A16( + if (self.quant_format == CompressionFormat.pack_quantized.value + and weight_quant.num_bits in WNA16_SUPPORTED_BITS): + return CompressedTensorsWNA16( num_bits=weight_quant.num_bits, strategy=weight_quant.strategy, group_size=weight_quant.group_size) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py index 3c95aa11fc76c..f6d20ce2c6f77 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py @@ -1,10 +1,11 @@ from .compressed_tensors_scheme import CompressedTensorsScheme # noqa: F401 from .compressed_tensors_unquantized import ( # noqa: F401 CompressedTensorsUnquantized) -from .compressed_tensors_w4a16 import CompressedTensorsW4A16 # noqa: F401 from .compressed_tensors_w4a16_24 import ( # noqa: F401 - CompressedTensorsW4A16Sparse24) + W4A16SPARSE24_SUPPORTED_BITS, CompressedTensorsW4A16Sparse24) from .compressed_tensors_w8a8_dynamictoken import ( # noqa: F401, E501 CompressedTensorsW8A8DynamicToken) from .compressed_tensors_w8a8_statictensor import ( # noqa: F401, E501 CompressedTensorsW8A8StaticTensor) +from .compressed_tensors_wNa16 import WNA16_SUPPORTED_BITS # noqa: F401 +from .compressed_tensors_wNa16 import CompressedTensorsWNA16 # noqa: F401 diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py index d7e04ddb8d94a..607029c819ddb 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py @@ -11,6 +11,7 @@ from vllm.model_executor.utils import set_weight_attrs __all__ = ["CompressedTensorsW4A16Sparse24"] +W4A16SPARSE24_SUPPORTED_BITS = [4] class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme): diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py similarity index 98% rename from vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16.py rename to vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index 373458cfffe04..7707ea6ee94bc 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -11,10 +11,11 @@ marlin_permute_scales) from vllm.model_executor.utils import set_weight_attrs -__all__ = ["CompressedTensorsW4A16"] +__all__ = ["CompressedTensorsWNA16"] +WNA16_SUPPORTED_BITS = [4, 8] -class CompressedTensorsW4A16(CompressedTensorsScheme): +class CompressedTensorsWNA16(CompressedTensorsScheme): def __init__(self, strategy: str,