Skip to content

Commit 3a96640

Browse files
gmagogsfmfacebook-github-bot
authored andcommitted
pytorch/ao/test/quantization
Reviewed By: avikchaudhuri Differential Revision: D67388025
1 parent 38c79d4 commit 3a96640

File tree

1 file changed

+11
-20
lines changed

1 file changed

+11
-20
lines changed

test/quantization/test_quant_api.py

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,38 +13,29 @@
1313
from pathlib import Path
1414

1515
import torch
16-
from torch.ao.quantization.quantize_pt2e import (
17-
convert_pt2e,
18-
prepare_pt2e,
19-
)
16+
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
2017
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
21-
XNNPACKQuantizer,
2218
get_symmetric_quantization_config,
19+
XNNPACKQuantizer,
2320
)
2421
from torch.testing._internal import common_utils
2522
from torch.testing._internal.common_utils import TestCase
2623

2724
from torchao import quantize_
28-
from torchao._models.llama.model import Transformer, prepare_inputs_for_model
25+
from torchao._models.llama.model import prepare_inputs_for_model, Transformer
2926
from torchao._models.llama.tokenizer import get_tokenizer
30-
from torchao.dtypes import (
31-
AffineQuantizedTensor,
32-
)
33-
from torchao.quantization import (
34-
LinearActivationQuantizedTensor,
35-
)
27+
from torchao.dtypes import AffineQuantizedTensor
28+
from torchao.quantization import LinearActivationQuantizedTensor
3629
from torchao.quantization.quant_api import (
37-
Quantizer,
38-
TwoStepQuantizer,
3930
_replace_with_custom_fn_if_matches_filter,
4031
int4_weight_only,
4132
int8_dynamic_activation_int4_weight,
4233
int8_dynamic_activation_int8_weight,
4334
int8_weight_only,
35+
Quantizer,
36+
TwoStepQuantizer,
4437
)
45-
from torchao.quantization.quant_primitives import (
46-
MappingType,
47-
)
38+
from torchao.quantization.quant_primitives import MappingType
4839
from torchao.quantization.subclass import (
4940
Int4WeightOnlyQuantizedLinearWeight,
5041
Int8WeightOnlyQuantizedLinearWeight,
@@ -59,7 +50,7 @@
5950

6051

6152
def dynamic_quant(model, example_inputs):
62-
m = torch.export.export(model, example_inputs).module()
53+
m = torch.export.export(model, example_inputs, strict=True).module()
6354
quantizer = XNNPACKQuantizer().set_global(
6455
get_symmetric_quantization_config(is_dynamic=True)
6556
)
@@ -69,7 +60,7 @@ def dynamic_quant(model, example_inputs):
6960

7061

7162
def capture_and_prepare(model, example_inputs):
72-
m = torch.export.export(model, example_inputs)
63+
m = torch.export.export(model, example_inputs, strict=True)
7364
quantizer = XNNPACKQuantizer().set_global(
7465
get_symmetric_quantization_config(is_dynamic=True)
7566
)
@@ -666,7 +657,7 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):
666657

667658
m_unwrapped = unwrap_tensor_subclass(m)
668659

669-
m = torch.export.export(m_unwrapped, example_inputs).module()
660+
m = torch.export.export(m_unwrapped, example_inputs, strict=True).module()
670661
exported_model_res = m(*example_inputs)
671662

672663
self.assertTrue(torch.equal(exported_model_res, ref))

0 commit comments

Comments
 (0)