Skip to content

Commit 95db455

Browse files
authored
[Misc] Add channel-wise quantization support for w8a8 dynamic per token activation quantization (#5542)
1 parent 7879f24 commit 95db455

File tree

4 files changed

+45
-32
lines changed

4 files changed

+45
-32
lines changed

tests/quantization/test_compressed_tensors.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515

1616
def test_compressed_tensors_w8a8_static_setup(vllm_runner):
17-
model_path = "nm-testing/tinyllama-oneshot-w8a8-static-v2"
17+
model_path = "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
1818
with vllm_runner(model_path, enforce_eager=True) as llm:
1919
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
2020
layer = model.model.layers[0]
@@ -43,15 +43,19 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner):
4343

4444

4545
def test_compressed_tensors_no_enforce_eager(vllm_runner):
46-
model_path = "nm-testing/tinyllama-oneshot-w8a8-static-v2"
46+
model_path = "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
4747
with vllm_runner(model_path) as llm:
4848
sampling_params = SamplingParams()
4949
output = llm.generate("Hello world!", sampling_params=sampling_params)
5050
assert output
5151

5252

53-
def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner):
54-
model_path = "nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2"
53+
@pytest.mark.parametrize("model_args", [
54+
("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2", "tensor"),
55+
("nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2", "channel"),
56+
])
57+
def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner, model_args):
58+
model_path, strategy = model_args
5559
with vllm_runner(model_path, dtype=torch.float16) as llm:
5660
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
5761
layer = model.model.layers[0]
@@ -60,6 +64,7 @@ def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner):
6064

6165
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
6266
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8DynamicToken)
67+
assert qkv_proj.scheme.strategy == strategy
6368
assert qkv_proj.weight.dtype is torch.int8
6469

6570

vllm/model_executor/layers/linear.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -468,13 +468,6 @@ def weight_loader(self,
468468
"MergedColumnParallelLinear, assume the weight is "
469469
"the same for all partitions.")
470470

471-
if fp8_scales_shard_indexer is None:
472-
if len(param_data.shape) == 0:
473-
param_data = param_data.reshape(1)
474-
475-
if len(loaded_weight.shape) == 0:
476-
loaded_weight = loaded_weight.reshape(1)
477-
478471
assert param_data.shape == loaded_weight.shape
479472
param_data.copy_(loaded_weight)
480473

@@ -686,12 +679,6 @@ def weight_loader(self,
686679
"QKVParallelLinear, assume the weight is the same "
687680
"for all partitions.")
688681

689-
if len(param_data.shape) == 0:
690-
param_data = param_data.reshape(1)
691-
692-
if len(loaded_weight.shape) == 0:
693-
loaded_weight = loaded_weight.reshape(1)
694-
695682
assert param_data.shape == loaded_weight.shape
696683
param_data.copy_(loaded_weight)
697684

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -95,14 +95,15 @@ def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
9595
def _is_dynamic_token_w8a8(self, weight_quant: BaseModel,
9696
input_quant: BaseModel) -> bool:
9797
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
98-
is_token_tensor = (weight_quant.strategy
99-
== QuantizationStrategy.TENSOR.value) and (
100-
input_quant.strategy
101-
== QuantizationStrategy.TOKEN.value)
98+
weight_strategy = (
99+
weight_quant.strategy == QuantizationStrategy.TENSOR.value
100+
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
101+
is_token = (weight_strategy and input_quant.strategy
102+
== QuantizationStrategy.TOKEN.value)
102103
is_symmetric = weight_quant.symmetric and input_quant.symmetric
103104
is_dynamic = not weight_quant.dynamic and input_quant.dynamic
104105

105-
return is_8_bits and is_token_tensor and is_symmetric and is_dynamic
106+
return is_8_bits and is_token and is_symmetric and is_dynamic
106107

107108
def _is_w4a16(self, weight_quant: BaseModel,
108109
input_quant: BaseModel) -> bool:
@@ -133,7 +134,8 @@ def _get_schema(self, weight_quant: BaseModel,
133134
return CompressedTensorsW8A8StaticTensor()
134135

135136
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
136-
return CompressedTensorsW8A8DynamicToken()
137+
return CompressedTensorsW8A8DynamicToken(
138+
strategy=weight_quant.strategy)
137139

138140
raise NotImplementedError(
139141
"No compressed-tensors compatible scheme was found.")

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,18 @@
66
from vllm import _custom_ops as custom_ops
77
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
88
CompressedTensorsScheme)
9+
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
10+
QuantizationStrategy)
911
from vllm.model_executor.utils import set_weight_attrs
1012

1113
__all__ = ["CompressedTensorsW8A8DynamicToken"]
1214

1315

1416
class CompressedTensorsW8A8DynamicToken(CompressedTensorsScheme):
1517

18+
def __init__(self, strategy: str):
19+
self.strategy = strategy
20+
1621
def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
1722
if isinstance(shard_id, int):
1823
return shard_id
@@ -45,11 +50,17 @@ def create_weights(self, layer: torch.nn.Module,
4550
# CompressedTensorsW8A8StaticTensor::create_weights for further
4651
# information.
4752
is_tensor_partitioned = len(output_partition_sizes) != 1
48-
weight_scale_dim = sum(
49-
output_partition_sizes) if is_tensor_partitioned else 1
53+
# when doing channel-wise quantization, number of scales
54+
# is equal to output_dim
55+
weight_scale_dim = sum(output_partition_sizes) if (
56+
is_tensor_partitioned
57+
or self.strategy == QuantizationStrategy.CHANNEL) else 1
58+
59+
shape: Union[Tuple[int], Tuple[int, int]] = (weight_scale_dim, )
60+
if self.strategy == QuantizationStrategy.CHANNEL:
61+
shape = (weight_scale_dim, 1)
5062

51-
weight_scale = Parameter(torch.empty(weight_scale_dim,
52-
dtype=torch.float32),
63+
weight_scale = Parameter(torch.empty(*shape, dtype=torch.float32),
5364
requires_grad=False)
5465

5566
weight = Parameter(torch.empty(sum(output_partition_sizes),
@@ -67,12 +78,20 @@ def create_weights(self, layer: torch.nn.Module,
6778
})
6879

6980
layer.register_parameter("weight_scale", weight_scale)
70-
set_weight_attrs(
71-
weight_scale, {
72-
"weight_loader": weight_loader,
73-
"shard_splitter": self.scales_shard_splitter,
74-
"logical_widths": output_partition_sizes
81+
set_weight_attrs(weight_scale, {"weight_loader": weight_loader})
82+
83+
# Don't need a shard_splitter for channel-wise quantization
84+
# Use the default loading method
85+
if self.strategy == QuantizationStrategy.CHANNEL:
86+
set_weight_attrs(weight_scale, {
87+
"output_dim": 0,
7588
})
89+
else:
90+
set_weight_attrs(
91+
weight_scale, {
92+
"logical_widths": output_partition_sizes,
93+
"shard_splitter": self.scales_shard_splitter,
94+
})
7695

7796
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
7897
weight = layer.weight

0 commit comments

Comments
 (0)