|
4 | 4 | import torch
|
5 | 5 | import torchao
|
6 | 6 |
|
7 |
| -from torch._inductor.test_case import TestCase |
8 | 7 | from torch.testing._internal import common_utils
|
9 | 8 | from torchao.dtypes import AffineQuantizedTensor
|
10 | 9 | from torchao.dtypes import to_affine_quantized_intx
|
@@ -69,7 +68,7 @@ def new_test(self, value=value):
|
69 | 68 |
|
70 | 69 |
|
71 | 70 |
|
72 |
| -class TorchAOBasicTestCase(TestCase): |
| 71 | +class TorchAOBasicTestCase(common_utils.TestCase): |
73 | 72 | """Basic test case for tensor subclasses
|
74 | 73 | """
|
75 | 74 | COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
|
@@ -143,6 +142,19 @@ def test_linear(self, device, dtype):
|
143 | 142 | lp_res = torch.nn.functional.linear(hp_act_tensor, lp_tensor)
|
144 | 143 | self.assertGreater(torchao.quantization.utils.compute_error(hp_res, lp_res), self.LINEAR_MIN_SQNR)
|
145 | 144 |
|
| 145 | + @common_utils.parametrize("device", COMMON_DEVICES) |
| 146 | + @common_utils.parametrize("dtype", COMMON_DTYPES) |
| 147 | + def test_linear_compile(self, device, dtype): |
| 148 | + hp_tensor = torch.randn(4, 128, device=device, dtype=dtype) |
| 149 | + lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) |
| 150 | + |
| 151 | + hp_act_tensor = torch.randn(32, 128, device=device, dtype=dtype) |
| 152 | + hp_res = torch.nn.functional.linear(hp_act_tensor, hp_tensor) |
| 153 | + l = torch.nn.Linear(128, 4, bias=False, device=device, dtype=dtype) |
| 154 | + l.weight = torch.nn.Parameter(lp_tensor) |
| 155 | + lp_res = torch.compile(l)(hp_act_tensor) |
| 156 | + self.assertGreater(torchao.quantization.utils.compute_error(hp_res, lp_res), self.LINEAR_MIN_SQNR) |
| 157 | + |
146 | 158 | common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase)
|
147 | 159 |
|
148 | 160 | if __name__ == "__main__":
|
|
0 commit comments