Skip to content

Commit 8b22a35

Browse files
committed
[Feat]: Add support for Dynamic Quant 4 bit CPU kleidiai kernels
Description: 1. Add optimized kernel support for Arm 4 bit matmul kernels Signed-off-by: Nikhil Gupta <nikhil.gupta2@arm.com>
1 parent 6e4bef1 commit 8b22a35

File tree

6 files changed

+278
-11
lines changed

6 files changed

+278
-11
lines changed

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliate open-source-office@arm.com
12
# SPDX-License-Identifier: Apache-2.0
23
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
34

@@ -25,7 +26,7 @@
2526
CompressedTensorsMoEMethod)
2627
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
2728
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensors24,
28-
CompressedTensorsScheme, CompressedTensorsW4A4Fp4,
29+
CompressedTensorsScheme, CompressedTensorsW4A4Fp4, CompressedTensorsW4A8,
2930
CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24,
3031
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
3132
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
@@ -74,7 +75,7 @@ def get_linear_method(self) -> "CompressedTensorsLinearMethod":
7475
return CompressedTensorsLinearMethod(self)
7576

7677
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
77-
return [torch.float16, torch.bfloat16]
78+
return [torch.float32, torch.float16, torch.bfloat16]
7879

7980
@classmethod
8081
def get_min_capability(cls) -> int:
@@ -299,6 +300,22 @@ def _is_dynamic_token_w8a8(self, weight_quant: BaseModel,
299300
# Only symmetric weight quantization supported.
300301
return is_8_bits and is_token and weight_quant.symmetric and is_dynamic
301302

303+
def _is_dynamic_token_w4a8(self, weight_quant: BaseModel,
304+
input_quant: BaseModel) -> bool:
305+
is_weight_4_bits = weight_quant.num_bits == 4
306+
is_activation_8_bits = input_quant.num_bits == 8
307+
weight_strategy = (
308+
weight_quant.strategy == QuantizationStrategy.GROUP.value
309+
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
310+
is_token = (weight_strategy and input_quant.strategy
311+
== QuantizationStrategy.TOKEN.value)
312+
is_dynamic = not weight_quant.dynamic and input_quant.dynamic
313+
314+
# Both symmetric and asymmetric input quantization supported.
315+
# Only symmetric weight quantization supported.
316+
return (is_weight_4_bits and is_activation_8_bits and is_token
317+
and weight_quant.symmetric and is_dynamic)
318+
302319
def _is_fp8_w8a8(self, weight_quant: BaseModel,
303320
input_quant: BaseModel) -> bool:
304321
# Confirm weights and activations quantized.
@@ -368,7 +385,6 @@ def _is_wNa16_group_channel(self, weight_quant: BaseModel,
368385
def _get_scheme_from_parts(
369386
self, weight_quant: BaseModel,
370387
input_quant: BaseModel) -> "CompressedTensorsScheme":
371-
372388
# Detect If Mixed Precision
373389
if self._is_fp4a16_nvfp4(weight_quant, input_quant):
374390
return CompressedTensorsW4A16Fp4()
@@ -437,6 +453,16 @@ def _get_scheme_from_parts(
437453
is_static_input_scheme=False,
438454
input_symmetric=input_quant.symmetric)
439455

456+
if self._is_dynamic_token_w4a8(weight_quant, input_quant):
457+
is_static_input_scheme = (input_quant
458+
and not input_quant.dynamic)
459+
return CompressedTensorsW4A8(
460+
num_bits=weight_quant.num_bits,
461+
strategy=weight_quant.strategy,
462+
group_size=weight_quant.group_size,
463+
is_static_input_scheme=is_static_input_scheme,
464+
input_symmetric=input_quant.symmetric)
465+
440466
raise NotImplementedError(
441467
"No compressed-tensors compatible scheme was found.")
442468

vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
# SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliate open-source-office@arm.com
12
# SPDX-License-Identifier: Apache-2.0
23
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
34

45
from .compressed_tensors_scheme import CompressedTensorsScheme
56
from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4
7+
from .compressed_tensors_w4a8 import CompressedTensorsW4A8
68
from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS,
79
CompressedTensorsW4A16Sparse24)
810
from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4
@@ -20,5 +22,5 @@
2022
"CompressedTensorsW8A8Int8", "CompressedTensorsW8A8Fp8",
2123
"WNA16_SUPPORTED_BITS", "W4A16SPARSE24_SUPPORTED_BITS",
2224
"CompressedTensors24", "CompressedTensorsW4A16Fp4",
23-
"CompressedTensorsW4A4Fp4"
25+
"CompressedTensorsW4A4Fp4", "CompressedTensorsW4A8"
2426
]
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliate open-source-office@arm.com
2+
# SPDX-License-Identifier: Apache-2.0
3+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
4+
5+
from typing import Callable, List, Optional, Set
6+
7+
import torch
8+
9+
from vllm.logger import init_logger
10+
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
11+
CompressedTensorsScheme)
12+
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
13+
MPLinearLayerConfig, choose_mp_linear_kernel)
14+
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
15+
GroupQuantScaleParameter,
16+
ModelWeightParameter)
17+
from vllm.scalar_type import scalar_types
18+
19+
logger = init_logger(__name__)
20+
21+
__all__ = ["CompressedTensorsW4A8"]
22+
W4A8_SUPPORTED_TYPES_MAP = {
23+
4: scalar_types.int4,
24+
}
25+
W4A8_SUPPORTED_BITS = list(W4A8_SUPPORTED_TYPES_MAP.keys())
26+
27+
28+
class CompressedTensorsW4A8(CompressedTensorsScheme):
29+
_kernel_backends_being_used: Set[str] = set()
30+
31+
def __init__(self,
32+
strategy: str,
33+
num_bits: int,
34+
group_size: Optional[int] = None,
35+
is_static_input_scheme: bool = False,
36+
input_symmetric: bool = True):
37+
self.strategy = strategy
38+
self.group_size = -1 if group_size is None else group_size
39+
self.is_static_input_scheme = is_static_input_scheme
40+
self.input_symmetric = input_symmetric
41+
42+
if num_bits not in W4A8_SUPPORTED_TYPES_MAP:
43+
raise ValueError(
44+
f"Unsupported num_bits = {num_bits}."
45+
f"Supported num_bits = {W4A8_SUPPORTED_TYPES_MAP.keys()}")
46+
self.quant_type = W4A8_SUPPORTED_TYPES_MAP[num_bits]
47+
48+
@classmethod
49+
def get_min_capability(cls) -> int:
50+
return 1
51+
52+
def create_weights(self, layer: torch.nn.Module, output_size: int,
53+
input_size: int, output_partition_sizes: List[int],
54+
input_size_per_partition: int,
55+
params_dtype: torch.dtype, weight_loader: Callable,
56+
**kwargs):
57+
output_size_per_partition = sum(output_partition_sizes)
58+
row_parallel = (input_size != input_size_per_partition)
59+
60+
# Compute effective group_size
61+
if self.group_size == -1:
62+
effective_group_size = (input_size_per_partition
63+
if row_parallel else input_size)
64+
else:
65+
effective_group_size = self.group_size
66+
67+
# Ensure group_size divides input_size_per_partition
68+
assert input_size_per_partition % effective_group_size == 0, (
69+
f"input_size_per_partition {input_size_per_partition}"
70+
f" not divisible by group_size {effective_group_size}")
71+
72+
# Determine scale partitioning
73+
is_channelwise = (self.group_size == -1)
74+
repeat_scales = (is_channelwise and row_parallel)
75+
partition_scales = not repeat_scales
76+
77+
mp_linear_kernel_config = MPLinearLayerConfig(
78+
full_weight_shape=(input_size, output_size),
79+
partition_weight_shape=(input_size_per_partition,
80+
output_size_per_partition),
81+
weight_type=self.quant_type,
82+
act_type=params_dtype,
83+
group_size=effective_group_size,
84+
zero_points=False,
85+
has_g_idx=False,
86+
)
87+
88+
kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
89+
if kernel_type.__name__ not in self._kernel_backends_being_used:
90+
logger.info("Using %s for CompressedTensorsW4A8",
91+
kernel_type.__name__)
92+
self._kernel_backends_being_used.add(kernel_type.__name__)
93+
94+
scales_and_zp_size = input_size_per_partition // effective_group_size
95+
96+
weight = ModelWeightParameter(data=torch.empty(
97+
output_size_per_partition,
98+
input_size_per_partition,
99+
dtype=torch.int8),
100+
input_dim=1,
101+
output_dim=0,
102+
weight_loader=weight_loader)
103+
layer.register_parameter("weight", weight)
104+
105+
weight_scale_args = {
106+
"weight_loader":
107+
weight_loader,
108+
"data":
109+
torch.empty(output_size_per_partition,
110+
scales_and_zp_size,
111+
dtype=params_dtype)
112+
}
113+
114+
if partition_scales:
115+
weight_scale = GroupQuantScaleParameter(output_dim=0,
116+
input_dim=1,
117+
**weight_scale_args)
118+
else:
119+
weight_scale = ChannelQuantScaleParameter(output_dim=0,
120+
**weight_scale_args)
121+
122+
layer.register_parameter("weight_packed", weight)
123+
layer.register_parameter("weight_scale", weight_scale)
124+
125+
self.kernel = kernel_type(mp_linear_kernel_config,
126+
w_q_param_name="weight_packed",
127+
w_s_param_name="weight_scale",
128+
w_zp_param_name=None,
129+
w_gidx_param_name=None)
130+
131+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
132+
self.kernel.process_weights_after_loading(layer)
133+
134+
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
135+
bias: Optional[torch.Tensor]) -> torch.Tensor:
136+
return self.kernel.apply_weights(layer, x, bias)

vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliate open-source-office@arm.com
12
# SPDX-License-Identifier: Apache-2.0
23
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
34

@@ -8,6 +9,8 @@
89
AllSparkLinearKernel)
910
from vllm.model_executor.layers.quantization.kernels.mixed_precision.bitblas import ( # noqa: E501
1011
BitBLASLinearKernel)
12+
from vllm.model_executor.layers.quantization.kernels.mixed_precision.dynamic_4bit import ( # noqa: E501
13+
Dynamic4bitLinearKernel)
1114
from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501
1215
ExllamaLinearKernel)
1316
from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501
@@ -20,6 +23,7 @@
2023

2124
# in priority/performance order (when available)
2225
_POSSIBLE_KERNELS: list[type[MPLinearKernel]] = [
26+
Dynamic4bitLinearKernel,
2327
MacheteLinearKernel,
2428
AllSparkLinearKernel,
2529
MarlinLinearKernel,
@@ -53,20 +57,21 @@ def choose_mp_linear_kernel(
5357
if current_platform is None:
5458
raise ValueError("Cannot determine compute capability")
5559
_cc = current_platform.get_device_capability()
56-
compute_capability = _cc[0] * 10 + _cc[1]
60+
if _cc is not None:
61+
compute_capability = _cc[0] * 10 + _cc[1]
5762

5863
failure_reasons = []
5964
for kernel in _POSSIBLE_KERNELS:
6065
if kernel.__name__ in envs.VLLM_DISABLED_KERNELS:
6166
failure_reasons.append(
6267
f' {kernel.__name__} disabled by environment variable')
6368
continue
64-
65-
if kernel.get_min_capability() > compute_capability:
69+
if (compute_capability is not None
70+
and kernel.get_min_capability() > compute_capability):
6671
failure_reasons.append(
6772
f"{kernel.__name__} requires capability "
68-
f"{kernel.get_min_capability()}, current compute capability "
69-
f"is {compute_capability}")
73+
f"{kernel.get_min_capability()}, current compute "
74+
f" capability is {compute_capability}")
7075
continue
7176

7277
can_implement, failure_reason = kernel.can_implement(config)
@@ -80,4 +85,4 @@ def choose_mp_linear_kernel(
8085
raise ValueError(
8186
"Failed to find a kernel that can implement the "\
8287
"WNA16 linear layer. Reasons: \n"
83-
+ '\n'.join(failure_reasons))
88+
+ '\n'.join(failure_reasons))
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliate open-source-office@arm.com
2+
# SPDX-License-Identifier: Apache-2.0
3+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
4+
5+
from typing import Optional, Tuple
6+
7+
import torch
8+
9+
from vllm.model_executor.layers.quantization.utils import replace_parameter
10+
from vllm.platforms import CpuArchEnum, current_platform
11+
from vllm.scalar_type import scalar_types
12+
13+
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
14+
15+
16+
class Dynamic4bitLinearKernel(MPLinearKernel):
17+
SUPPORTED_QUANT_TYPES = [scalar_types.int4]
18+
19+
@classmethod
20+
def get_min_capability(cls) -> int:
21+
return 1
22+
23+
@classmethod
24+
def can_implement(cls,
25+
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
26+
if not current_platform.is_cpu():
27+
return False, "Only CPU is supported"
28+
if c.weight_type not in cls.SUPPORTED_QUANT_TYPES:
29+
return False, f"Unsupported quant type {c.weight_type}"
30+
if current_platform.get_cpu_architecture(
31+
) == CpuArchEnum.ARM and c.act_type not in [
32+
torch.float32,
33+
]:
34+
return False, "Dynamic4bitLinearKernel on Arm requires"\
35+
" Float32 activations"
36+
if c.full_weight_shape[0] % c.group_size != 0:
37+
return False, f"Group size ({c.group_size}) does not evenly divide"\
38+
" the number of input features "\
39+
f"({c.full_weight_shape[0]})"
40+
if current_platform.get_cpu_architecture() == CpuArchEnum.ARM:
41+
try:
42+
# Attempt to retrieve the operation
43+
_ = torch.ops.aten._dyn_quant_matmul_4bit
44+
except AttributeError:
45+
return False, f"PyTorch {torch.__version__} does not support"\
46+
" _dyn_quant_matmul_4bit. Install a newer version"
47+
return True, None
48+
49+
def process_weights_after_loading(self, layer: torch.nn.Module):
50+
c = self.config
51+
packed_weight = getattr(layer, self.w_q_name)
52+
packed_weight = packed_weight.add(8)
53+
uint8_packed = (packed_weight[::, 1::2] << 4
54+
| packed_weight[::, ::2]).to(torch.uint8)
55+
56+
scales = getattr(layer, self.w_s_name)
57+
block_size = c.group_size
58+
59+
# Handle scaling factors for partitioned weights
60+
if block_size == c.partition_weight_shape[0]:
61+
scales = scales.to(
62+
torch.float32
63+
) # Float32 & Bfloat16 variants requires float32 scales
64+
scales = scales.view(-1, 1) # Channel-wise scales
65+
if layer.bias is not None:
66+
layer.bias = layer.bias.to(
67+
torch.float32
68+
) # Float32 & Bfloat16 variants requires float32 bias
69+
else:
70+
# KleidiAI kernel requires bfloat16 scales with groupwise scheme
71+
scales = scales.to(torch.bfloat16)
72+
73+
# Repack weights as per kernel requirement
74+
w = torch.ops.aten._dyn_quant_pack_4bit_weight(
75+
uint8_packed, scales, layer.bias, block_size,
76+
c.partition_weight_shape[0], c.partition_weight_shape[1])
77+
replace_parameter(layer, self.w_q_name,
78+
torch.nn.Parameter(w, requires_grad=False))
79+
setattr(layer, self.w_s_name, None)
80+
81+
def apply_weights(self,
82+
layer: torch.nn.Module,
83+
x: torch.Tensor,
84+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
85+
c = self.config
86+
x_2d = x.reshape(-1, x.shape[-1])
87+
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
88+
89+
w_q = getattr(layer, self.w_q_name)
90+
output = torch.ops.aten._dyn_quant_matmul_4bit(
91+
x_2d, w_q, c.group_size, c.partition_weight_shape[0],
92+
c.partition_weight_shape[1])
93+
return output.reshape(out_shape)

0 commit comments

Comments
 (0)