Skip to content

Commit c677099

Browse files
dsikkaprashantgupta24
authored andcommitted
[Misc] Update w4a16 compressed-tensors support to include w8a16 (vllm-project#5794)
1 parent 0439542 commit c677099

File tree

5 files changed

+36
-26
lines changed

5 files changed

+36
-26
lines changed

tests/quantization/test_compressed_tensors.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88

99
from vllm import SamplingParams
1010
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
11-
CompressedTensorsLinearMethod, CompressedTensorsW4A16,
12-
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8DynamicToken,
13-
CompressedTensorsW8A8StaticTensor)
11+
CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24,
12+
CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor,
13+
CompressedTensorsWNA16)
1414

1515

1616
@pytest.mark.parametrize("model_args", [
@@ -74,26 +74,27 @@ def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner, model_args):
7474
assert qkv_proj.weight.dtype is torch.int8
7575

7676

77-
@pytest.mark.parametrize("w4a16_args", [
78-
("nm-testing/tinyllama-oneshot-w4a16-channel-v2", "channel", None),
79-
("nm-testing/tinyllama-oneshot-w4a16-group128-v2", "group", 128),
80-
])
81-
def test_compressed_tensors_w4a16(vllm_runner, w4a16_args):
82-
model, strategy, group = w4a16_args
77+
@pytest.mark.parametrize(
78+
"wNa16_args",
79+
[("nm-testing/tinyllama-oneshot-w4a16-channel-v2", "channel", None, 8),
80+
("nm-testing/tinyllama-oneshot-w4a16-group128-v2", "group", 128, 8),
81+
("nm-testing/tinyllama-oneshot-w8a16-per-channel", "channel", None, 4)])
82+
def test_compressed_tensors_w4a16(vllm_runner, wNa16_args):
83+
model, strategy, group, pack_factor = wNa16_args
8384
with vllm_runner(model) as llm:
8485
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
8586
layer = model.model.layers[0]
8687

8788
qkv_proj = layer.self_attn.qkv_proj
8889
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
89-
assert isinstance(qkv_proj.scheme, CompressedTensorsW4A16)
90+
assert isinstance(qkv_proj.scheme, CompressedTensorsWNA16)
9091

9192
assert qkv_proj.scheme.strategy == strategy
9293
assert qkv_proj.scheme.group_size == group
9394

9495
assert qkv_proj.weight_packed.dtype is torch.int32
9596
assert qkv_proj.weight_scale.dtype is torch.float16
96-
assert qkv_proj.weight_packed.pack_factor == 8
97+
assert qkv_proj.weight_packed.pack_factor == pack_factor
9798

9899

99100
def test_compressed_tensors_w4a16_marlin24(vllm_runner):

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
88
QuantizationConfig)
99
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
10-
CompressedTensorsScheme, CompressedTensorsW4A16,
11-
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8DynamicToken,
12-
CompressedTensorsW8A8StaticTensor)
10+
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS,
11+
CompressedTensorsScheme, CompressedTensorsW4A16Sparse24,
12+
CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor,
13+
CompressedTensorsWNA16)
1314
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
1415
CompressionFormat, QuantizationArgs, QuantizationStrategy,
1516
find_first_name_or_class_match)
@@ -108,26 +109,31 @@ def _is_dynamic_token_w8a8(self, weight_quant: BaseModel,
108109

109110
return is_8_bits and is_token and is_symmetric and is_dynamic
110111

111-
def _is_w4a16(self, weight_quant: BaseModel,
112-
input_quant: BaseModel) -> bool:
112+
def _is_wNa16_group_channel(self, weight_quant: BaseModel,
113+
input_quant: BaseModel) -> bool:
113114
input_quant_none = input_quant is None
114-
is_4_bits = weight_quant.num_bits == 4
115115
is_symmetric = weight_quant.symmetric
116+
is_channel_group = (
117+
weight_quant.strategy == QuantizationStrategy.CHANNEL.value
118+
or weight_quant.strategy == QuantizationStrategy.GROUP.value)
116119
is_static = not weight_quant.dynamic
117120

118-
return is_4_bits and input_quant_none and is_symmetric and is_static
121+
return (is_channel_group and input_quant_none and is_symmetric
122+
and is_static)
119123

120124
def _get_schema(self, weight_quant: BaseModel,
121125
input_quant: BaseModel) -> "CompressedTensorsScheme":
122126

123-
if self._is_w4a16(weight_quant, input_quant):
124-
if self.quant_format == CompressionFormat.marlin_24.value:
127+
if self._is_wNa16_group_channel(weight_quant, input_quant):
128+
if (self.quant_format == CompressionFormat.marlin_24.value
129+
and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS):
125130
return CompressedTensorsW4A16Sparse24(
126131
strategy=weight_quant.strategy,
127132
num_bits=weight_quant.num_bits,
128133
group_size=weight_quant.group_size)
129-
if self.quant_format == CompressionFormat.pack_quantized.value:
130-
return CompressedTensorsW4A16(
134+
if (self.quant_format == CompressionFormat.pack_quantized.value
135+
and weight_quant.num_bits in WNA16_SUPPORTED_BITS):
136+
return CompressedTensorsWNA16(
131137
num_bits=weight_quant.num_bits,
132138
strategy=weight_quant.strategy,
133139
group_size=weight_quant.group_size)
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from .compressed_tensors_scheme import CompressedTensorsScheme # noqa: F401
22
from .compressed_tensors_unquantized import ( # noqa: F401
33
CompressedTensorsUnquantized)
4-
from .compressed_tensors_w4a16 import CompressedTensorsW4A16 # noqa: F401
54
from .compressed_tensors_w4a16_24 import ( # noqa: F401
6-
CompressedTensorsW4A16Sparse24)
5+
W4A16SPARSE24_SUPPORTED_BITS, CompressedTensorsW4A16Sparse24)
76
from .compressed_tensors_w8a8_dynamictoken import ( # noqa: F401, E501
87
CompressedTensorsW8A8DynamicToken)
98
from .compressed_tensors_w8a8_statictensor import ( # noqa: F401, E501
109
CompressedTensorsW8A8StaticTensor)
10+
from .compressed_tensors_wNa16 import WNA16_SUPPORTED_BITS # noqa: F401
11+
from .compressed_tensors_wNa16 import CompressedTensorsWNA16 # noqa: F401

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from vllm.model_executor.utils import set_weight_attrs
1212

1313
__all__ = ["CompressedTensorsW4A16Sparse24"]
14+
W4A16SPARSE24_SUPPORTED_BITS = [4]
1415

1516

1617
class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16.py renamed to vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@
1111
marlin_permute_scales)
1212
from vllm.model_executor.utils import set_weight_attrs
1313

14-
__all__ = ["CompressedTensorsW4A16"]
14+
__all__ = ["CompressedTensorsWNA16"]
15+
WNA16_SUPPORTED_BITS = [4, 8]
1516

1617

17-
class CompressedTensorsW4A16(CompressedTensorsScheme):
18+
class CompressedTensorsWNA16(CompressedTensorsScheme):
1819

1920
def __init__(self,
2021
strategy: str,

0 commit comments

Comments
 (0)