Skip to content

Commit 00d18ec

Browse files
authored
Add test util for basic tensor subclass functionalities (#839)
* Add test util for basic tensor subclass functionalities Summary: This is a small starting point for testing low precision tensor subclass functionalities we can add more test cases for training, tensor parallel, FSDP in the future right now it tests: - tensor flatten/unflatten - constructing low precision tensor with different device/dtype - move tensor subclass from device1 to device2 - transpose works - linear works (weight only quantization with the low precision tensor) It can be extended with new tensor subclasses or test cases by overriding the class variables: e.g. ``` class MyTensorSubclassTest(TorchAOBasicTestCase): COMMON_DEVICES = ["cpu", "cuda"] COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16] TENSOR_SUBCLASS = LUTQuantizedTensor FACTORY_FN = to_lut_quantized_intx kwargs = { "target_dtype": torch.uint8, } # minimum sqnr for linear operation when the weight is quantized to low precision # with the above setting LINEAR_MIN_SQNR = 40 ``` Test Plan: python test/utils.py Reviewers: Subscribers: Tasks: Tags: * minor fix * don't use inductor TestCase
1 parent c842d50 commit 00d18ec

File tree

1 file changed

+161
-0
lines changed

1 file changed

+161
-0
lines changed

torchao/testing/utils.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
import unittest
2+
import functools
3+
import copy
4+
import torch
5+
import torchao
6+
7+
from torch.testing._internal import common_utils
8+
from torchao.dtypes import AffineQuantizedTensor
9+
from torchao.dtypes import to_affine_quantized_intx
10+
from torchao.quantization.quant_primitives import MappingType
11+
12+
"""
13+
How to use:
14+
15+
import unittest
16+
from torchao.testing.utils import TorchAOBasicTestCase, copy_tests
17+
from torch.testing._internal import common_utils
18+
19+
# TODO: currently there is no way to set COMMON_DEVICES/COMMON_DTYPES
20+
# we can figure out this a bit later
21+
22+
# change arguments
23+
class MyTestCase(TorchAOBasicTestCase):
24+
TENSOR_SUBCLASS = MyDTypeTensor
25+
FACTOR_FN = to_my_dtype
26+
kwargs = {"target_dtype": torch.uint8}
27+
LINEAR_MIN_SQNR = 30
28+
29+
# copy the instantiated tests
30+
copy_tests(TorchAOBasicTestCase, MyTestCase, "my_test_case")
31+
32+
if __name__ == "__main__":
33+
unittest.main()
34+
"""
35+
36+
# copied from https://github.com/pytorch/pytorch/blob/941d094dd1b507dacf06ddc6ed3485a9537e09b7/test/inductor/test_torchinductor.py#L11389
37+
def copy_tests(
38+
my_cls, other_cls, suffix, test_failures=None, xfail_prop=None
39+
): # noqa: B902
40+
for name, value in my_cls.__dict__.items():
41+
if name.startswith("test_"):
42+
# You cannot copy functions in Python, so we use closures here to
43+
# create objects with different ids. Otherwise, unittest.skip
44+
# would modify all methods sharing the same object id. Also, by
45+
# using a default argument, we create a copy instead of a
46+
# reference. Otherwise, we would lose access to the value.
47+
48+
@functools.wraps(value)
49+
def new_test(self, value=value):
50+
return value(self)
51+
52+
# Copy __dict__ which may contain test metadata
53+
new_test.__dict__ = copy.deepcopy(value.__dict__)
54+
55+
if xfail_prop is not None and hasattr(value, xfail_prop):
56+
new_test = unittest.expectedFailure(new_test)
57+
58+
tf = test_failures and test_failures.get(name)
59+
if tf is not None and suffix in tf.suffixes:
60+
skip_func = (
61+
unittest.skip("Skipped!")
62+
if tf.is_skip
63+
else unittest.expectedFailure
64+
)
65+
new_test = skip_func(new_test)
66+
67+
setattr(other_cls, f"{name}_{suffix}", new_test)
68+
69+
70+
71+
class TorchAOBasicTestCase(common_utils.TestCase):
72+
"""Basic test case for tensor subclasses
73+
"""
74+
COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
75+
COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16]
76+
77+
TENSOR_SUBCLASS = AffineQuantizedTensor
78+
FACTORY_FN = to_affine_quantized_intx
79+
kwargs = {
80+
"mapping_type": MappingType.ASYMMETRIC,
81+
"block_size": (1, 32),
82+
"target_dtype": torch.uint8,
83+
}
84+
# minimum sqnr for linear operation when the weight is quantized to low precision
85+
# with the above setting
86+
LINEAR_MIN_SQNR = 40
87+
88+
def test_flatten_unflatten(self):
89+
hp_tensor = torch.randn(4, 128)
90+
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs)
91+
tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__()
92+
tensor_data_dict = {name: getattr(lp_tensor, name) for name in tensor_data_name_dict}
93+
outer_size = lp_tensor.size()
94+
outer_stride = lp_tensor.stride()
95+
reconstructed = self.TENSOR_SUBCLASS.__tensor_unflatten__(tensor_data_dict, tensor_attributes, outer_size, outer_stride)
96+
self.assertEqual(lp_tensor.dequantize(), reconstructed.dequantize())
97+
98+
@common_utils.parametrize("device", COMMON_DEVICES)
99+
@common_utils.parametrize("dtype", COMMON_DTYPES)
100+
def test_hp_tensor_device_dtype(self, device, dtype):
101+
hp_tensor = torch.randn(4, 128, device=device, dtype=dtype)
102+
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs)
103+
104+
@common_utils.parametrize("device1", COMMON_DEVICES)
105+
@common_utils.parametrize("device2", COMMON_DEVICES)
106+
def test_device1_to_device2(self, device1, device2):
107+
"""Note: this should be parametrized with device1 and device2
108+
e.g. device1 = ["cpu", "cuda"], device2 = ["cpu", "cuda"]
109+
"""
110+
hp_tensor = torch.randn(4, 128, device=device1, dtype=torch.bfloat16)
111+
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs)
112+
lp_tensor.to(device=device2)
113+
114+
hp_tensor = torch.randn(4, 128, device=device1, dtype=torch.bfloat16)
115+
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs)
116+
lp_tensor.to(device2)
117+
118+
hp_tensor = torch.randn(4, 128, device=device1, dtype=torch.bfloat16)
119+
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs)
120+
lp_tensor.cuda()
121+
122+
hp_tensor = torch.randn(4, 128, device=device1, dtype=torch.bfloat16)
123+
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs)
124+
lp_tensor.cpu()
125+
126+
@common_utils.parametrize("device", COMMON_DEVICES)
127+
@common_utils.parametrize("dtype", COMMON_DTYPES)
128+
def test_transpose(self, device, dtype):
129+
hp_tensor = torch.randn(4, 128, device=device, dtype=dtype)
130+
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs)
131+
lp_tensor = lp_tensor.t()
132+
self.assertEqual(lp_tensor.shape, (128, 4))
133+
134+
@common_utils.parametrize("device", COMMON_DEVICES)
135+
@common_utils.parametrize("dtype", COMMON_DTYPES)
136+
def test_linear(self, device, dtype):
137+
hp_tensor = torch.randn(4, 128, device=device, dtype=dtype)
138+
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs)
139+
140+
hp_act_tensor = torch.randn(32, 128, device=device, dtype=dtype)
141+
hp_res = torch.nn.functional.linear(hp_act_tensor, hp_tensor)
142+
lp_res = torch.nn.functional.linear(hp_act_tensor, lp_tensor)
143+
self.assertGreater(torchao.quantization.utils.compute_error(hp_res, lp_res), self.LINEAR_MIN_SQNR)
144+
145+
@common_utils.parametrize("device", COMMON_DEVICES)
146+
@common_utils.parametrize("dtype", COMMON_DTYPES)
147+
def test_linear_compile(self, device, dtype):
148+
hp_tensor = torch.randn(4, 128, device=device, dtype=dtype)
149+
lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs)
150+
151+
hp_act_tensor = torch.randn(32, 128, device=device, dtype=dtype)
152+
hp_res = torch.nn.functional.linear(hp_act_tensor, hp_tensor)
153+
l = torch.nn.Linear(128, 4, bias=False, device=device, dtype=dtype)
154+
l.weight = torch.nn.Parameter(lp_tensor)
155+
lp_res = torch.compile(l)(hp_act_tensor)
156+
self.assertGreater(torchao.quantization.utils.compute_error(hp_res, lp_res), self.LINEAR_MIN_SQNR)
157+
158+
common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase)
159+
160+
if __name__ == "__main__":
161+
unittest.main()

0 commit comments

Comments
 (0)