Skip to content

Fix int4pack_mm error #517

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

class TestAffineQuantized(TestCase):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
# @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
def test_tensor_core_layout_transpose(self):
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
t = l.weight
Expand Down
16 changes: 8 additions & 8 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ def test_dequantize_int8_weight_only_quant_subclass(self, device, dtype):

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
# @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype):
if dtype != torch.bfloat16:
self.skipTest("Currently only supports bfloat16.")
Expand All @@ -642,7 +642,7 @@ def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype):

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
# @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
def test_dequantize_int4_weight_only_quant_subclass_grouped(self, device, dtype):
if dtype != torch.bfloat16:
self.skipTest("Currently only supports bfloat16.")
Expand Down Expand Up @@ -737,7 +737,7 @@ def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype):

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
# @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
def test_int4_weight_only_quant_subclass(self, device, dtype):
if dtype != torch.bfloat16:
self.skipTest(f"Fails for {dtype}")
Expand All @@ -748,7 +748,7 @@ def test_int4_weight_only_quant_subclass(self, device, dtype):

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
# @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
def test_int4_weight_only_quant_subclass_grouped(self, device, dtype):
if dtype != torch.bfloat16:
self.skipTest(f"Fails for {dtype}")
Expand Down Expand Up @@ -823,7 +823,7 @@ def test_int8_weight_only_quant_with_freeze(self, device, dtype):

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
# @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
def test_int4_weight_only_quant_subclass_api(self, device, dtype):
if dtype != torch.bfloat16:
self.skipTest(f"Fails for {dtype}")
Expand All @@ -838,7 +838,7 @@ def test_int4_weight_only_quant_subclass_api(self, device, dtype):

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
# @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
if dtype != torch.bfloat16:
self.skipTest(f"Fails for {dtype}")
Expand Down Expand Up @@ -1028,7 +1028,7 @@ def test_save_load_int8woqtensors(self, device, dtype):

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch 2.3+.")
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 doesn't work for 2.5+ right now")
# @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 doesn't work for 2.5+ right now")
@torch.no_grad()
def test_save_load_int4woqtensors(self, device, dtype):
if dtype != torch.bfloat16:
Expand Down Expand Up @@ -1488,7 +1488,7 @@ def test_get_model_size_autoquant(self, device, dtype):
@parameterized.expand(
list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)),
)
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
# @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
def test_get_model_size_aqt(self, api, test_device, test_dtype):
if test_dtype != torch.bfloat16:
self.skipTest(f"{api} in {test_dtype} is not supported yet")
Expand Down
2 changes: 1 addition & 1 deletion test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ def test_quantized_tensor_subclass_8da4w(self):
self.assertTrue(torch.equal(res, ref))

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "Test currently doesn't work for 2.5+")
# @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "Test currently doesn't work for 2.5+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_quantized_tensor_subclass_int4(self):
# use 1024 so that we don't need padding
Expand Down
9 changes: 8 additions & 1 deletion test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from torchao.utils import (
TORCH_VERSION_AFTER_2_3,
TORCH_VERSION_AFTER_2_4,
TORCH_VERSION_AFTER_2_5,
is_fbcode,
)

Expand Down Expand Up @@ -99,6 +100,8 @@ def _groupwise_affine_quantize_tensor_from_qparams(
.to(torch.int32)
.reshape_as(w)
)
if TORCH_VERSION_AFTER_2_5:
w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8)

return w_int4x8

Expand Down Expand Up @@ -500,7 +503,11 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self):
n_bit = 4
groupsize = 128

w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)
if TORCH_VERSION_AFTER_2_5:
input_uint8 = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8)
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input_uint8, scales, zeros, n_bit, groupsize)
else:
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)
w_bf16_ref = _groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)

self.assertTrue(torch.equal(w_bf16, w_bf16_ref))
Expand Down
21 changes: 16 additions & 5 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,20 +95,24 @@ def test_quant_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK):
TEST_CONFIGS_DEQUANT = list(itertools.product(SHAPES, INNERKTILES, QGROUP_SIZES))

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
# @pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
@pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK, ids=str)
def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles):
N, K = shape
assert K % (inner_k_tiles * kTileSizeK) == 0 and N % kTileSizeN == 0

t = torch.randint(0, 16, dtype=torch.int, size=shape, device="cuda")
if TORCH_VERSION_AFTER_2_5:
t = (t[::, ::2] << 4 | t[::, 1::2]).to(torch.uint8)
packed_w = torch.ops.aten._convert_weight_to_int4pack(t, inner_k_tiles)
unpacked = torchao.ops.unpack_tensor_core_tiled_layout(packed_w, inner_k_tiles)
if TORCH_VERSION_AFTER_2_5:
unpacked = (unpacked[::, ::2] << 4 | unpacked[::, 1::2]).to(torch.uint8)
assert torch.equal(t, unpacked)

# TODO: Fix "test_aot_dispatch_dynamic" test failure
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
# @pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
@pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK , ids=str)
def test_unpack_tensor_core_tiled_layout_op(shape, inner_k_tiles):
test_utils = [
Expand All @@ -122,6 +126,8 @@ def test_unpack_tensor_core_tiled_layout_op(shape, inner_k_tiles):
test_utils.append("test_aot_dispatch_dynamic")

t = torch.randint(0, 16, dtype=torch.int, size=shape, device="cuda")
if TORCH_VERSION_AFTER_2_5:
t = (t[::, ::2] << 4 | t[::, 1::2]).to(torch.uint8)
packed_w = torch.ops.aten._convert_weight_to_int4pack(t, inner_k_tiles)

opcheck(
Expand Down Expand Up @@ -151,7 +157,7 @@ def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16):


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
# @pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
@pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str)
def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant(shape, inner_k_tiles, group_size):
n, k = shape
Expand Down Expand Up @@ -210,7 +216,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant(shape, in

# This test differs from one above in that it uses `unpack_tensor_core_tiled_layout` to unpack then dequantize
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
# @pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
@pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str)
def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(shape, inner_k_tiles, group_size):
n, k = shape
Expand All @@ -229,6 +235,9 @@ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(shap

# Unpack and dequantize
unpacked = torchao.ops.unpack_tensor_core_tiled_layout(packed, inner_k_tiles)
if TORCH_VERSION_AFTER_2_5:
unpacked = (unpacked[::, ::2] << 4 | unpacked[::, 1::2]).to(torch.uint8)

dq_ao = groupwise_affine_dequantize_tensor_from_qparams(
unpacked, scales, zeros, n_bit=4, groupsize=group_size
)
Expand Down Expand Up @@ -264,13 +273,15 @@ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(shap
assert diff_op_ao < 1e-1

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
# @pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
@pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str)
def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size):
n, k = shape
device = "cuda"

q = torch.randint(0, 16, shape, dtype=torch.int, device=device)
if TORCH_VERSION_AFTER_2_5:
q = (q[::, ::2] << 4 | q[::, 1::2]).to(torch.uint8)
packed_w = torch._convert_weight_to_int4pack(q, inner_k_tiles)
q_groups = k // group_size
scales = torch.randn(n, q_groups, dtype=torch.bfloat16, device=device)
Expand Down
11 changes: 7 additions & 4 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
)
from typing import ClassVar
from dataclasses import dataclass
from torchao.utils import TORCH_VERSION_AFTER_2_5

aten = torch.ops.aten

Expand Down Expand Up @@ -245,7 +246,6 @@ def from_float(

scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain)
int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain)

int_data = layout_type.post_process(int_data)

layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type))
Expand Down Expand Up @@ -570,9 +570,12 @@ def from_plain(
layout_type: LayoutType
):
assert isinstance(layout_type, TensorCoreTiledLayoutType)
# assert int_data.dtype == torch.uint8, "torch.ops.aten._convert_weight_to_int4pack expects `uint8` dtype"
# packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data, inner_k_tiles)
packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data.to(torch.int32), layout_type.inner_k_tiles)
if TORCH_VERSION_AFTER_2_5:
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8)
assert int_data.dtype == torch.uint8, "torch.ops.aten._convert_weight_to_int4pack in torch 2.5 expects `uint8` dtype"
else:
assert int_data.dtype == torch.int32, "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype"
packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data, layout_type.inner_k_tiles)
scale = scale.reshape(int_data.shape[0], -1)
zero_point = zero_point.reshape(int_data.shape[0], -1)
scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point)
Expand Down
3 changes: 3 additions & 0 deletions torchao/prototype/hqq/hqq_tinygemm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from hqq.core.utils import *

import torch.nn.functional as F
from torchao.utils import TORCH_VERSION_AFTER_2_5


class HQQLinearTorchWeightOnlyInt4(torch.nn.Module):
Expand Down Expand Up @@ -198,6 +199,8 @@ def hqq_quants_to_torch_quants(
.reshape(shape)
.contiguous()
)
if TORCH_VERSION_AFTER_2_5:
W_q = (W_q[::, ::2] << 4 | W_q[::, 1::2]).to(torch.uint8)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
W_q = (W_q[::, ::2] << 4 | W_q[::, 1::2]).to(torch.uint8)
W_q = (torch.bitwise_left_shift(W_q[::, ::2], 4) | W_q[::, 1::2]).to(torch.uint8)


# group_dequantize_tensor_from_qparams
# W_r = W_q*scales + min_val
Expand Down
23 changes: 17 additions & 6 deletions torchao/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
dequantize_affine,
int_scaled_matmul,
)
from torchao.utils import TORCH_VERSION_AFTER_2_5

__all__ = [
"compute_error",
Expand Down Expand Up @@ -349,6 +350,8 @@ def groupwise_affine_quantize_tensor_from_qparams(
quant_max = 2 ** n_bit - 1

int_data = quantize_affine(w, block_size, scales, zeros, output_dtype, quant_min, quant_max, zero_point_domain = ZeroPointDomain.FLOAT)
if TORCH_VERSION_AFTER_2_5:
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8)
Copy link
Contributor

@manuelcandales manuelcandales Jul 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should break on MPS backend, since __lshift__.Scalar is not currently implemented for MPS

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is int_data in MPS device in this function? If so, we can make int_data in cpu device, then convert back to MPS device.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@malfet landed pytorch/pytorch#131813, so this won't be a problem anymore

Copy link
Contributor

@manuelcandales manuelcandales Jul 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In any case, I learned from @malfet today (see his suggestion on line 203) that if instead of using << in here, we use torch.bitwise_left_shift(x, 4), it would be falling back to cpu. So, things would work even prior to his PR having landed, if torch.bitwise_left_shift is used instead of <<

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clarification. With pytorch/pytorch#131813, __lshift__.Scalar has MPS dispatch now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8)
int_data = (torch.bitwise_left_shift(int_data[::, ::2], 4) | int_data[::, 1::2]).to(torch.uint8)

return int_data

def groupwise_affine_dequantize_tensor_from_qparams(
Expand All @@ -359,18 +362,26 @@ def groupwise_affine_dequantize_tensor_from_qparams(
groupsize=128,
):
assert groupsize > 1
# needed for GPTQ single column dequantize
if groupsize > w_int4x8.shape[-1] and scales.shape[-1] == 1:
groupsize = w_int4x8.shape[-1]
assert w_int4x8.shape[-1] % groupsize == 0
assert w_int4x8.dim() == 2
if TORCH_VERSION_AFTER_2_5:
data = w_int4x8.to(torch.int32)
high_bits = data >> 4
low_bits = data & 0x0F
w_int32 = torch.zeros((w_int4x8.shape[0], w_int4x8.shape[1] * 2), dtype=torch.int32, device=w_int4x8.device)
w_int32[::, ::2] = high_bits
w_int32[::, 1::2] = low_bits
else:
w_int32 = w_int4x8

# needed for GPTQ single column dequantize
if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1:
groupsize = w_int32.shape[-1]
assert w_int32.shape[-1] % groupsize == 0
block_size = (1, groupsize)
input_dtype = torch.int32
quant_min = 0
quant_max = 2**n_bit - 1
return dequantize_affine(w_int4x8, block_size, scales, zeros, input_dtype, quant_min, quant_max, zero_point_domain=ZeroPointDomain.FLOAT, output_dtype=scales.dtype)

return dequantize_affine(w_int32, block_size, scales, zeros, input_dtype, quant_min, quant_max, zero_point_domain=ZeroPointDomain.FLOAT, output_dtype=scales.dtype)

def groupwise_affine_quantize_tensor(w, n_bit=4, groupsize=128, dtype=torch.bfloat16):
scales, zeros = get_groupwise_affine_qparams(w, n_bit, groupsize, dtype)
Expand Down
Loading