Skip to content

Commit 6028093

Browse files
committed
chore: update tests
1 parent cd883d8 commit 6028093

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

test/integration/test_integration.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from torchao.quantization.dynamic_quant import (
2020
DynamicallyPerAxisQuantizedLinear,
2121
)
22+
from torchao.dtypes import TensorCoreTiledLayoutType
2223
from torchao.quantization.quant_api import (
2324
int4_weight_only,
2425
int8_weight_only,
@@ -852,18 +853,20 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
852853
for test_shape in ([(256, 256, 16)] + ([(256, 256, 8)] if device=='cuda' else [])):
853854
for groupsize in [64, 32]:
854855
for inner_k_tiles in [4, 2]:
855-
kwargs = {"groupsize": groupsize, "inner_k_tiles": inner_k_tiles}
856+
kwargs = {"groupsize": groupsize, "layout_type": TensorCoreTiledLayoutType(inner_k_tiles=inner_k_tiles)}
856857

857858
def api(mod):
859+
kwargs_copy = kwargs.copy()
858860
if TORCH_VERSION_AFTER_2_4:
859-
kwargs_copy = kwargs.copy()
860861
kwargs_copy["group_size"] = groupsize
861862
del kwargs_copy["groupsize"]
862863
quantize_(mod, int4_weight_only(**kwargs_copy))
863864
if not TORCH_VERSION_AFTER_2_5:
864865
unwrap_tensor_subclass(mod)
865866
else:
866-
change_linear_weights_to_int4_woqtensors(mod, **kwargs)
867+
kwargs_copy["inner_k_tiles"] = inner_k_tiles
868+
del kwargs_copy["layout_type"]
869+
change_linear_weights_to_int4_woqtensors(mod, **kwargs_copy)
867870

868871
self._test_lin_weight_subclass_api_impl(
869872
api,

0 commit comments

Comments
 (0)