Skip to content

Commit 890d8d9

Browse files
authored
[Kernel] compressed-tensors marlin 24 support (#5435)
1 parent 9e74d9d commit 890d8d9

File tree

5 files changed

+196
-19
lines changed

5 files changed

+196
-19
lines changed

tests/quantization/test_compressed_tensors.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from vllm import SamplingParams
1010
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
1111
CompressedTensorsLinearMethod, CompressedTensorsW4A16,
12-
CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor)
12+
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8DynamicToken,
13+
CompressedTensorsW8A8StaticTensor)
1314

1415

1516
def test_compressed_tensors_w8a8_static_setup(vllm_runner):
@@ -51,8 +52,7 @@ def test_compressed_tensors_no_enforce_eager(vllm_runner):
5152

5253
def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner):
5354
model_path = "nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2"
54-
with vllm_runner(model_path, enforce_eager=True,
55-
dtype=torch.float16) as llm:
55+
with vllm_runner(model_path, dtype=torch.float16) as llm:
5656
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
5757
layer = model.model.layers[0]
5858

@@ -83,3 +83,20 @@ def test_compressed_tensors_w4a16(vllm_runner, w4a16_args):
8383
assert qkv_proj.weight_packed.dtype is torch.int32
8484
assert qkv_proj.weight_scale.dtype is torch.float16
8585
assert qkv_proj.weight_packed.pack_factor == 8
86+
87+
88+
def test_compressed_tensors_w4a16_marlin24(vllm_runner):
89+
model_path = "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t"
90+
with vllm_runner(model_path) as llm:
91+
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
92+
layer = model.model.layers[0]
93+
94+
qkv_proj = layer.self_attn.qkv_proj
95+
96+
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
97+
assert isinstance(qkv_proj.scheme, CompressedTensorsW4A16Sparse24)
98+
assert qkv_proj.weight_packed.dtype is torch.int32
99+
100+
sampling_params = SamplingParams()
101+
output = llm.generate("Hello world!", sampling_params=sampling_params)
102+
assert output

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,20 @@
88
QuantizationConfig)
99
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
1010
CompressedTensorsScheme, CompressedTensorsW4A16,
11-
CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor)
11+
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8DynamicToken,
12+
CompressedTensorsW8A8StaticTensor)
1213
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
13-
QuantizationArgs, QuantizationStrategy, find_first_name_or_class_match)
14+
CompressionFormat, QuantizationArgs, QuantizationStrategy,
15+
find_first_name_or_class_match)
1416

1517

1618
class CompressedTensorsConfig(QuantizationConfig):
1719

18-
def __init__(self, layer_quant_details: Dict[str, Any], ignore: List[str]):
20+
def __init__(self, layer_quant_details: Dict[str, Any], ignore: List[str],
21+
quant_format: str):
1922
self.ignore = ignore
2023
self.layer_quant_details = layer_quant_details
24+
self.quant_format = quant_format
2125

2226
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
2327
return CompressedTensorsLinearMethod(self)
@@ -46,6 +50,7 @@ def get_quant_method(
4650
def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
4751
layer_quant_details: Dict[str, Any] = dict()
4852
ignore: List[str] = config.get("ignore", None)
53+
quant_format: str = config.get("format", None)
4954

5055
# The quant_config has multiple config_groups, each containing
5156
# an input_activations key with details about how the activations are
@@ -69,7 +74,9 @@ def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
6974
except Exception:
7075
layer_quant_details[target]["input_activations"] = None
7176

72-
return cls(layer_quant_details=layer_quant_details, ignore=ignore)
77+
return cls(layer_quant_details=layer_quant_details,
78+
ignore=ignore,
79+
quant_format=quant_format)
7380

7481
@classmethod
7582
def get_config_filenames(cls) -> List[str]:
@@ -110,17 +117,26 @@ def _get_schema(self, weight_quant: BaseModel,
110117
input_quant: BaseModel) -> "CompressedTensorsScheme":
111118

112119
if self._is_w4a16(weight_quant, input_quant):
113-
return CompressedTensorsW4A16(num_bits=weight_quant.num_bits,
114-
strategy=weight_quant.strategy,
115-
group_size=weight_quant.group_size)
116-
117-
if self._is_static_tensor_w8a8(weight_quant, input_quant):
118-
return CompressedTensorsW8A8StaticTensor()
119-
120-
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
121-
return CompressedTensorsW8A8DynamicToken()
122-
123-
raise NotImplementedError("Scheme not supported.")
120+
if self.quant_format == CompressionFormat.marlin_24.value:
121+
return CompressedTensorsW4A16Sparse24(
122+
strategy=weight_quant.strategy,
123+
num_bits=weight_quant.num_bits,
124+
group_size=weight_quant.group_size)
125+
if self.quant_format == CompressionFormat.pack_quantized.value:
126+
return CompressedTensorsW4A16(
127+
num_bits=weight_quant.num_bits,
128+
strategy=weight_quant.strategy,
129+
group_size=weight_quant.group_size)
130+
131+
if self.quant_format == CompressionFormat.int_quantized.value:
132+
if self._is_static_tensor_w8a8(weight_quant, input_quant):
133+
return CompressedTensorsW8A8StaticTensor()
134+
135+
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
136+
return CompressedTensorsW8A8DynamicToken()
137+
138+
raise NotImplementedError(
139+
"No compressed-tensors compatible scheme was found.")
124140

125141
def get_scheme(self, layer: torch.nn.Module) -> "CompressedTensorsScheme":
126142

@@ -165,9 +181,9 @@ def create_weights(self, layer: torch.nn.Module,
165181
scheme = self.quantization_config.get_scheme(layer=layer)
166182
scheme.create_weights(
167183
layer=layer,
184+
input_size=input_size,
168185
input_size_per_partition=input_size_per_partition,
169186
output_partition_sizes=output_partition_sizes,
170-
input_size=input_size,
171187
output_size=output_size,
172188
params_dtype=params_dtype,
173189
weight_loader=weight_loader)

vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from .compressed_tensors_unquantized import ( # noqa: F401
33
CompressedTensorsUnquantized)
44
from .compressed_tensors_w4a16 import CompressedTensorsW4A16 # noqa: F401
5+
from .compressed_tensors_w4a16_24 import ( # noqa: F401
6+
CompressedTensorsW4A16Sparse24)
57
from .compressed_tensors_w8a8_dynamictoken import ( # noqa: F401, E501
68
CompressedTensorsW8A8DynamicToken)
79
from .compressed_tensors_w8a8_statictensor import ( # noqa: F401, E501
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
from typing import Callable, List, Optional
2+
3+
import torch
4+
from torch.nn import Parameter
5+
6+
from vllm import _custom_ops as ops
7+
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
8+
CompressedTensorsScheme)
9+
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
10+
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N)
11+
from vllm.model_executor.utils import set_weight_attrs
12+
13+
__all__ = ["CompressedTensorsW4A16Sparse24"]
14+
15+
16+
class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
17+
18+
def __init__(self,
19+
strategy: str,
20+
num_bits: int,
21+
group_size: Optional[int] = None):
22+
self.strategy = strategy
23+
self.group_size = group_size
24+
self.num_bits = num_bits
25+
self.tile_size = 16
26+
27+
if self.strategy == "group" and self.group_size is None:
28+
raise ValueError(
29+
"group_size must be given when using strategy group")
30+
31+
def create_weights(self, layer: torch.nn.Module, input_size: int,
32+
output_partition_sizes: List[int],
33+
input_size_per_partition: int,
34+
params_dtype: torch.dtype, weight_loader: Callable,
35+
**kwargs):
36+
37+
pack_factor = 32 // self.num_bits
38+
output_size_per_partition = sum(output_partition_sizes)
39+
40+
qweight = Parameter(
41+
torch.empty(
42+
input_size_per_partition // self.tile_size // 2,
43+
output_size_per_partition * self.tile_size // pack_factor,
44+
dtype=torch.int32,
45+
),
46+
requires_grad=False,
47+
)
48+
set_weight_attrs(
49+
qweight,
50+
{
51+
"input_dim": 0,
52+
"output_dim": 1,
53+
"packed_dim": 1,
54+
"pack_factor": pack_factor,
55+
"marlin_tile_size": self.tile_size,
56+
"weight_loader": weight_loader
57+
},
58+
)
59+
60+
layer.register_parameter("weight_packed", qweight)
61+
62+
input_groups = (1 if self.group_size is None else
63+
input_size_per_partition // self.group_size)
64+
65+
scales = Parameter(
66+
torch.empty(
67+
input_groups,
68+
output_size_per_partition,
69+
dtype=params_dtype,
70+
),
71+
requires_grad=False,
72+
)
73+
set_weight_attrs(
74+
scales,
75+
{
76+
"output_dim": 1,
77+
"input_dim": None if input_groups == 1 else 0,
78+
"weight_loader": weight_loader
79+
},
80+
)
81+
layer.register_parameter("scale_packed", scales)
82+
83+
weight_shape = Parameter(torch.empty(2, dtype=torch.int64),
84+
requires_grad=False)
85+
86+
layer.register_parameter("weight_shape", weight_shape)
87+
set_weight_attrs(weight_shape, {"weight_loader": weight_loader})
88+
89+
meta = Parameter(
90+
torch.empty(
91+
input_size_per_partition // 8 // 2 // 2,
92+
output_size_per_partition * 2,
93+
dtype=torch.int16,
94+
),
95+
requires_grad=False,
96+
)
97+
set_weight_attrs(
98+
meta,
99+
{
100+
"input_dim": 0,
101+
"packed_dim": 1,
102+
"pack_factor": 1,
103+
"output_dim": 1,
104+
"marlin_tile_size": 2,
105+
"weight_loader": weight_loader
106+
},
107+
)
108+
layer.register_parameter("meta", meta)
109+
110+
max_workspace_size = (
111+
output_size_per_partition //
112+
GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL
113+
workspace = Parameter(torch.zeros(max_workspace_size, dtype=torch.int),
114+
requires_grad=False)
115+
layer.workspace = workspace
116+
117+
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
118+
qweight = layer.weight_packed
119+
meta = layer.meta
120+
scales = layer.scale_packed
121+
workspace = layer.workspace
122+
123+
x_2d = x.view(-1, x.shape[-1])
124+
125+
size_m = x_2d.shape[0]
126+
size_k = x_2d.shape[1]
127+
size_n = scales.shape[1]
128+
129+
output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales,
130+
workspace, self.num_bits, size_m,
131+
size_n, size_k)
132+
133+
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
134+
return output

vllm/model_executor/layers/quantization/compressed_tensors/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,14 @@
66
from torch.nn import Module
77

88

9+
class CompressionFormat(Enum):
10+
dense = "dense"
11+
sparse_bitmask = "sparse-bitmask"
12+
int_quantized = "int-quantized"
13+
pack_quantized = "pack-quantized"
14+
marlin_24 = "marlin-24"
15+
16+
917
class QuantizationType(str, Enum):
1018
"""
1119
Enum storing quantization type options

0 commit comments

Comments
 (0)