Skip to content

Commit

Permalink
[Kernel] w4a16 support for compressed-tensors (vllm-project#5385)
Browse files Browse the repository at this point in the history
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
  • Loading branch information
dsikka and robertgshaw2-neuralmagic authored Jun 13, 2024
1 parent 8840753 commit c2637a6
Show file tree
Hide file tree
Showing 4 changed files with 230 additions and 10 deletions.
27 changes: 25 additions & 2 deletions tests/quantization/test_compressed_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
Run `pytest tests/quantization/test_compressed_tensors.py`.
"""

import pytest
import torch

from vllm import SamplingParams
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsLinearMethod, CompressedTensorsW8A8DynamicToken,
CompressedTensorsW8A8StaticTensor)
CompressedTensorsLinearMethod, CompressedTensorsW4A16,
CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor)


def test_compressed_tensors_w8a8_static_setup(vllm_runner):
Expand Down Expand Up @@ -60,3 +61,25 @@ def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner):
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8DynamicToken)
assert qkv_proj.weight.dtype is torch.int8


@pytest.mark.parametrize("w4a16_args", [
("nm-testing/tinyllama-oneshot-w4a16-channel-v2", "channel", None),
("nm-testing/tinyllama-oneshot-w4a16-group128-v2", "group", 128),
])
def test_compressed_tensors_w4a16(vllm_runner, w4a16_args):
model, strategy, group = w4a16_args
with vllm_runner(model) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0]

qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsW4A16)

assert qkv_proj.scheme.strategy == strategy
assert qkv_proj.scheme.group_size == group

assert qkv_proj.weight_packed.dtype is torch.int32
assert qkv_proj.weight_scale.dtype is torch.float16
assert qkv_proj.weight_packed.pack_factor == 8
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
QuantizationConfig)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme, CompressedTensorsW8A8DynamicToken,
CompressedTensorsW8A8StaticTensor)
CompressedTensorsScheme, CompressedTensorsW4A16,
CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
QuantizationArgs, QuantizationStrategy, find_first_name_or_class_match)

Expand Down Expand Up @@ -47,16 +47,27 @@ def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
layer_quant_details: Dict[str, Any] = dict()
ignore: List[str] = config.get("ignore", None)

# The quant_config has multiple config_groups, each containing
# an input_activations key with details about how the activations are
# quantized, a weights key indicating how the weights are quantized,
# and a list of targets under the `targets` key, dictating which
# layers are impacted by the quantization details. The quantization
# details follow the structure defined by the QuantizationArgs
# pydantic model, which is used to verify the structure of the
# quant_config and also store the details for later use.
for key, quant_config in config["config_groups"].items():
targets = quant_config.get("targets")
for target in targets:
layer_quant_details[target] = {}
layer_quant_details[target][
"weight"] = QuantizationArgs.parse_obj(
"weights"] = QuantizationArgs.parse_obj(
quant_config.get("weights"))
layer_quant_details[target][
"input"] = QuantizationArgs.parse_obj(
quant_config.get("input_activations"))
try:
layer_quant_details[target][
"input_activations"] = QuantizationArgs.parse_obj(
quant_config.get("input_activations"))
except Exception:
layer_quant_details[target]["input_activations"] = None

return cls(layer_quant_details=layer_quant_details, ignore=ignore)

Expand Down Expand Up @@ -86,8 +97,23 @@ def _is_dynamic_token_w8a8(self, weight_quant: BaseModel,

return is_8_bits and is_token_tensor and is_symmetric and is_dynamic

def _is_w4a16(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
input_quant_none = input_quant is None
is_4_bits = weight_quant.num_bits == 4
is_symmetric = weight_quant.symmetric
is_static = not weight_quant.dynamic

return is_4_bits and input_quant_none and is_symmetric and is_static

def _get_schema(self, weight_quant: BaseModel,
input_quant: BaseModel) -> "CompressedTensorsScheme":

if self._is_w4a16(weight_quant, input_quant):
return CompressedTensorsW4A16(num_bits=weight_quant.num_bits,
strategy=weight_quant.strategy,
group_size=weight_quant.group_size)

if self._is_static_tensor_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8StaticTensor()

Expand All @@ -113,8 +139,9 @@ def get_scheme(self, layer: torch.nn.Module) -> "CompressedTensorsScheme":
raise ValueError(
f"Could not find quantization details for {layer}.")

return self._get_schema(weight_quant=layer_quant_details["weight"],
input_quant=layer_quant_details["input"])
return self._get_schema(
weight_quant=layer_quant_details["weights"],
input_quant=layer_quant_details["input_activations"])


class CompressedTensorsLinearMethod(LinearMethodBase):
Expand All @@ -140,6 +167,7 @@ def create_weights(self, layer: torch.nn.Module,
layer=layer,
input_size_per_partition=input_size_per_partition,
output_partition_sizes=output_partition_sizes,
input_size=input_size,
output_size=output_size,
params_dtype=params_dtype,
weight_loader=weight_loader)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .compressed_tensors_scheme import CompressedTensorsScheme # noqa: F401
from .compressed_tensors_unquantized import ( # noqa: F401
CompressedTensorsUnquantized)
from .compressed_tensors_w4a16 import CompressedTensorsW4A16 # noqa: F401
from .compressed_tensors_w8a8_dynamictoken import ( # noqa: F401, E501
CompressedTensorsW8A8DynamicToken)
from .compressed_tensors_w8a8_statictensor import ( # noqa: F401, E501
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
from typing import Callable, List, Optional

import torch
from torch.nn import Parameter

from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, GPTQMarlinState,
marlin_permute_scales)
from vllm.model_executor.utils import set_weight_attrs

__all__ = ["CompressedTensorsW4A16"]


class CompressedTensorsW4A16(CompressedTensorsScheme):

def __init__(self,
strategy: str,
num_bits: int,
group_size: Optional[int] = None):
self.num_bits = num_bits
self.strategy = strategy
self.group_size = group_size

if self.strategy == "group" and self.group_size is None:
raise ValueError(
"group_size must be given when using strategy group")

def create_weights(self, layer: torch.nn.Module, input_size: int,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):

pack_factor = 32 // self.num_bits
output_size_per_partition = sum(output_partition_sizes)

if self.group_size is not None:
group_size = self.group_size
else:
group_size = input_size

weight_scale_dim = None
scales_and_zp_size = input_size // group_size

if (input_size != input_size_per_partition
and self.group_size is not None):
weight_scale_dim = 1
scales_and_zp_size = input_size_per_partition // group_size

weight = Parameter(
torch.empty(
output_size_per_partition,
input_size_per_partition // pack_factor,
dtype=torch.int32,
),
requires_grad=False,
)

set_weight_attrs(
weight, {
"input_dim": 1,
"output_dim": 0,
"packed_dim": 1,
"pack_factor": pack_factor
})
set_weight_attrs(weight, {"weight_loader": weight_loader})

layer.register_parameter("weight_packed", weight)

weight_scale = Parameter(
torch.empty(
output_size_per_partition,
scales_and_zp_size,
dtype=params_dtype,
),
requires_grad=False,
)

set_weight_attrs(weight_scale, {"weight_loader": weight_loader})
set_weight_attrs(weight_scale, {
"input_dim": weight_scale_dim,
"output_dim": 0
})
layer.register_parameter("weight_scale", weight_scale)

# A 2D array defining the original shape of the weights
# before packing
weight_shape = Parameter(torch.empty(2, dtype=torch.int64),
requires_grad=False)

layer.register_parameter("weight_shape", weight_shape)
set_weight_attrs(weight_shape, {"weight_loader": weight_loader})

layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition

layer.input_size = input_size
layer.marlin_state = GPTQMarlinState.REPACK
layer.is_k_full = True
layer.group_size = group_size

max_workspace_size = (
output_size_per_partition //
GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL

workspace = torch.zeros(max_workspace_size,
dtype=torch.int,
requires_grad=False)
layer.workspace = workspace

def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
reshaped_x = x.reshape(-1, x.shape[-1])

size_m = reshaped_x.shape[0]
part_size_n = layer.output_size_per_partition
part_size_k = layer.input_size_per_partition

out_shape = x.shape[:-1] + (part_size_n, )

if layer.marlin_state == GPTQMarlinState.REPACK:
layer.marlin_state = GPTQMarlinState.READY

# Newly generated tensors need to replace existing tensors that are
# already registered as parameters by vLLM (and won't be freed)
def replace_tensor(name, new_t):
# It is important to use resize_() here since it ensures
# the same buffer is reused
getattr(layer, name).resize_(new_t.shape)
getattr(layer, name).copy_(new_t)
del new_t

cur_device = layer.weight_packed.device

# Reset g_idx related tensors
layer.g_idx = Parameter(torch.empty(0,
dtype=torch.int,
device=cur_device),
requires_grad=False)
layer.g_idx_sort_indices = Parameter(torch.empty(
0, dtype=torch.int, device=cur_device),
requires_grad=False)

# Repack weights
marlin_qweight = ops.gptq_marlin_repack(
layer.weight_packed.t().contiguous(), layer.g_idx_sort_indices,
part_size_k, part_size_n, self.num_bits)

replace_tensor("weight_packed", marlin_qweight)

# Permute scales
scales_size_k = part_size_k
scales_size_n = part_size_n

marlin_scales = marlin_permute_scales(
layer.weight_scale.squeeze().t().contiguous(), scales_size_k,
scales_size_n, layer.group_size, self.num_bits)
replace_tensor("weight_scale", marlin_scales)

output = ops.gptq_marlin_gemm(reshaped_x, layer.weight_packed,
layer.weight_scale, layer.g_idx,
layer.g_idx_sort_indices,
layer.workspace, self.num_bits, size_m,
part_size_n, part_size_k,
layer.is_k_full)
return output.reshape(out_shape)

0 comments on commit c2637a6

Please sign in to comment.