Skip to content

Commit 967e35a

Browse files
committed
[resubmit] Gemlite fix
Summary: Resubmitting #1432 since it has some rebase issues and we want to merge the fix asap Test Plan: see #1432 Reviewers: Subscribers: Tasks: Tags:
1 parent ec64182 commit 967e35a

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

test/integration/test_integration.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -958,11 +958,20 @@ def test_gemlite_layout(self, device, dtype):
958958
self._test_lin_weight_subclass_api_impl(
959959
api,
960960
device,
961-
15,
961+
15,
962962
test_shape=test_shape,
963963
test_dtype=dtype,
964964
)
965965

966+
# test that shapes with non divisible by 128 shapes aren't causing errors
967+
self._test_lin_weight_subclass_api_impl(
968+
lambda mod: quantize_(mod, gemlite_uintx_weight_only(None, 4, 32)),
969+
device,
970+
15,
971+
test_shape=[1, 1025, 513],
972+
test_dtype=dtype,
973+
)
974+
966975

967976
@parameterized.expand(COMMON_DEVICE_DTYPE)
968977
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")

torchao/dtypes/uintx/gemlite_layout.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from torchao.dtypes.utils import Layout, is_device
1616
from torchao.quantization.quant_primitives import quantize_affine
1717
from torchao.utils import fill_defaults
18+
import warnings
1819

1920
aten = torch.ops.aten
2021

@@ -76,6 +77,14 @@ def apply_gemlite_quant(
7677
out_features, in_features = weight.shape
7778
group_size = in_features if group_size is None else group_size
7879

80+
if in_features % 128 != 0 and out_features % 128 != 0:
81+
warnings.simplefilter("once", UserWarning)
82+
warnings.warn(
83+
"Gemlite only works for layers with in_features or out_features divisible by 128, "
84+
+ "some layers have been skipped", UserWarning
85+
)
86+
return weight
87+
7988
quant_kwargs = get_gemlite_quant_kwargs(bit_width, group_size)
8089

8190
layout = GemlitePackedLayout(
@@ -173,6 +182,10 @@ def from_plain(
173182
exhaustive=False,
174183
use_cuda_graph=False,
175184
)
185+
if _layout.group_size == None and _layout.bit_width == 4:
186+
from gemlite.core import GEMLITE_ACC_DTYPE
187+
from gemlite.dtypes import DType
188+
GEMLITE_ACC_DTYPE[DType.FP16] = DType.FP32
176189

177190
out_features, in_features = int_data.shape
178191
input_dtype, output_dtype = DType.FP16, DType.FP16

0 commit comments

Comments
 (0)