Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Add Exllama as a backend for compressed-tensors #9395

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
initial commit
  • Loading branch information
LucasWilkinson committed Oct 16, 2024
commit 695e85e70e9c12b173c57a789f7b6dd03d022206
9 changes: 9 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
VLLM_SKIP_P2P_CHECK: bool = False
VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1: bool = False
VLLM_TORCH_COMPILE_LEVEL: int = 0
VLLM_DISABLED_KERNELS: List[str] = []


def get_default_cache_root():
Expand Down Expand Up @@ -430,6 +431,14 @@ def get_default_config_root():
"VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1":
lambda: os.environ.get("VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1", "0"
) == "1",

# List of quantization kernels that should be disabled, used for testing
# and performance comparisons. Currently only affects MPLinearKernel
# selection
# (kernels: MacheteLinearKernel, MarlinLinearKernel, ExllamaLinearKernel)
"VLLM_DISABLED_KERNELS":
lambda: [] if "VLLM_DISABLED_KERNELS" not in os.environ else os.environ[
"VLLM_DISABLED_KERNELS"].split(","),
}

# end-env-vars-definition
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ def __init__(self,
self.config = c
self.w_q_name = w_q_param_name
self.w_s_name = w_s_param_name
if c.zero_points:
assert w_zp_param_name is not None
if c.has_g_idx:
assert w_gidx_param_name is not None
self.w_zp_name = w_zp_param_name
self.w_gidx_name = w_gidx_param_name

Expand Down
7 changes: 5 additions & 2 deletions vllm/model_executor/layers/quantization/kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import os
from typing import List, Optional, Type

import vllm.envs as envs
from vllm.model_executor.layers.quantization.kernels.machete import (
MacheteLinearKernel)
from vllm.model_executor.layers.quantization.kernels.marlin import (
MarlinLinearKernel)
from vllm.model_executor.layers.quantization.kernels.exllama import (
ExllamaLinearKernel)
from vllm.model_executor.layers.quantization.kernels.MPLinearKernel import (
MPLinearKernel, MPLinearLayerConfig)
from vllm.platforms import current_platform
Expand All @@ -13,6 +16,7 @@
_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [
MacheteLinearKernel,
MarlinLinearKernel,
ExllamaLinearKernel,
]


Expand Down Expand Up @@ -45,8 +49,7 @@ def choose_mp_linear_kernel(

failure_reasons = []
for kernel in _POSSIBLE_KERNELS:
if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "")\
.split(","):
if kernel.__name__ in envs.VLLM_DISABLED_KERNELS:
failure_reasons.append(
f' {kernel.__name__} disabled by environment variable')
continue
Expand Down
146 changes: 146 additions & 0 deletions vllm/model_executor/layers/quantization/kernels/exllama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
from functools import partial
from typing import Optional, Tuple

import torch

from vllm import _custom_ops as ops
from vllm.scalar_type import scalar_types
from vllm.model_executor.layers.quantization.utils.machete_utils import (
MACHETE_SUPPORTED_GROUP_SIZES, check_machete_supports_shape,
query_machete_supported_quant_types)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
pack_quantized_values_into_int32, unpack_quantized_values_into_int32)
from vllm.model_executor.parameter import (BasevLLMParameter,
permute_param_layout_)

from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig


class ExllamaLinearKernel(MPLinearKernel):
SUPPORTED_QUANT_TYPES = [
scalar_types.uint2b2,
scalar_types.uint3b4,
scalar_types.uint4b8,
scalar_types.uint8b128
]

@classmethod
def get_min_capability(cls) -> int:
return 60

@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
if c.has_g_idx and\
c.partition_weight_shape[0] != c.full_weight_shape[0]:
return False, "Act reordering currently not supported by Exllama, "\
"when the input features are partitioned across "\
"devices"

if c.act_type != torch.float16:
return False, "Exllama only supports float16 activations"

if c.weight_type not in cls.SUPPORTED_QUANT_TYPES:
return False, f"Quant type ({c.weight_type}) not supported by "\
"Exllama, supported types are: "\
f"{cls.SUPPORTED_QUANT_TYPES}"

if c.full_weight_shape[0] % c.group_size != 0:
return False, f"Group size ({c.group_size}) does not evenly divide"\
" the number of input features "\
f"({c.full_weight_shape[0]})"

return True, None

# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
def process_weights_after_loading(self, layer: torch.nn.Module):
c = self.config

# For Exllama, we need to set a zero-point tensor if there is not one
if not c.zero_points:
self.w_zp_name = "qzeros"
device = getattr(layer, self.w_q_name).device
groups = c.full_weight_shape[0] // c.group_size
out_features = c.partition_weight_shape[1]

if c.weight_type.has_bias():
# if the type has a bias we have to create a zeros tensor that
# contains the bias values repeated for each group (-1 due to
# a bug in the original GPTQ checkpoint format leading to
# exllama kernel adding 1 to the zero points during inference)
# Documentation of the bug can be found here:
# https://garden.danieldk.eu/GPTQ-Checkpoint-Format
zeros = torch.full(
(groups, out_features),
c.weight_type.bias - 1,
dtype=torch.int32,
device=device)
else:
raise NotImplementedError(
"A 0 zero-point is not supported by Exllama due to "
"a bug in the original GPTQ checkpoint format leading to "
"exllama kernel adding 1 to the zero points during "
"inference")
print("zeros", zeros.shape)
zeros = pack_quantized_values_into_int32(zeros, c.weight_type, packed_dim=1)
print("zeros_packed", zeros.shape)
setattr(layer, self.w_zp_name, torch.nn.Parameter(zeros, requires_grad=False))

if c.has_g_idx:
def transform_w_g_idx(x):
# Exllama wants the permutation array instead of the group
# incdices
LucasWilkinson marked this conversation as resolved.
Show resolved Hide resolved
return torch.argsort(x).to(torch.int)
self._transform_param(layer, self.w_gidx_name, transform_w_g_idx)
else:
self.w_gidx_name = "g_idx"
empty_g_idx = torch.nn.Parameter(
torch.empty((0, ),
dtype=torch.int,
device=device),
requires_grad=False)
setattr(layer, self.w_gidx_name, empty_g_idx)

def transform_w_q(x):
assert isinstance(x, BasevLLMParameter)
assert self.w_gidx_name is not None
g_idx = getattr(layer, self.w_gidx_name)

permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
x_cont = x.data.contiguous()
ops.gptq_shuffle(x_cont, g_idx, c.weight_type.size_bits)
return x_cont

def transform_w_s(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1)
x.data = x.data.contiguous()
return x.to(dtype=c.act_type)

# Repack weights and scales for Machete
self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s)

def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
c = self.config

x_2d = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )

w_q, w_s, w_zp, w_g_idx = self._get_weight_params(layer)

#print(w_q.shape, w_s.shape, w_zp.shape, w_g_idx.shape)
LucasWilkinson marked this conversation as resolved.
Show resolved Hide resolved

assert w_zp is not None, "Zero points are not supported by Exllama"
assert w_g_idx is not None, "Group index is required by Exllama"
output = ops.gptq_gemm(
x_2d, w_q, w_zp, w_s, w_g_idx, True, c.weight_type.size_bits)

if bias is not None:
output.add_(bias)
return output.reshape(out_shape)
14 changes: 7 additions & 7 deletions vllm/model_executor/layers/quantization/kernels/machete.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
MACHETE_SUPPORTED_GROUP_SIZES, check_machete_supports_shape,
query_machete_supported_quant_types)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
pack_weights_into_int32, unpack_weights_into_int32)
pack_quantized_values_into_int32, unpack_quantized_values_into_int32)
from vllm.model_executor.parameter import (BasevLLMParameter,
permute_param_layout_)

Expand Down Expand Up @@ -71,13 +71,13 @@ def transform_w_q(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
if c.has_g_idx:
x_unpacked = unpack_weights_into_int32(x.data,
c.weight_type,
packed_dim=0)
x_unpacked = unpack_quantized_values_into_int32(x.data,
c.weight_type,
packed_dim=0)
x_perm = x_unpacked[perm, :]
x.data = pack_weights_into_int32(x_perm,
c.weight_type,
packed_dim=0)
x.data = pack_quantized_values_into_int32(x_perm,
c.weight_type,
packed_dim=0)
x.data = ops.machete_prepack_B(x.data.t().contiguous().t(),
self.config.weight_type)
return x
Expand Down
12 changes: 6 additions & 6 deletions vllm/model_executor/layers/quantization/utils/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
}


def pack_weights_into_int32(w_q: torch.Tensor,
wtype: ScalarType,
packed_dim: int = 0):
def pack_quantized_values_into_int32(w_q: torch.Tensor,
wtype: ScalarType,
packed_dim: int = 0):
# move dim to pack to the end
perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
inv_perm = tuple(perm.index(i) for i in range(len(perm)))
Expand All @@ -42,9 +42,9 @@ def pack_weights_into_int32(w_q: torch.Tensor,
return res.permute(inv_perm)


def unpack_weights_into_int32(w_q: torch.Tensor,
wtype: ScalarType,
packed_dim: int = 0):
def unpack_quantized_values_into_int32(w_q: torch.Tensor,
wtype: ScalarType,
packed_dim: int = 0):
# move dim to pack to the end
perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
inv_perm = tuple(perm.index(i) for i in range(len(perm)))
Expand Down
2 changes: 2 additions & 0 deletions vllm/scalar_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class scalar_types:
float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE.value)

# "gptq" types
uint2b2 = ScalarType.uint(2, 2)
uint3b4 = ScalarType.uint(3, 4)
uint4b8 = ScalarType.uint(4, 8)
uint8b128 = ScalarType.uint(8, 128)

Expand Down
Loading