Skip to content

Commit 8856c1e

Browse files
committed
up
1 parent 92a3da4 commit 8856c1e

File tree

1 file changed

+48
-100
lines changed

1 file changed

+48
-100
lines changed

torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py

Lines changed: 48 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -10,58 +10,41 @@
1010
import unittest
1111

1212
import torch
13-
from parameterized import param, parameterized
1413
from torch.testing import FileCheck
1514

1615
from torchao.dtypes import PlainLayout
17-
from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor
1816
from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import (
1917
PackedLinearInt8DynamicActivationIntxWeightLayout,
2018
)
2119
from torchao.experimental.q_dq_layout import QDQLayout
22-
from torchao.experimental.quant_api import int8_dynamic_activation_intx_weight
23-
from torchao.quantization.granularity import PerGroup, PerRow
24-
from torchao.quantization.linear_activation_quantized_tensor import (
25-
LinearActivationQuantizedTensor,
20+
from torchao.experimental.quant_api import (
21+
int8_dynamic_activation_intx_weight,
22+
)
23+
from torchao.quantization.granularity import (
24+
PerGroup,
25+
PerRow,
2626
)
2727
from torchao.quantization.quant_api import quantize_
2828
from torchao.utils import unwrap_tensor_subclass
2929

3030

31-
# def _truncate_weight_to_bf16(weight):
32-
# assert isinstance(weight, AffineQuantizedTensor)
33-
# assert isinstance(weight.tensor_impl.get_layout(), PlainLayout)
34-
# assert weight.tensor_impl.scale.dtype == torch.float32
35-
# print("BEFORE", weight.tensor_impl.scale[0][0])
36-
# weight.tensor_impl.scale = weight.tensor_impl.scale.to(torch.bfloat16).to(
37-
# torch.float32
38-
# )
39-
# print("AFTER", weight.tensor_impl.scale[0][0])
40-
# print("CHANGE", weight.tensor_impl.scale.to(torch.bfloat16).to(torch.float32)[0][0])
41-
42-
43-
# def _truncate_model_to_bf16(model):
44-
# for name, param in model.named_parameters():
45-
# print(name, param.dtype)
46-
# # print(param)
47-
# if isinstance(param, LinearActivationQuantizedTensor):
48-
# print("FOUND ONE")
49-
# _truncate_weight_to_bf16(param.original_weight_tensor)
50-
51-
5231
class TestInt8DynamicActivationIntxWeight(unittest.TestCase):
53-
TEST_ACCURACY_CASES = [
54-
param(
55-
layout=layout,
56-
weight_dtype=weight_dtype,
57-
has_weight_zeros=has_weight_zeros,
58-
granularity=granularity,
59-
)
60-
for layout in [
32+
def test_accuracy(self):
33+
"""
34+
Checks the accuracy of different layouts by comparing the results to PlainLayout()
35+
"""
36+
m = 1
37+
n = 1071
38+
k = 4096
39+
activations = torch.randn(m, k)
40+
model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)])
41+
42+
reference_layout = PlainLayout()
43+
test_layouts = [
6144
PackedLinearInt8DynamicActivationIntxWeightLayout(),
6245
QDQLayout(),
6346
]
64-
for weight_dtype in [
47+
test_weight_dtypes = [
6548
torch.int1,
6649
torch.int2,
6750
torch.int3,
@@ -71,68 +54,37 @@ class TestInt8DynamicActivationIntxWeight(unittest.TestCase):
7154
torch.int7,
7255
torch.int8,
7356
]
74-
for has_weight_zeros in [
75-
True,
76-
False,
77-
]
78-
for granularity in [
79-
PerGroup(128),
80-
PerRow(),
81-
]
82-
]
83-
84-
@parameterized.expand(
85-
TEST_ACCURACY_CASES,
86-
name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}",
87-
)
88-
def test_accuracy(self, layout, weight_dtype, has_weight_zeros, granularity):
89-
"""
90-
Checks the accuracy of different layouts by comparing the results to PlainLayout()
91-
"""
92-
m = 1
93-
n = 1071
94-
k = 4096
95-
activations = torch.randn(m, k)
96-
model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)])
97-
98-
reference_layout = PlainLayout()
99-
quantized_model = copy.deepcopy(model)
100-
quantize_(
101-
quantized_model,
102-
int8_dynamic_activation_intx_weight(
103-
weight_dtype=weight_dtype,
104-
granularity=granularity,
105-
has_weight_zeros=has_weight_zeros,
106-
layout=layout,
107-
),
108-
)
109-
110-
quantized_model_reference = copy.deepcopy(model)
111-
quantize_(
112-
quantized_model_reference,
113-
int8_dynamic_activation_intx_weight(
114-
weight_dtype=weight_dtype,
115-
granularity=granularity,
116-
has_weight_zeros=has_weight_zeros,
117-
layout=reference_layout,
118-
),
119-
)
120-
121-
with torch.no_grad():
122-
result = quantized_model(activations)
123-
expected_result = quantized_model_reference(activations)
124-
125-
if (
126-
isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout)
127-
and weight_dtype == torch.int4
128-
and not has_weight_zeros
57+
test_has_weight_zeros = [True, False]
58+
test_granularities = [PerGroup(128), PerRow()]
59+
for layout, weight_dtype, has_weight_zeros, granularity in itertools.product(
60+
test_layouts, test_weight_dtypes, test_has_weight_zeros, test_granularities
12961
):
130-
# Use relaxed MSE accuracy criteria for KleidiAI kernels
131-
self.assertTrue(
132-
torch.nn.functional.mse_loss(result, expected_result) <= 1e-5
62+
quantized_model = copy.deepcopy(model)
63+
quantize_(
64+
quantized_model,
65+
int8_dynamic_activation_intx_weight(
66+
weight_dtype=weight_dtype,
67+
granularity=granularity,
68+
has_weight_zeros=has_weight_zeros,
69+
layout=layout,
70+
),
71+
)
72+
73+
quantized_model_reference = copy.deepcopy(model)
74+
quantize_(
75+
quantized_model_reference,
76+
int8_dynamic_activation_intx_weight(
77+
weight_dtype=weight_dtype,
78+
granularity=granularity,
79+
has_weight_zeros=has_weight_zeros,
80+
layout=reference_layout,
81+
),
13382
)
134-
else:
135-
self.assertTrue(torch.allclose(result, expected_result, atol=1e-5))
83+
84+
with torch.no_grad():
85+
result = quantized_model(activations)
86+
expected_result = quantized_model_reference(activations)
87+
self.assertTrue(torch.allclose(result, expected_result, atol=1e-6))
13688

13789
def test_export_compile_aoti_PackedLinearInt8DynamicActivationIntxWeightLayout(
13890
self,
@@ -232,7 +184,3 @@ def test_export_QDQLayout(self):
232184
FileCheck().check_count(line, 1, exactly=True).run(
233185
exported.graph_module.code
234186
)
235-
236-
237-
if __name__ == "__main__":
238-
unittest.main()

0 commit comments

Comments
 (0)