|
| 1 | +import unittest |
| 2 | +import functools |
| 3 | +import copy |
| 4 | +import torch |
| 5 | +import torchao |
| 6 | + |
| 7 | +from torch.testing._internal import common_utils |
| 8 | +from torchao.dtypes import AffineQuantizedTensor |
| 9 | +from torchao.dtypes import to_affine_quantized_intx |
| 10 | +from torchao.quantization.quant_primitives import MappingType |
| 11 | + |
| 12 | +""" |
| 13 | +How to use: |
| 14 | +
|
| 15 | +import unittest |
| 16 | +from torchao.testing.utils import TorchAOBasicTestCase, copy_tests |
| 17 | +from torch.testing._internal import common_utils |
| 18 | +
|
| 19 | +# TODO: currently there is no way to set COMMON_DEVICES/COMMON_DTYPES |
| 20 | +# we can figure out this a bit later |
| 21 | +
|
| 22 | +# change arguments |
| 23 | +class MyTestCase(TorchAOBasicTestCase): |
| 24 | + TENSOR_SUBCLASS = MyDTypeTensor |
| 25 | + FACTOR_FN = to_my_dtype |
| 26 | + kwargs = {"target_dtype": torch.uint8} |
| 27 | + LINEAR_MIN_SQNR = 30 |
| 28 | +
|
| 29 | +# copy the instantiated tests |
| 30 | +copy_tests(TorchAOBasicTestCase, MyTestCase, "my_test_case") |
| 31 | +
|
| 32 | +if __name__ == "__main__": |
| 33 | + unittest.main() |
| 34 | +""" |
| 35 | + |
| 36 | +# copied from https://github.com/pytorch/pytorch/blob/941d094dd1b507dacf06ddc6ed3485a9537e09b7/test/inductor/test_torchinductor.py#L11389 |
| 37 | +def copy_tests( |
| 38 | + my_cls, other_cls, suffix, test_failures=None, xfail_prop=None |
| 39 | +): # noqa: B902 |
| 40 | + for name, value in my_cls.__dict__.items(): |
| 41 | + if name.startswith("test_"): |
| 42 | + # You cannot copy functions in Python, so we use closures here to |
| 43 | + # create objects with different ids. Otherwise, unittest.skip |
| 44 | + # would modify all methods sharing the same object id. Also, by |
| 45 | + # using a default argument, we create a copy instead of a |
| 46 | + # reference. Otherwise, we would lose access to the value. |
| 47 | + |
| 48 | + @functools.wraps(value) |
| 49 | + def new_test(self, value=value): |
| 50 | + return value(self) |
| 51 | + |
| 52 | + # Copy __dict__ which may contain test metadata |
| 53 | + new_test.__dict__ = copy.deepcopy(value.__dict__) |
| 54 | + |
| 55 | + if xfail_prop is not None and hasattr(value, xfail_prop): |
| 56 | + new_test = unittest.expectedFailure(new_test) |
| 57 | + |
| 58 | + tf = test_failures and test_failures.get(name) |
| 59 | + if tf is not None and suffix in tf.suffixes: |
| 60 | + skip_func = ( |
| 61 | + unittest.skip("Skipped!") |
| 62 | + if tf.is_skip |
| 63 | + else unittest.expectedFailure |
| 64 | + ) |
| 65 | + new_test = skip_func(new_test) |
| 66 | + |
| 67 | + setattr(other_cls, f"{name}_{suffix}", new_test) |
| 68 | + |
| 69 | + |
| 70 | + |
| 71 | +class TorchAOBasicTestCase(common_utils.TestCase): |
| 72 | + """Basic test case for tensor subclasses |
| 73 | + """ |
| 74 | + COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) |
| 75 | + COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16] |
| 76 | + |
| 77 | + TENSOR_SUBCLASS = AffineQuantizedTensor |
| 78 | + FACTORY_FN = to_affine_quantized_intx |
| 79 | + kwargs = { |
| 80 | + "mapping_type": MappingType.ASYMMETRIC, |
| 81 | + "block_size": (1, 32), |
| 82 | + "target_dtype": torch.uint8, |
| 83 | + } |
| 84 | + # minimum sqnr for linear operation when the weight is quantized to low precision |
| 85 | + # with the above setting |
| 86 | + LINEAR_MIN_SQNR = 40 |
| 87 | + |
| 88 | + def test_flatten_unflatten(self): |
| 89 | + hp_tensor = torch.randn(4, 128) |
| 90 | + lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) |
| 91 | + tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__() |
| 92 | + tensor_data_dict = {name: getattr(lp_tensor, name) for name in tensor_data_name_dict} |
| 93 | + outer_size = lp_tensor.size() |
| 94 | + outer_stride = lp_tensor.stride() |
| 95 | + reconstructed = self.TENSOR_SUBCLASS.__tensor_unflatten__(tensor_data_dict, tensor_attributes, outer_size, outer_stride) |
| 96 | + self.assertEqual(lp_tensor.dequantize(), reconstructed.dequantize()) |
| 97 | + |
| 98 | + @common_utils.parametrize("device", COMMON_DEVICES) |
| 99 | + @common_utils.parametrize("dtype", COMMON_DTYPES) |
| 100 | + def test_hp_tensor_device_dtype(self, device, dtype): |
| 101 | + hp_tensor = torch.randn(4, 128, device=device, dtype=dtype) |
| 102 | + lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) |
| 103 | + |
| 104 | + @common_utils.parametrize("device1", COMMON_DEVICES) |
| 105 | + @common_utils.parametrize("device2", COMMON_DEVICES) |
| 106 | + def test_device1_to_device2(self, device1, device2): |
| 107 | + """Note: this should be parametrized with device1 and device2 |
| 108 | + e.g. device1 = ["cpu", "cuda"], device2 = ["cpu", "cuda"] |
| 109 | + """ |
| 110 | + hp_tensor = torch.randn(4, 128, device=device1, dtype=torch.bfloat16) |
| 111 | + lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) |
| 112 | + lp_tensor.to(device=device2) |
| 113 | + |
| 114 | + hp_tensor = torch.randn(4, 128, device=device1, dtype=torch.bfloat16) |
| 115 | + lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) |
| 116 | + lp_tensor.to(device2) |
| 117 | + |
| 118 | + hp_tensor = torch.randn(4, 128, device=device1, dtype=torch.bfloat16) |
| 119 | + lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) |
| 120 | + lp_tensor.cuda() |
| 121 | + |
| 122 | + hp_tensor = torch.randn(4, 128, device=device1, dtype=torch.bfloat16) |
| 123 | + lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) |
| 124 | + lp_tensor.cpu() |
| 125 | + |
| 126 | + @common_utils.parametrize("device", COMMON_DEVICES) |
| 127 | + @common_utils.parametrize("dtype", COMMON_DTYPES) |
| 128 | + def test_transpose(self, device, dtype): |
| 129 | + hp_tensor = torch.randn(4, 128, device=device, dtype=dtype) |
| 130 | + lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) |
| 131 | + lp_tensor = lp_tensor.t() |
| 132 | + self.assertEqual(lp_tensor.shape, (128, 4)) |
| 133 | + |
| 134 | + @common_utils.parametrize("device", COMMON_DEVICES) |
| 135 | + @common_utils.parametrize("dtype", COMMON_DTYPES) |
| 136 | + def test_linear(self, device, dtype): |
| 137 | + hp_tensor = torch.randn(4, 128, device=device, dtype=dtype) |
| 138 | + lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) |
| 139 | + |
| 140 | + hp_act_tensor = torch.randn(32, 128, device=device, dtype=dtype) |
| 141 | + hp_res = torch.nn.functional.linear(hp_act_tensor, hp_tensor) |
| 142 | + lp_res = torch.nn.functional.linear(hp_act_tensor, lp_tensor) |
| 143 | + self.assertGreater(torchao.quantization.utils.compute_error(hp_res, lp_res), self.LINEAR_MIN_SQNR) |
| 144 | + |
| 145 | + @common_utils.parametrize("device", COMMON_DEVICES) |
| 146 | + @common_utils.parametrize("dtype", COMMON_DTYPES) |
| 147 | + def test_linear_compile(self, device, dtype): |
| 148 | + hp_tensor = torch.randn(4, 128, device=device, dtype=dtype) |
| 149 | + lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) |
| 150 | + |
| 151 | + hp_act_tensor = torch.randn(32, 128, device=device, dtype=dtype) |
| 152 | + hp_res = torch.nn.functional.linear(hp_act_tensor, hp_tensor) |
| 153 | + l = torch.nn.Linear(128, 4, bias=False, device=device, dtype=dtype) |
| 154 | + l.weight = torch.nn.Parameter(lp_tensor) |
| 155 | + lp_res = torch.compile(l)(hp_act_tensor) |
| 156 | + self.assertGreater(torchao.quantization.utils.compute_error(hp_res, lp_res), self.LINEAR_MIN_SQNR) |
| 157 | + |
| 158 | +common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase) |
| 159 | + |
| 160 | +if __name__ == "__main__": |
| 161 | + unittest.main() |
0 commit comments