Skip to content

Commit

Permalink
fix the issue of qlinear packing being too slow. (AutoGPTQ#770)
Browse files Browse the repository at this point in the history
The `for` loop operation in `pack` function is too slowly, replace it with tensor operation.
  • Loading branch information
Marxlp authored Jan 21, 2025
1 parent c039255 commit 191bd77
Showing 1 changed file with 1 addition and 8 deletions.
9 changes: 1 addition & 8 deletions auto_gptq/nn_modules/qlinear/qlinear_exllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,7 @@ def pack(self, linear, scales, zeros, g_idx=None):
if linear.bias is not None:
self.bias = linear.bias.clone().half()

intweight = []
for idx in range(self.infeatures):
intweight.append(
torch.round((W[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[
:, None
]
)
intweight = torch.cat(intweight, dim=1)
intweight = torch.round((W + scale_zeros[g_idx].T) / scales[g_idx].T).to(torch.int)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(np.uint32)

Expand Down

0 comments on commit 191bd77

Please sign in to comment.