Skip to content

Enable Int4WeightOnlyGPTQQuantizer on Intel GPU. #2200

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ def _groupwise_affine_quantize_tensor_from_qparams(
if TORCH_VERSION_AT_LEAST_2_5:
if (not (check_cpu_version(w.device))) and (not (check_xpu_version(w.device))):
w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8)
if (check_xpu_version(w.device)):
w_int4x8 = (w_int4x8[::, 1::2] << 4 | w_int4x8[::, ::2]).to(torch.uint8)

return w_int4x8

Expand Down Expand Up @@ -752,6 +754,8 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self):
not (check_xpu_version(input.device))
):
input_tmp = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8)
if (check_xpu_version(input.device)):
input_tmp = (input[::, 1::2] << 4 | input[::, ::2]).to(torch.uint8)
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(
input_tmp, scales, zeros, n_bit, groupsize, zero_point_domain
)
Expand Down
9 changes: 5 additions & 4 deletions torchao/dtypes/uintx/int4_xpu_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,14 +246,15 @@ def from_plain(
):
assert isinstance(_layout, Int4XPULayout)

from torchao.quantization.utils import convert_weight_to_int4pack_xpu

if TORCH_VERSION_AT_LEAST_2_8:
assert int_data.dtype == torch.int32, (
"torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` dtype"
)
packed_weight = convert_weight_to_int4pack_xpu(
int_data, zero_point.dtype != scale.dtype
packed_weight = (int_data[::, 1::2] << 4 | int_data[::, ::2]).to(
torch.uint8
)
packed_weight = torch.ops.aten._convert_weight_to_int4pack(
packed_weight.contiguous(), 8
)
else:
assert False, "INT4 not supported on XPU until 2.8"
Expand Down
Loading