10
10
import unittest
11
11
12
12
import torch
13
- from parameterized import param , parameterized
14
13
from torch .testing import FileCheck
15
14
16
15
from torchao .dtypes import PlainLayout
17
- from torchao .dtypes .affine_quantized_tensor import AffineQuantizedTensor
18
16
from torchao .experimental .packed_linear_int8_dynamic_activation_intx_weight_layout import (
19
17
PackedLinearInt8DynamicActivationIntxWeightLayout ,
20
18
)
21
19
from torchao .experimental .q_dq_layout import QDQLayout
22
- from torchao .experimental .quant_api import int8_dynamic_activation_intx_weight
23
- from torchao .quantization .granularity import PerGroup , PerRow
24
- from torchao .quantization .linear_activation_quantized_tensor import (
25
- LinearActivationQuantizedTensor ,
20
+ from torchao .experimental .quant_api import (
21
+ int8_dynamic_activation_intx_weight ,
22
+ )
23
+ from torchao .quantization .granularity import (
24
+ PerGroup ,
25
+ PerRow ,
26
26
)
27
27
from torchao .quantization .quant_api import quantize_
28
28
from torchao .utils import unwrap_tensor_subclass
29
29
30
30
31
- # def _truncate_weight_to_bf16(weight):
32
- # assert isinstance(weight, AffineQuantizedTensor)
33
- # assert isinstance(weight.tensor_impl.get_layout(), PlainLayout)
34
- # assert weight.tensor_impl.scale.dtype == torch.float32
35
- # print("BEFORE", weight.tensor_impl.scale[0][0])
36
- # weight.tensor_impl.scale = weight.tensor_impl.scale.to(torch.bfloat16).to(
37
- # torch.float32
38
- # )
39
- # print("AFTER", weight.tensor_impl.scale[0][0])
40
- # print("CHANGE", weight.tensor_impl.scale.to(torch.bfloat16).to(torch.float32)[0][0])
41
-
42
-
43
- # def _truncate_model_to_bf16(model):
44
- # for name, param in model.named_parameters():
45
- # print(name, param.dtype)
46
- # # print(param)
47
- # if isinstance(param, LinearActivationQuantizedTensor):
48
- # print("FOUND ONE")
49
- # _truncate_weight_to_bf16(param.original_weight_tensor)
50
-
51
-
52
31
class TestInt8DynamicActivationIntxWeight (unittest .TestCase ):
53
- TEST_ACCURACY_CASES = [
54
- param (
55
- layout = layout ,
56
- weight_dtype = weight_dtype ,
57
- has_weight_zeros = has_weight_zeros ,
58
- granularity = granularity ,
59
- )
60
- for layout in [
32
+ def test_accuracy (self ):
33
+ """
34
+ Checks the accuracy of different layouts by comparing the results to PlainLayout()
35
+ """
36
+ m = 1
37
+ n = 1071
38
+ k = 4096
39
+ activations = torch .randn (m , k )
40
+ model = torch .nn .Sequential (* [torch .nn .Linear (k , n , bias = False )])
41
+
42
+ reference_layout = PlainLayout ()
43
+ test_layouts = [
61
44
PackedLinearInt8DynamicActivationIntxWeightLayout (),
62
45
QDQLayout (),
63
46
]
64
- for weight_dtype in [
47
+ test_weight_dtypes = [
65
48
torch .int1 ,
66
49
torch .int2 ,
67
50
torch .int3 ,
@@ -71,68 +54,37 @@ class TestInt8DynamicActivationIntxWeight(unittest.TestCase):
71
54
torch .int7 ,
72
55
torch .int8 ,
73
56
]
74
- for has_weight_zeros in [
75
- True ,
76
- False ,
77
- ]
78
- for granularity in [
79
- PerGroup (128 ),
80
- PerRow (),
81
- ]
82
- ]
83
-
84
- @parameterized .expand (
85
- TEST_ACCURACY_CASES ,
86
- name_func = lambda f , _ , params : f .__name__ + f"_{ params .kwargs } " ,
87
- )
88
- def test_accuracy (self , layout , weight_dtype , has_weight_zeros , granularity ):
89
- """
90
- Checks the accuracy of different layouts by comparing the results to PlainLayout()
91
- """
92
- m = 1
93
- n = 1071
94
- k = 4096
95
- activations = torch .randn (m , k )
96
- model = torch .nn .Sequential (* [torch .nn .Linear (k , n , bias = False )])
97
-
98
- reference_layout = PlainLayout ()
99
- quantized_model = copy .deepcopy (model )
100
- quantize_ (
101
- quantized_model ,
102
- int8_dynamic_activation_intx_weight (
103
- weight_dtype = weight_dtype ,
104
- granularity = granularity ,
105
- has_weight_zeros = has_weight_zeros ,
106
- layout = layout ,
107
- ),
108
- )
109
-
110
- quantized_model_reference = copy .deepcopy (model )
111
- quantize_ (
112
- quantized_model_reference ,
113
- int8_dynamic_activation_intx_weight (
114
- weight_dtype = weight_dtype ,
115
- granularity = granularity ,
116
- has_weight_zeros = has_weight_zeros ,
117
- layout = reference_layout ,
118
- ),
119
- )
120
-
121
- with torch .no_grad ():
122
- result = quantized_model (activations )
123
- expected_result = quantized_model_reference (activations )
124
-
125
- if (
126
- isinstance (layout , PackedLinearInt8DynamicActivationIntxWeightLayout )
127
- and weight_dtype == torch .int4
128
- and not has_weight_zeros
57
+ test_has_weight_zeros = [True , False ]
58
+ test_granularities = [PerGroup (128 ), PerRow ()]
59
+ for layout , weight_dtype , has_weight_zeros , granularity in itertools .product (
60
+ test_layouts , test_weight_dtypes , test_has_weight_zeros , test_granularities
129
61
):
130
- # Use relaxed MSE accuracy criteria for KleidiAI kernels
131
- self .assertTrue (
132
- torch .nn .functional .mse_loss (result , expected_result ) <= 1e-5
62
+ quantized_model = copy .deepcopy (model )
63
+ quantize_ (
64
+ quantized_model ,
65
+ int8_dynamic_activation_intx_weight (
66
+ weight_dtype = weight_dtype ,
67
+ granularity = granularity ,
68
+ has_weight_zeros = has_weight_zeros ,
69
+ layout = layout ,
70
+ ),
71
+ )
72
+
73
+ quantized_model_reference = copy .deepcopy (model )
74
+ quantize_ (
75
+ quantized_model_reference ,
76
+ int8_dynamic_activation_intx_weight (
77
+ weight_dtype = weight_dtype ,
78
+ granularity = granularity ,
79
+ has_weight_zeros = has_weight_zeros ,
80
+ layout = reference_layout ,
81
+ ),
133
82
)
134
- else :
135
- self .assertTrue (torch .allclose (result , expected_result , atol = 1e-5 ))
83
+
84
+ with torch .no_grad ():
85
+ result = quantized_model (activations )
86
+ expected_result = quantized_model_reference (activations )
87
+ self .assertTrue (torch .allclose (result , expected_result , atol = 1e-6 ))
136
88
137
89
def test_export_compile_aoti_PackedLinearInt8DynamicActivationIntxWeightLayout (
138
90
self ,
@@ -232,7 +184,3 @@ def test_export_QDQLayout(self):
232
184
FileCheck ().check_count (line , 1 , exactly = True ).run (
233
185
exported .graph_module .code
234
186
)
235
-
236
-
237
- if __name__ == "__main__" :
238
- unittest .main ()
0 commit comments