Skip to content

Commit

Permalink
[FIX] Marlin padding (ModelCloud#183)
Browse files Browse the repository at this point in the history
* marlin padding

* fix padded workspace

* use self.infeatures and self.outfeatures

* use self.original_infeatures/self.original_outfeatures

* crop C

* fix marlin padding

* cleanup

---------

Co-authored-by: LRL-ModelCloud <lrl@modelcloud.ai>
  • Loading branch information
LRL-ModelCloud and LRL-ModelCloud authored Jul 9, 2024
1 parent 815672f commit b95d490
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 16 deletions.
55 changes: 42 additions & 13 deletions gptqmodel/nn_modules/qlinear/qlinear_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import gptqmodel_marlin_cuda
import numpy as np
import torch
import torch.nn.functional as F
from gptqmodel.nn_modules.qlinear import BaseQuantLinear

logger = getLogger(__name__)
Expand All @@ -27,8 +28,6 @@ def mul(A, B, C, s, workspace, thread_k=-1, thread_n=-1, sms=-1, max_par=16):


# Precompute permutations for Marlin weight and scale shuffling


def _get_perms():
perm = []
for i in range(32):
Expand Down Expand Up @@ -70,30 +69,39 @@ class MarlinQuantLinear(BaseQuantLinear):
def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures: int, outfeatures: int,
bias: bool, **kwargs):
super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, **kwargs)

if not torch.cuda.get_device_capability()[0] >= 8:
raise ValueError(
f'Can not use Marlin int4*fp16 kernel with a device of compute capability {torch.cuda.get_device_capability()}, the minimum compute capability is 8.0 for Marlin kernel. Please do not use `backend=Backend.MARLIN`, or please upgrade your GPU ("The more you buy, the more you save." - Taiwanese proverb).'
)

if infeatures % 128 != 0 or outfeatures % 256 != 0:
raise ValueError("`infeatures` must be divisible by 128 and `outfeatures` by 256.")
# if infeatures % 128 != 0 or outfeatures % 256 != 0:
# raise ValueError("`infeatures` must be divisible by 128 and `outfeatures` by 256.")
if group_size not in [-1, 128] and group_size != infeatures:
raise ValueError("Only group_size -1 and 128 are supported.")
# Marlin groups infeatures according to group_size, so infeatures must be an integer multiple of group_size.
if infeatures % group_size != 0:
raise ValueError("`infeatures` must be divisible by `group_size`.")
# # Marlin groups infeatures according to group_size, so infeatures must be an integer multiple of group_size.
# if infeatures % group_size != 0:
# raise ValueError("`infeatures` must be divisible by `group_size`.")


self.original_infeatures = infeatures
self.original_outfeatures = outfeatures

self.infeatures = infeatures + (-infeatures % 128)
self.outfeatures = outfeatures + (-outfeatures % 256)

self.infeatures = infeatures
self.outfeatures = outfeatures
self.group_size = group_size if group_size != -1 else infeatures

del infeatures
del outfeatures
del group_size

self.register_buffer(
"B",
torch.empty((self.infeatures // 16, self.outfeatures * 16 // 8), dtype=torch.int),
)
self.register_buffer(
"s",
torch.empty((self.infeatures // group_size, self.outfeatures), dtype=torch.half),
torch.empty((self.infeatures // self.group_size, self.outfeatures), dtype=torch.half),
)
# 128 is currently the minimum `tile_n`, hence it gives the maximum workspace size; 16 is the default `max_par`
self.register_buffer(
Expand All @@ -102,7 +110,7 @@ def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeat
persistent=False,
)
if bias:
self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.half))
self.register_buffer("bias", torch.zeros((self.outfeatures), dtype=torch.half))
else:
self.bias = None

Expand All @@ -122,6 +130,16 @@ def pack(self, linear, scales):
maxq = 2**4 - 1
s = scales.t()
w = linear.weight.data.t()

if self.infeatures != self.original_infeatures or self.outfeatures != self.original_outfeatures:
padded_w = torch.zeros((self.infeatures, self.outfeatures))
padded_w[:w.size(0), :w.size(1)] = w
w = padded_w

padded_s = torch.zeros((s.size(0), self.outfeatures))
padded_s[:s.size(0), :s.size(1)] = s
s = padded_s

if self.group_size != self.infeatures:
w = w.reshape((-1, self.group_size, self.outfeatures))
w = w.permute(1, 0, 2)
Expand All @@ -148,8 +166,10 @@ def pack(self, linear, scales):
for i in range(8):
q |= res[:, i::8] << 4 * i
q = torch.from_numpy(q.astype(np.int32)).to(w.device)

self.B[:, :] = q.to(self.B.device)
self.s[:, :] = s.to(self.s.device)

if linear.bias is not None:
if self.bias is not None:
self.bias[:] = linear.bias.data.to(self.bias.device)
Expand All @@ -158,7 +178,13 @@ def pack(self, linear, scales):

def forward(self, A):
A = A.half()
if A.size(-1) != self.infeatures:
A = F.pad(A, (0, self.infeatures - self.original_infeatures))

C = torch.empty(A.shape[:-1] + (self.s.shape[1],), dtype=A.dtype, device=A.device)
if C.size(-1) != self.outfeatures:
C = F.pad(C, (0, self.outfeatures - self.original_outfeatures))

mul(
A.view((-1, A.shape[-1])),
self.B,
Expand All @@ -167,12 +193,15 @@ def forward(self, A):
self.workspace,
)
C = C + self.bias if self.bias is not None else C

if self.outfeatures != self.original_outfeatures:
C = C[:, :, :self.original_outfeatures]

return C

def post_init(self):
self.validate_device(self.B.device.type)


# Copied from https://github.com/IST-DASLab/marlin/pull/1
@torch.no_grad()
def unpack_4bit_to_32bit_signed(qweight, qzeros):
Expand Down
6 changes: 3 additions & 3 deletions gptqmodel/utils/marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,13 @@ def convert_to_marlin(
group_size=module.group_size,
sym=sym,
desc_act=desc_act,
infeatures=module.infeatures,
outfeatures=module.outfeatures,
infeatures=module.original_infeatures,
outfeatures=module.original_outfeatures,
bias=module.bias is not None,
)

# workspace is never in the state_dict, thus we need to allocate it manually.
new_module.workspace = torch.zeros(module.outfeatures // 128 * 16, dtype=torch.int, device=module.device)
new_module.workspace = torch.zeros(new_module.outfeatures // 128 * 16, dtype=torch.int, device=module.device)

# Dequantize the weight.
if repack:
Expand Down

0 comments on commit b95d490

Please sign in to comment.