Skip to content

Commit

Permalink
[FIX] Padding infeatures/outfeatures for exllama, exllama v2, and mar…
Browse files Browse the repository at this point in the history
…lin (ModelCloud#98)

* fix padding

* fix padding

* store original in/out features

* fix bad var reference

* shorter var name

* limit bitblas convert to use 1 thread

* ruff

* fix qlinear_exllama pack

* revert qliner_marlin change

* cleanup code

* plan b: init with original shape, then model load, then do padding/resize in post_init

* fix g_idx post_init

* const var reformat to all caps

* fix ( -> [

* padding the x that passes in forward

* comments/todo

* comments

---------

Co-authored-by: LRL-ModelCloud <lrl@modelcloud.ai>
  • Loading branch information
Qubitium and LRL-ModelCloud authored Jun 29, 2024
1 parent 5bf289a commit eef560e
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 78 deletions.
56 changes: 39 additions & 17 deletions gptqmodel/nn_modules/qlinear/qlinear_exllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
from gptqmodel.nn_modules.qlinear import BaseQuantLinear
from gptqmodel_exllama_kernels import make_q4, q4_matmul
Expand All @@ -14,12 +15,12 @@


# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device="meta")
NON_TENSOR = torch.empty((1, 1), device="meta")


def ext_make_q4(qweight, qzeros, scales, g_idx, device):
"""Construct Q4Matrix, return handle"""
return make_q4(qweight, qzeros, scales, g_idx if g_idx is not None else none_tensor, device)
return make_q4(qweight, qzeros, scales, g_idx if g_idx is not None else NON_TENSOR, device)


def ext_q4_matmul(x, q4, q4_width):
Expand All @@ -44,54 +45,70 @@ def __init__(self, bits: int, group_size: int , sym:bool, desc_act: bool, infeat
super().__init__()
self.validate(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act)

self.padding = -outfeatures % 32
self.outfeatures = outfeatures + self.padding
outfeatures = self.outfeatures

self.infeatures = infeatures
self.bits = bits
self.group_size = group_size if group_size != -1 else infeatures

# auto pad
self.outfeatures = outfeatures + (-outfeatures % 32)
self.infeatures = infeatures + (-infeatures % self.group_size)

# backup original values
self.original_outfeatures = outfeatures
self.original_infeatures = infeatures

self.maxq = 2**self.bits - 1

assert infeatures % 32 == 0
assert infeatures % self.group_size == 0
assert outfeatures % 32 == 0
assert self.infeatures % 32 == 0
assert self.outfeatures % 32 == 0

self.register_buffer(
"qweight",
torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32),
torch.zeros((self.original_infeatures // 32 * self.bits, self.original_outfeatures), dtype=torch.int32),
)
self.register_buffer(
"qzeros",
torch.zeros(
(
math.ceil(infeatures / self.group_size),
outfeatures // 32 * self.bits,
math.ceil(self.original_infeatures / self.group_size),
self.original_outfeatures // 32 * self.bits,
),
dtype=torch.int32,
),
)
self.register_buffer(
"scales",
torch.zeros(
(math.ceil(infeatures / self.group_size), outfeatures),
(math.ceil(self.original_infeatures / self.group_size), self.original_outfeatures),
dtype=torch.float16,
),
)
self.register_buffer(
"g_idx",
torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32),
torch.tensor([i // self.group_size for i in range(self.original_infeatures)], dtype=torch.int32),
)

if bias:
self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.float16))
self.register_buffer("bias", torch.zeros(self.original_outfeatures, dtype=torch.float16))
else:
self.bias = None

def post_init(self):
assert self.qweight.device.type == "cuda"
assert self.qweight.device.index is not None

# resize due to padding after model weights have been loaded
if self.outfeatures != self.original_outfeatures or self.infeatures != self.original_infeatures:
self.qweight.resize_(self.infeatures // 32 * self.bits, self.outfeatures)
self.qzeros.resize_(
math.ceil(self.infeatures / self.group_size),
self.outfeatures // 32 * self.bits
)
self.scales.resize_((math.ceil(self.infeatures / self.group_size), self.outfeatures),)
self.g_idx = torch.tensor([i // self.group_size for i in range(self.infeatures)], dtype=torch.int32, device=self.g_idx.device)
if self.bias is not None:
self.bias.resize_(self.outfeatures)


self.width = self.qweight.shape[1]

# make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx.
Expand Down Expand Up @@ -120,7 +137,7 @@ def pack(self, linear, scales, zeros, g_idx=None):
self.bias = linear.bias.clone().half()

intweight = []
for idx in range(self.infeatures):
for idx in range(self.original_infeatures):
intweight.append(
torch.round((W[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[
:, None
Expand Down Expand Up @@ -164,6 +181,11 @@ def forward(self, x):

x = x.half()

# TODO: need to run checks to make sure there is no performance regression padding with F.pad
# if infeatures is padded, we need to pad the input as well
if x.size(-1) != self.infeatures and self.infeatures > self.original_infeatures:
x = F.pad(x, (0, self.infeatures - self.original_infeatures))

out = ext_q4_matmul(x, self.q4, self.width)

if self.bias is not None:
Expand Down
75 changes: 48 additions & 27 deletions gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from logging import getLogger

import torch
import torch.nn.functional as F
from gptqmodel.nn_modules.qlinear import BaseQuantLinear
from gptqmodel_exllamav2_kernels import gemm_half_q_half, make_q_matrix

Expand All @@ -12,7 +13,7 @@


# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device="meta")
NONE_TENSOR = torch.empty((1, 1), device="meta")


def _torch_device(idx):
Expand Down Expand Up @@ -47,9 +48,9 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
w["q_scale"],
w["q_scale_max"],
w["q_groups"],
none_tensor,
none_tensor,
none_tensor,
NONE_TENSOR,
NONE_TENSOR,
NONE_TENSOR,
temp_dq,
)
# GPTQ
Expand All @@ -70,9 +71,9 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
w["qweight"],
w["q_perm"],
w["q_invperm"],
none_tensor,
none_tensor,
none_tensor,
NONE_TENSOR,
NONE_TENSOR,
NONE_TENSOR,
w["qzeros"],
w["scales"],
w["g_idx"].cpu(),
Expand All @@ -82,14 +83,14 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
else:
return make_q_matrix(
w["qweight"],
none_tensor,
none_tensor,
none_tensor,
none_tensor,
none_tensor,
NONE_TENSOR,
NONE_TENSOR,
NONE_TENSOR,
NONE_TENSOR,
NONE_TENSOR,
w["qzeros"],
w["scales"],
none_tensor,
NONE_TENSOR,
temp_dq,
)

Expand All @@ -108,54 +109,69 @@ def __init__(self, bits: int, group_size: int, sym: bool, desc_act: bool, infeat
self.q_handle = None
self.q_tensors = None

self.padding = -outfeatures % 32
self.outfeatures = outfeatures + self.padding
outfeatures = self.outfeatures

self.infeatures = infeatures
self.bits = bits
self.group_size = group_size if group_size != -1 else infeatures

# auto pad
self.outfeatures = outfeatures + (-outfeatures % 32)
self.infeatures = infeatures + (-infeatures % self.group_size)

# backup original values
self.original_outfeatures = outfeatures
self.original_infeatures = infeatures
self.maxq = 2**self.bits - 1

assert infeatures % 32 == 0
assert infeatures % self.group_size == 0
assert outfeatures % 32 == 0
assert self.infeatures % 32 == 0
assert self.outfeatures % 32 == 0

# I need to register the tensors, otherwise, we won't be able to load them easily using transformers ...
self.register_buffer(
"qweight",
torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32),
torch.zeros((self.original_infeatures // 32 * self.bits, self.original_outfeatures), dtype=torch.int32),
)
self.register_buffer(
"qzeros",
torch.zeros(
(
math.ceil(infeatures / self.group_size),
outfeatures // 32 * self.bits,
math.ceil(self.original_infeatures / self.group_size),
self.original_outfeatures // 32 * self.bits,
),
dtype=torch.int32,
),
)
self.register_buffer(
"scales",
torch.zeros(
(math.ceil(infeatures / self.group_size), outfeatures),
(math.ceil(self.original_infeatures / self.group_size), self.original_outfeatures),
dtype=torch.float16,
),
)
self.register_buffer(
"g_idx",
torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32),
torch.tensor([i // self.group_size for i in range(self.original_infeatures)], dtype=torch.int32),
)

if bias:
self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.float16))
self.register_buffer("bias", torch.zeros((self.original_outfeatures), dtype=torch.float16))
else:
self.bias = None

def post_init(self, temp_dq):
assert self.qweight.device.type == "cuda"
assert self.qweight.device.index is not None

# resize due to padding after model weights have been loaded
if self.outfeatures != self.original_outfeatures or self.infeatures != self.original_infeatures:
self.qweight.resize_(self.infeatures // 32 * self.bits, self.outfeatures)
self.qzeros.resize_(
math.ceil(self.infeatures / self.group_size),
self.outfeatures // 32 * self.bits
)
self.scales.resize_(math.ceil(self.infeatures / self.group_size), self.outfeatures)
self.g_idx = torch.tensor([i // self.group_size for i in range(self.infeatures)], dtype=torch.int32, device=self.g_idx.device)
if self.bias is not None:
self.bias.resize_(self.outfeatures)

self.q_tensors = {
"qweight": self.qweight,
"qzeros": self.qzeros,
Expand All @@ -173,6 +189,11 @@ def forward(self, x, force_cuda=False):

x = x.half()

# TODO: need to run checks to make sure there is no performance regression padding with F.pad
# if infeatures is padded, we need to pad the input as well
if x.size(-1) != self.infeatures and self.infeatures > self.original_infeatures:
x = F.pad(x, (0, self.infeatures - self.original_infeatures))

output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda)

if self.bias is not None:
Expand Down
1 change: 1 addition & 0 deletions gptqmodel/nn_modules/qlinear/qlinear_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def __init__(self, bits: int, group_size: int, sym: bool, desc_act: bool, infeat
raise ValueError("`infeatures` must be divisible by 128 and `outfeatures` by 256.")
if group_size not in [-1, 128] and group_size != infeatures:
raise ValueError("Only group_size -1 and 128 are supported.")
# Marlin groups infeatures according to group_size, so infeatures must be an integer multiple of group_size.
if infeatures % group_size != 0:
raise ValueError("`infeatures` must be divisible by `group_size`.")

Expand Down
71 changes: 37 additions & 34 deletions gptqmodel/utils/bitblas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from logging import getLogger

import accelerate
import threadpoolctl as tctl
import torch
from accelerate.utils import find_tied_parameters
from tqdm import tqdm
Expand Down Expand Up @@ -75,7 +76,7 @@ def prepare_model_for_bitblas_load(


@torch.no_grad()
def convert_to_bitblas(model, model_quantlinear, quantization_config: QuantizeConfig, sym: bool, desc_act: bool, repack: bool,
def convert_to_bitblas(model, model_quantlinear, quant_config: QuantizeConfig, sym: bool, desc_act: bool, repack: bool,
strict: bool = False):
"""
Converts GPTQ-packed weights to the Bitblas format.
Expand All @@ -90,40 +91,42 @@ def convert_to_bitblas(model, model_quantlinear, quantization_config: QuantizeCo
# TODO: load directly BitBLAS QuantLinear.
message = "Overriding QuantLinear layers to use BitBLAS's QuantLinear..."

for name, module in tqdm(model.named_modules(), desc=message, total=len(list(model.named_modules()))):
if not isinstance(module, model_quantlinear):
continue

parent_name = ".".join(name.split(".")[:-1])
layer_name = name[len(parent_name) + 1:]

# We could use `torch.count_nonzero(module.bias) > 0` here to discard zero bias, but this has issues when loading weights
# from checkpoints holding zero bias.
with torch.device("meta"):
bitblas_module = BitBLASQuantLinear(
bits=quantization_config.bits,
group_size=quantization_config.group_size,
sym=sym,
desc_act=desc_act,
infeatures=module.infeatures,
outfeatures=module.outfeatures,
bias=module.bias is not None,
enable_tuning=True
)

# Dequantize the weight.
if repack:
bitblas_module.repack_from_gptq(module)

# Save to parent.
parent_module = model.get_submodule(parent_name)
setattr(parent_module, layer_name, bitblas_module)

# Free cuda memory.
del module
gc.collect()
# TODO: need to benchmark to see multiple threads help with bitblas/tvm compilation and runtime
with tctl.threadpool_limits(limits=1):
for name, module in tqdm(model.named_modules(), desc=message, total=len(list(model.named_modules()))):
if not isinstance(module, model_quantlinear):
continue

parent_name = ".".join(name.split(".")[:-1])
layer_name = name[len(parent_name) + 1:]

# We could use `torch.count_nonzero(module.bias) > 0` here to discard zero bias, but this has issues when loading weights
# from checkpoints holding zero bias.
with torch.device("meta"):
bitblas_module = BitBLASQuantLinear(
bits=quant_config.bits,
group_size=quant_config.group_size,
sym=sym,
desc_act=desc_act,
infeatures=module.infeatures,
outfeatures=module.outfeatures,
bias=module.bias is not None,
enable_tuning=True
)

# Dequantize the weight.
if repack:
bitblas_module.repack_from_gptq(module)

# Save to parent.
parent_module = model.get_submodule(parent_name)
setattr(parent_module, layer_name, bitblas_module)

# Free cuda memory.
del module
gc.collect()

# Set quantization config to be BitBLAS.
quantization_config.format = FORMAT.BITBLAS
quant_config.format = FORMAT.BITBLAS

return model

0 comments on commit eef560e

Please sign in to comment.