Skip to content

Commit 59323a1

Browse files
sakogan0xrushi
authored andcommitted
[Quantization] [Performance] Enable Marlin GEMM kernels for the calibration-free RTN-based quantization (vllm-project#26051)
Signed-off-by: Alex Kogan <alex.kogan@oracle.com> Signed-off-by: Alex Kogan <82225080+sakogan@users.noreply.github.com> Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
1 parent 50af0fb commit 59323a1

File tree

2 files changed

+233
-56
lines changed

2 files changed

+233
-56
lines changed

vllm/model_executor/layers/quantization/rtn.py

Lines changed: 188 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,16 @@
66
from collections.abc import Callable
77
from typing import Any, Optional
88

9+
import numpy as np
910
import torch
10-
import torch.nn.functional as F
1111
from torch.nn.parameter import Parameter
1212

1313
from vllm.logger import init_logger
14-
from vllm.model_executor.layers.fused_moe import (
15-
FusedMoE,
16-
FusedMoEConfig,
17-
FusedMoEMethodBase,
18-
)
1914
from vllm.model_executor.layers.fused_moe.config import (
15+
FusedMoEConfig,
2016
FusedMoEQuantConfig,
21-
int4_w4a16_moe_quant_config,
22-
int8_w8a16_moe_quant_config,
2317
)
18+
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase
2419
from vllm.model_executor.layers.linear import (
2520
LinearBase,
2621
LinearMethodBase,
@@ -31,6 +26,12 @@
3126
QuantizationConfig,
3227
QuantizeMethodBase,
3328
)
29+
from vllm.model_executor.layers.quantization.utils import replace_parameter
30+
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
31+
apply_rtn_marlin_linear,
32+
marlin_make_workspace_new,
33+
)
34+
from vllm.scalar_type import scalar_types
3435

3536
logger = init_logger(__name__)
3637
"""By default, use 8 bit as target precision, but it can be
@@ -41,6 +42,9 @@
4142
overridden by setting the RTN_GROUP_SIZE envvar
4243
"""
4344
GROUP_SIZE = os.getenv("RTN_GROUP_SIZE", "128")
45+
"""Global Marlin workspace shared by all modules
46+
"""
47+
workspace = None
4448

4549

4650
class RTNConfig(QuantizationConfig):
@@ -60,6 +64,10 @@ def __init__(
6064
f"supported for RTN, but got {self.weight_bits} bits."
6165
)
6266

67+
self.quant_type = (
68+
scalar_types.uint8b128 if self.weight_bits == 8 else scalar_types.uint4b8
69+
)
70+
6371
def __repr__(self) -> str:
6472
return (
6573
f"RTNConfig(weight_bits={self.weight_bits}, group_size={self.group_size})"
@@ -221,24 +229,32 @@ def create_weights(
221229
layer.output_size_per_partition = output_size_per_partition
222230

223231
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
224-
fix_weights(layer, "weight")
232+
"""Repack weights and scales for Marlin kernels."""
233+
weight_bits = self.quant_config.weight_bits
234+
235+
weight, scale = repack_weights(layer.weight, layer.scale, weight_bits)
236+
237+
replace_parameter(layer, "weight", weight)
238+
replace_parameter(layer, "scale", scale)
239+
240+
init_workspace(layer.weight.device)
225241

226242
def apply(
227243
self,
228244
layer: torch.nn.Module,
229245
x: torch.Tensor,
230246
bias: torch.Tensor | None = None,
231247
) -> torch.Tensor:
232-
qweight = layer.weight
233-
scale = layer.scale
234-
235-
weight = rtn_dequantize(qweight, scale)
236-
out = F.linear(x, weight)
237-
del weight
238-
if bias is not None:
239-
out.add_(bias)
240-
241-
return out
248+
return apply_rtn_marlin_linear(
249+
input=x,
250+
weight=layer.weight,
251+
weight_scale=layer.scale,
252+
workspace=workspace,
253+
quant_type=self.quant_config.quant_type,
254+
output_size_per_partition=layer.output_size_per_partition,
255+
input_size_per_partition=layer.input_size_per_partition,
256+
bias=bias,
257+
)
242258

243259

244260
class RTNMoEMethod(FusedMoEMethodBase):
@@ -315,28 +331,27 @@ def create_weights(
315331
set_weight_attrs(w2_weight, extra_weight_attrs)
316332

317333
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
334+
"""Repack weights and scales for Marlin kernels."""
318335
weight_bits = self.quant_config.weight_bits
319-
fix_weights(layer, "w13_weight", weight_bits == 4)
320-
fix_weights(layer, "w2_weight", weight_bits == 4)
336+
337+
w13_weight, w13_scale = repack_weights(
338+
layer.w13_weight, layer.w13_scale, weight_bits
339+
)
340+
replace_parameter(layer, "w13_weight", w13_weight)
341+
replace_parameter(layer, "w13_scale", w13_scale)
342+
343+
w2_weight, w2_scale = repack_weights(
344+
layer.w2_weight, layer.w2_scale, weight_bits
345+
)
346+
replace_parameter(layer, "w2_weight", w2_weight)
347+
replace_parameter(layer, "w2_scale", w2_scale)
348+
349+
init_workspace(layer.w13_weight.device)
321350

322351
def get_fused_moe_quant_config(
323352
self, layer: torch.nn.Module
324353
) -> FusedMoEQuantConfig | None:
325-
weight_bits = self.quant_config.weight_bits
326-
group_size = self.quant_config.group_size
327-
assert weight_bits == 4 or weight_bits == 8
328-
config_builder = (
329-
int4_w4a16_moe_quant_config
330-
if weight_bits == 4
331-
else int8_w8a16_moe_quant_config
332-
)
333-
return config_builder(
334-
w1_scale=layer.w13_scale,
335-
w2_scale=layer.w2_scale,
336-
w1_zp=None,
337-
w2_zp=None,
338-
block_shape=[0, group_size],
339-
)
354+
return None
340355

341356
def apply(
342357
self,
@@ -366,8 +381,6 @@ def apply(
366381
if enable_eplb:
367382
raise NotImplementedError("EPLB not supported for `RTNMoEMethod` yet.")
368383

369-
from vllm.model_executor.layers.fused_moe import fused_experts
370-
371384
topk_weights, topk_ids, _ = FusedMoE.select_experts(
372385
hidden_states=x,
373386
router_logits=router_logits,
@@ -383,18 +396,22 @@ def apply(
383396
indices_type=self.topk_indices_dtype,
384397
)
385398

386-
return fused_experts(
399+
return torch.ops.vllm.fused_marlin_moe(
387400
x,
388401
layer.w13_weight,
389402
layer.w2_weight,
390-
topk_weights=topk_weights,
391-
topk_ids=topk_ids,
392-
inplace=True,
393-
activation=activation,
403+
getattr(layer, "w13_bias", None),
404+
getattr(layer, "w2_bias", None),
405+
layer.w13_scale,
406+
layer.w2_scale,
407+
router_logits,
408+
topk_weights,
409+
topk_ids,
410+
quant_type_id=self.quant_config.quant_type.id,
394411
apply_router_weight_on_input=apply_router_weight_on_input,
395412
global_num_experts=global_num_experts,
396413
expert_map=expert_map,
397-
quant_config=self.moe_quant_config,
414+
workspace=workspace,
398415
)
399416

400417

@@ -504,18 +521,133 @@ def rtn_dequantize(tensor: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
504521
return input_deq
505522

506523

507-
def fix_weights(layer: torch.nn.Module, param_name: str, reshape: bool = False):
508-
"""torch.compile does not know how to deal with a Parameter subclass
509-
(aka RTNParameter). As we don't really need RTNParameters for the
510-
forward pass, we replace them with equivalent instances of Parameters.
524+
def _get_perms():
525+
perm = []
526+
for i in range(32):
527+
perm1 = []
528+
col = i // 4
529+
for block in [0, 1]:
530+
for row in [
531+
2 * (i % 4),
532+
2 * (i % 4) + 1,
533+
2 * (i % 4 + 4),
534+
2 * (i % 4 + 4) + 1,
535+
]:
536+
perm1.append(16 * row + col + 8 * block)
537+
for j in range(4):
538+
perm.extend([p + 256 * j for p in perm1])
539+
540+
perm_arr = np.array(perm)
541+
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
542+
perm_arr = perm_arr.reshape((-1, 8))[:, interleave].ravel()
543+
perm_tensor = torch.from_numpy(perm_arr)
544+
scale_perm = []
545+
for i in range(8):
546+
scale_perm.extend([i + 8 * j for j in range(8)])
547+
scale_perm_single = []
548+
for i in range(4):
549+
scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
550+
return perm_tensor, scale_perm, scale_perm_single
551+
552+
553+
_perm, _scale_perm, _scale_perm_single = _get_perms()
554+
555+
556+
def pack_for_marlin(weight, scale, qbits):
557+
batch = weight.shape[0]
558+
559+
n = weight.size(1)
560+
k = weight.size(2)
561+
groupsize = k // scale.size(2)
562+
563+
tile = 16
564+
s = scale.permute(0, 2, 1) # transpose
565+
w = weight.permute(0, 2, 1) # transpose
566+
if groupsize != k:
567+
w = w.reshape((batch, -1, groupsize, n))
568+
w = w.permute(0, 2, 1, 3)
569+
w = w.reshape((batch, groupsize, -1))
570+
s = s.reshape((batch, 1, -1))
571+
572+
if groupsize != k:
573+
w = w.reshape((batch, groupsize, -1, n))
574+
w = w.permute(0, 2, 1, 3)
575+
w = w.reshape((batch, k, n)).contiguous()
576+
s = s.reshape((batch, -1, len(_scale_perm)))[:, :, _scale_perm]
577+
else:
578+
s = s.reshape((batch, -1, len(_scale_perm_single)))[:, :, _scale_perm_single]
579+
s = s.reshape((batch, -1, n)).contiguous()
580+
w = w.reshape((batch, k // tile, tile, n // tile, tile))
581+
w = w.permute((0, 1, 3, 2, 4))
582+
w = w.reshape((batch, k // tile, n * tile))
583+
res = w
584+
res = res.reshape((batch, -1, _perm.numel()))[:, :, _perm].reshape(res.shape)
585+
if qbits == 4:
586+
q = torch.zeros(
587+
(batch, res.shape[1], res.shape[2] // 2), dtype=torch.int8, device=w.device
588+
)
589+
for i in range(2):
590+
q |= res[:, :, i::2] << 4 * i
591+
q = q.reshape(batch, -1, n).contiguous()
592+
else:
593+
q = res.clone()
594+
q[:, :, 2::8] = res[:, :, 4::8]
595+
q[:, :, 3::8] = res[:, :, 5::8]
596+
q[:, :, 4::8] = res[:, :, 2::8]
597+
q[:, :, 5::8] = res[:, :, 3::8]
598+
q = q.reshape(batch, -1, n).to(torch.int8).contiguous()
599+
600+
return q, s
601+
602+
603+
def repack_8bit_into_32bit(input):
604+
output = torch.zeros(
605+
(input.shape[0], input.shape[1], input.shape[2] // 4),
606+
dtype=torch.int32,
607+
device=input.device,
608+
)
609+
for i in range(4):
610+
output |= (input[:, :, i::4] & 0xFF).to(torch.int32) << 8 * i
611+
612+
return output
613+
614+
615+
def repack_weights(qweight, scale, weight_bits):
616+
batch_present = len(qweight.shape) == 3
617+
if not batch_present:
618+
qweight = qweight.unsqueeze(0)
619+
scale = scale.unsqueeze(0)
620+
621+
if weight_bits == 4:
622+
"""Unpack two 4-bit values from each byte.
623+
"""
624+
qweight_unpacked = torch.empty(
625+
(qweight.shape[0], qweight.shape[1] * 2, qweight.shape[2]),
626+
dtype=torch.uint8,
627+
device=qweight.device,
628+
)
629+
for i in range(2):
630+
qweight_unpacked[:, :, i::2] = ((qweight << 4 * (1 - i)) >> 4).reshape(
631+
qweight.shape[0], qweight.shape[1] * 2, qweight.shape[2] // 2
632+
)
633+
else:
634+
qweight_unpacked = qweight
635+
636+
qweight_packed, scale_packed = pack_for_marlin(qweight_unpacked, scale, weight_bits)
637+
"""Marlin kernels expect tensors in int32 format in a certain shape
511638
"""
512-
old_weight = getattr(layer, param_name)
513-
assert isinstance(old_weight, RTNParameter)
514-
data = old_weight.data.data
639+
qweight_repacked = repack_8bit_into_32bit(qweight_packed.to(torch.uint8))
640+
qweight_reshaped = qweight_repacked.reshape(
641+
qweight.shape[0], qweight.shape[2] // 16, -1
642+
)
643+
if not batch_present:
644+
qweight_reshaped = qweight_reshaped.squeeze(0)
645+
scale_packed = scale_packed.squeeze(0)
646+
647+
return qweight_reshaped, scale_packed
515648

516-
delattr(layer, param_name)
517649

518-
if reshape:
519-
data = data.reshape(old_weight.shape[0], old_weight.shape[1] * 2, -1)
520-
new_weight = Parameter(data=data, requires_grad=False)
521-
layer.register_parameter(param_name, new_weight)
650+
def init_workspace(device):
651+
global workspace
652+
if workspace is None:
653+
workspace = marlin_make_workspace_new(device, 4)

vllm/model_executor/layers/quantization/utils/marlin_utils.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,3 +528,48 @@ def apply_awq_marlin_linear(
528528
)
529529

530530
return output.reshape(out_shape)
531+
532+
533+
def apply_rtn_marlin_linear(
534+
input: torch.Tensor,
535+
weight: torch.Tensor,
536+
weight_scale: torch.Tensor,
537+
workspace: torch.Tensor,
538+
quant_type: ScalarType,
539+
output_size_per_partition: int,
540+
input_size_per_partition: int,
541+
bias: torch.Tensor | None = None,
542+
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
543+
) -> torch.Tensor:
544+
reshaped_x = input.reshape(-1, input.shape[-1])
545+
out_shape = input.shape[:-1] + (output_size_per_partition,)
546+
547+
use_atomic_add = should_use_atomic_add_reduce(
548+
m=reshaped_x.size(0),
549+
n=output_size_per_partition,
550+
k=reshaped_x.size(1),
551+
device=input.device,
552+
dtype=input.dtype,
553+
)
554+
555+
output = ops.gptq_marlin_gemm(
556+
reshaped_x,
557+
None,
558+
weight,
559+
bias,
560+
weight_scale,
561+
None,
562+
None,
563+
None,
564+
None,
565+
workspace,
566+
quant_type,
567+
size_m=reshaped_x.shape[0],
568+
size_n=output_size_per_partition,
569+
size_k=input_size_per_partition,
570+
use_atomic_add=use_atomic_add,
571+
use_fp32_reduce=use_fp32_reduce,
572+
is_zp_float=False,
573+
)
574+
575+
return output.reshape(out_shape)

0 commit comments

Comments
 (0)