-
Notifications
You must be signed in to change notification settings - Fork 294
Fix int4pack_mm error #517
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -17,6 +17,7 @@ | |||||
dequantize_affine, | ||||||
int_scaled_matmul, | ||||||
) | ||||||
from torchao.utils import TORCH_VERSION_AFTER_2_5 | ||||||
|
||||||
__all__ = [ | ||||||
"compute_error", | ||||||
|
@@ -349,6 +350,8 @@ def groupwise_affine_quantize_tensor_from_qparams( | |||||
quant_max = 2 ** n_bit - 1 | ||||||
|
||||||
int_data = quantize_affine(w, block_size, scales, zeros, output_dtype, quant_min, quant_max, zero_point_domain = ZeroPointDomain.FLOAT) | ||||||
if TORCH_VERSION_AFTER_2_5: | ||||||
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should break on MPS backend, since There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @malfet landed pytorch/pytorch#131813, so this won't be a problem anymore There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In any case, I learned from @malfet today (see his suggestion on line 203) that if instead of using << in here, we use torch.bitwise_left_shift(x, 4), it would be falling back to cpu. So, things would work even prior to his PR having landed, if torch.bitwise_left_shift is used instead of << There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the clarification. With pytorch/pytorch#131813, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
return int_data | ||||||
|
||||||
def groupwise_affine_dequantize_tensor_from_qparams( | ||||||
|
@@ -359,18 +362,26 @@ def groupwise_affine_dequantize_tensor_from_qparams( | |||||
groupsize=128, | ||||||
): | ||||||
assert groupsize > 1 | ||||||
# needed for GPTQ single column dequantize | ||||||
if groupsize > w_int4x8.shape[-1] and scales.shape[-1] == 1: | ||||||
groupsize = w_int4x8.shape[-1] | ||||||
assert w_int4x8.shape[-1] % groupsize == 0 | ||||||
assert w_int4x8.dim() == 2 | ||||||
if TORCH_VERSION_AFTER_2_5: | ||||||
data = w_int4x8.to(torch.int32) | ||||||
high_bits = data >> 4 | ||||||
low_bits = data & 0x0F | ||||||
w_int32 = torch.zeros((w_int4x8.shape[0], w_int4x8.shape[1] * 2), dtype=torch.int32, device=w_int4x8.device) | ||||||
w_int32[::, ::2] = high_bits | ||||||
w_int32[::, 1::2] = low_bits | ||||||
else: | ||||||
w_int32 = w_int4x8 | ||||||
|
||||||
# needed for GPTQ single column dequantize | ||||||
if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1: | ||||||
groupsize = w_int32.shape[-1] | ||||||
assert w_int32.shape[-1] % groupsize == 0 | ||||||
block_size = (1, groupsize) | ||||||
input_dtype = torch.int32 | ||||||
quant_min = 0 | ||||||
quant_max = 2**n_bit - 1 | ||||||
return dequantize_affine(w_int4x8, block_size, scales, zeros, input_dtype, quant_min, quant_max, zero_point_domain=ZeroPointDomain.FLOAT, output_dtype=scales.dtype) | ||||||
|
||||||
return dequantize_affine(w_int32, block_size, scales, zeros, input_dtype, quant_min, quant_max, zero_point_domain=ZeroPointDomain.FLOAT, output_dtype=scales.dtype) | ||||||
|
||||||
def groupwise_affine_quantize_tensor(w, n_bit=4, groupsize=128, dtype=torch.bfloat16): | ||||||
scales, zeros = get_groupwise_affine_qparams(w, n_bit, groupsize, dtype) | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.