Skip to content

Commit f52d3ab

Browse files
authored
[resubmit] Gemlite fix (#1435)
* [resubmit] Gemlite fix Summary: Resubmitting #1432 since it has some rebase issues and we want to merge the fix asap Test Plan: see #1432 Reviewers: Subscribers: Tasks: Tags: * ruff
1 parent ec64182 commit f52d3ab

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

test/integration/test_integration.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -958,11 +958,20 @@ def test_gemlite_layout(self, device, dtype):
958958
self._test_lin_weight_subclass_api_impl(
959959
api,
960960
device,
961-
15,
961+
15,
962962
test_shape=test_shape,
963963
test_dtype=dtype,
964964
)
965965

966+
# test that shapes with non divisible by 128 shapes aren't causing errors
967+
self._test_lin_weight_subclass_api_impl(
968+
lambda mod: quantize_(mod, gemlite_uintx_weight_only(None, 4, 32)),
969+
device,
970+
15,
971+
test_shape=[1, 1025, 513],
972+
test_dtype=dtype,
973+
)
974+
966975

967976
@parameterized.expand(COMMON_DEVICE_DTYPE)
968977
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")

torchao/dtypes/uintx/gemlite_layout.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from dataclasses import dataclass
23
from typing import Dict, Optional, Tuple
34

@@ -76,6 +77,15 @@ def apply_gemlite_quant(
7677
out_features, in_features = weight.shape
7778
group_size = in_features if group_size is None else group_size
7879

80+
if in_features % 128 != 0 and out_features % 128 != 0:
81+
warnings.simplefilter("once", UserWarning)
82+
warnings.warn(
83+
"Gemlite only works for layers with in_features or out_features divisible by 128, "
84+
+ "some layers have been skipped",
85+
UserWarning,
86+
)
87+
return weight
88+
7989
quant_kwargs = get_gemlite_quant_kwargs(bit_width, group_size)
8090

8191
layout = GemlitePackedLayout(
@@ -173,6 +183,11 @@ def from_plain(
173183
exhaustive=False,
174184
use_cuda_graph=False,
175185
)
186+
if _layout.group_size is None and _layout.bit_width == 4:
187+
from gemlite.core import GEMLITE_ACC_DTYPE
188+
from gemlite.dtypes import DType
189+
190+
GEMLITE_ACC_DTYPE[DType.FP16] = DType.FP32
176191

177192
out_features, in_features = int_data.shape
178193
input_dtype, output_dtype = DType.FP16, DType.FP16

0 commit comments

Comments
 (0)