Skip to content

Commit

Permalink
fix transpose 4bit (#1301)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiqing-feng authored Aug 2, 2024
1 parent a142f1e commit 1775035
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion bitsandbytes/backends/cpu_xpu_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def quantize_4bit_impl(
state.absmax = torch.Tensor()
return torch.Tensor(), state

return out, state
return out.unsqueeze(0), state


@_maybe_torch_compile
Expand Down Expand Up @@ -428,6 +428,13 @@ def dequantize_4bit_impl(
Dequantized tensor.
"""

if A.shape[0] == 1:
transpose = False
A = A.squeeze(0)
elif A.shape[1] == 1:
transpose = True
A = A.squeeze(1)

if quant_state is None:
assert absmax is not None and out is not None

Expand Down Expand Up @@ -484,6 +491,9 @@ def dequantize_4bit_impl(
out_reshaped[n - rem :] = out_dq[n - rem :] * absmax[-1]

# take transpose here because weight is transposed (again) for computation
if transpose:
out = out.t()

return out


Expand Down

0 comments on commit 1775035

Please sign in to comment.