|
19 | 19 | from torchao.quantization.dynamic_quant import (
|
20 | 20 | DynamicallyPerAxisQuantizedLinear,
|
21 | 21 | )
|
| 22 | +from torchao.dtypes import TensorCoreTiledLayoutType |
22 | 23 | from torchao.quantization.quant_api import (
|
23 | 24 | int4_weight_only,
|
24 | 25 | int8_weight_only,
|
@@ -852,18 +853,20 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
|
852 | 853 | for test_shape in ([(256, 256, 16)] + ([(256, 256, 8)] if device=='cuda' else [])):
|
853 | 854 | for groupsize in [64, 32]:
|
854 | 855 | for inner_k_tiles in [4, 2]:
|
855 |
| - kwargs = {"groupsize": groupsize, "inner_k_tiles": inner_k_tiles} |
| 856 | + kwargs = {"groupsize": groupsize, "layout_type": TensorCoreTiledLayoutType(inner_k_tiles=inner_k_tiles)} |
856 | 857 |
|
857 | 858 | def api(mod):
|
| 859 | + kwargs_copy = kwargs.copy() |
858 | 860 | if TORCH_VERSION_AFTER_2_4:
|
859 |
| - kwargs_copy = kwargs.copy() |
860 | 861 | kwargs_copy["group_size"] = groupsize
|
861 | 862 | del kwargs_copy["groupsize"]
|
862 | 863 | quantize_(mod, int4_weight_only(**kwargs_copy))
|
863 | 864 | if not TORCH_VERSION_AFTER_2_5:
|
864 | 865 | unwrap_tensor_subclass(mod)
|
865 | 866 | else:
|
866 |
| - change_linear_weights_to_int4_woqtensors(mod, **kwargs) |
| 867 | + kwargs_copy["inner_k_tiles"] = inner_k_tiles |
| 868 | + del kwargs_copy["layout_type"] |
| 869 | + change_linear_weights_to_int4_woqtensors(mod, **kwargs_copy) |
867 | 870 |
|
868 | 871 | self._test_lin_weight_subclass_api_impl(
|
869 | 872 | api,
|
|
0 commit comments