Skip to content

Commit 75d9c90

Browse files
committed
don't use inductor TestCase
1 parent a9907bb commit 75d9c90

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

torchao/testing/utils.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import torch
55
import torchao
66

7-
from torch._inductor.test_case import TestCase
87
from torch.testing._internal import common_utils
98
from torchao.dtypes import AffineQuantizedTensor
109
from torchao.dtypes import to_affine_quantized_intx
@@ -69,7 +68,7 @@ def new_test(self, value=value):
6968

7069

7170

72-
class TorchAOBasicTestCase(TestCase):
71+
class TorchAOBasicTestCase(common_utils.TestCase):
7372
"""Basic test case for tensor subclasses
7473
"""
7574
COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
@@ -143,6 +142,19 @@ def test_linear(self, device, dtype):
143142
lp_res = torch.nn.functional.linear(hp_act_tensor, lp_tensor)
144143
self.assertGreater(torchao.quantization.utils.compute_error(hp_res, lp_res), self.LINEAR_MIN_SQNR)
145144

145+
@common_utils.parametrize("device", COMMON_DEVICES)
146+
@common_utils.parametrize("dtype", COMMON_DTYPES)
147+
def test_linear_compile(self, device, dtype):
148+
hp_tensor = torch.randn(4, 128, device=device, dtype=dtype)
149+
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs)
150+
151+
hp_act_tensor = torch.randn(32, 128, device=device, dtype=dtype)
152+
hp_res = torch.nn.functional.linear(hp_act_tensor, hp_tensor)
153+
l = torch.nn.Linear(128, 4, bias=False, device=device, dtype=dtype)
154+
l.weight = torch.nn.Parameter(lp_tensor)
155+
lp_res = torch.compile(l)(hp_act_tensor)
156+
self.assertGreater(torchao.quantization.utils.compute_error(hp_res, lp_res), self.LINEAR_MIN_SQNR)
157+
146158
common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase)
147159

148160
if __name__ == "__main__":

0 commit comments

Comments
 (0)