Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] Marlin runtime conversion padding #192

Merged
merged 7 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 2 additions & 10 deletions gptqmodel/nn_modules/qlinear/qlinear_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,6 @@ def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeat

self.group_size = group_size if group_size != -1 else infeatures

del infeatures
del outfeatures
del group_size

self.register_buffer(
"B",
torch.empty((self.infeatures // 16, self.outfeatures * 16 // 8), dtype=torch.int),
Expand Down Expand Up @@ -132,11 +128,11 @@ def pack(self, linear, scales):
w = linear.weight.data.t()

if self.infeatures != self.original_infeatures or self.outfeatures != self.original_outfeatures:
padded_w = torch.zeros((self.infeatures, self.outfeatures))
padded_w = torch.zeros((self.infeatures, self.outfeatures), dtype=w.dtype, device=w.device)
padded_w[:w.size(0), :w.size(1)] = w
w = padded_w

padded_s = torch.zeros((s.size(0), self.outfeatures))
padded_s = torch.zeros((s.size(0), self.outfeatures), dtype=torch.half, device=s.device)
padded_s[:s.size(0), :s.size(1)] = s
s = padded_s

Expand Down Expand Up @@ -184,10 +180,6 @@ def forward(self, A):
A = F.pad(A, (0, self.infeatures - self.original_infeatures))

C = torch.empty(A.shape[:-1] + (self.s.shape[1],), dtype=A.dtype, device=A.device)

if C.size(-1) != self.outfeatures:
C = F.pad(C, (0, self.outfeatures - self.original_outfeatures))

mul(
A.view((-1, A.shape[-1])),
self.B,
Expand Down
16 changes: 14 additions & 2 deletions gptqmodel/utils/marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,13 @@ def convert_to_marlin(
if repack:
import gptqmodel_marlin_cuda

marlin_repacked_weight = gptqmodel_marlin_cuda.gptq_repack(module.qweight)
qweight = module.qweight
if new_module.infeatures != new_module.original_infeatures or new_module.outfeatures != new_module.original_outfeatures:
padded_qweight = torch.zeros((new_module.infeatures, new_module.outfeatures), dtype=torch.int, device=module.qweight.device)
padded_qweight[:module.qweight.size(0), :module.qweight.size(1)] = qweight
qweight = padded_qweight

marlin_repacked_weight = gptqmodel_marlin_cuda.gptq_repack(qweight)

if strict:
dequantized_qzeros = unpack_qzeros(module.qzeros)
Expand All @@ -163,12 +169,18 @@ def convert_to_marlin(
_, _scale_perm, _scale_perm_single = _get_perms()

s = module.scales.data.clone()

if new_module.infeatures != new_module.original_infeatures or new_module.outfeatures != new_module.original_outfeatures:
padded_s = torch.zeros((s.size(0), new_module.outfeatures), dtype=torch.half, device=s.device)
padded_s[:s.size(0), :s.size(1)] = s
s = padded_s

if module.group_size != module.infeatures:
s = s.reshape((1, -1))
s = s.reshape((-1, len(_scale_perm)))[:, _scale_perm]
else:
s = s.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single]
s = s.reshape((-1, module.outfeatures)).contiguous()
s = s.reshape((-1, new_module.outfeatures)).contiguous()

new_module.B = marlin_repacked_weight
new_module.s = s
Expand Down