Skip to content

Commit 0763e90

Browse files
committed
up
1 parent 7446048 commit 0763e90

File tree

1 file changed

+26
-21
lines changed

1 file changed

+26
-21
lines changed

torchao/experimental/tests/test_quant_passes.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import itertools
78
import unittest
89

910
import torch
@@ -32,27 +33,31 @@ def test_replace_q_dq_patterns_with_quantized_linear_ops_pass(self):
3233
layer_to_weight_mapping_type = {}
3334
layer_to_weight_zero_point_domain = {}
3435
layer_to_weight_granularity = {}
35-
for weight_dtype in [getattr(torch, f"int{i}") for i in range(1, 9)]:
36-
for weight_mapping_type in [MappingType.ASYMMETRIC, MappingType.SYMMETRIC]:
37-
for weight_zero_point_domain in [
38-
ZeroPointDomain.INT,
39-
ZeroPointDomain.NONE,
40-
]:
41-
if (
42-
weight_mapping_type == MappingType.SYMMETRIC
43-
and weight_zero_point_domain == ZeroPointDomain.INT
44-
):
45-
continue
46-
for weight_granularity in [PerAxis(0), PerGroup(32)]:
47-
for has_bias in [True, False]:
48-
idx = len(layers)
49-
layer_to_weight_dtype[idx] = weight_dtype
50-
layer_to_weight_mapping_type[idx] = weight_mapping_type
51-
layer_to_weight_zero_point_domain[idx] = (
52-
weight_zero_point_domain
53-
)
54-
layer_to_weight_granularity[idx] = weight_granularity
55-
layers.append(torch.nn.Linear(64, 64, bias=has_bias))
36+
for (
37+
weight_dtype,
38+
weight_mapping_type,
39+
weight_zero_point_domain,
40+
weight_granularity,
41+
has_bias,
42+
) in itertools.product(
43+
[getattr(torch, f"int{i}") for i in range(1, 9)],
44+
[MappingType.ASYMMETRIC, MappingType.SYMMETRIC],
45+
[ZeroPointDomain.INT, ZeroPointDomain.NONE],
46+
[PerAxis(0), PerGroup(32)],
47+
[True, False],
48+
):
49+
if (
50+
weight_mapping_type == MappingType.SYMMETRIC
51+
and weight_zero_point_domain == ZeroPointDomain.INT
52+
):
53+
continue
54+
55+
idx = len(layers)
56+
layer_to_weight_dtype[idx] = weight_dtype
57+
layer_to_weight_mapping_type[idx] = weight_mapping_type
58+
layer_to_weight_zero_point_domain[idx] = weight_zero_point_domain
59+
layer_to_weight_granularity[idx] = weight_granularity
60+
layers.append(torch.nn.Linear(64, 64, bias=has_bias))
5661

5762
activations = torch.randn(2, 1, 64, dtype=torch.float32)
5863
model = torch.nn.Sequential(*layers)

0 commit comments

Comments
 (0)