13
13
from pathlib import Path
14
14
15
15
import torch
16
- from torch .ao .quantization .quantize_pt2e import (
17
- convert_pt2e ,
18
- prepare_pt2e ,
19
- )
16
+ from torch .ao .quantization .quantize_pt2e import convert_pt2e , prepare_pt2e
20
17
from torch .ao .quantization .quantizer .xnnpack_quantizer import (
21
- XNNPACKQuantizer ,
22
18
get_symmetric_quantization_config ,
19
+ XNNPACKQuantizer ,
23
20
)
24
21
from torch .testing ._internal import common_utils
25
22
from torch .testing ._internal .common_utils import TestCase
26
23
27
24
from torchao import quantize_
28
- from torchao ._models .llama .model import Transformer , prepare_inputs_for_model
25
+ from torchao ._models .llama .model import prepare_inputs_for_model , Transformer
29
26
from torchao ._models .llama .tokenizer import get_tokenizer
30
- from torchao .dtypes import (
31
- AffineQuantizedTensor ,
32
- )
33
- from torchao .quantization import (
34
- LinearActivationQuantizedTensor ,
35
- )
27
+ from torchao .dtypes import AffineQuantizedTensor
28
+ from torchao .quantization import LinearActivationQuantizedTensor
36
29
from torchao .quantization .quant_api import (
37
- Quantizer ,
38
- TwoStepQuantizer ,
39
30
_replace_with_custom_fn_if_matches_filter ,
40
31
int4_weight_only ,
41
32
int8_dynamic_activation_int4_weight ,
42
33
int8_dynamic_activation_int8_weight ,
43
34
int8_weight_only ,
35
+ Quantizer ,
36
+ TwoStepQuantizer ,
44
37
)
45
- from torchao .quantization .quant_primitives import (
46
- MappingType ,
47
- )
38
+ from torchao .quantization .quant_primitives import MappingType
48
39
from torchao .quantization .subclass import (
49
40
Int4WeightOnlyQuantizedLinearWeight ,
50
41
Int8WeightOnlyQuantizedLinearWeight ,
59
50
60
51
61
52
def dynamic_quant (model , example_inputs ):
62
- m = torch .export .export (model , example_inputs ).module ()
53
+ m = torch .export .export (model , example_inputs , strict = True ).module ()
63
54
quantizer = XNNPACKQuantizer ().set_global (
64
55
get_symmetric_quantization_config (is_dynamic = True )
65
56
)
@@ -69,7 +60,7 @@ def dynamic_quant(model, example_inputs):
69
60
70
61
71
62
def capture_and_prepare (model , example_inputs ):
72
- m = torch .export .export (model , example_inputs )
63
+ m = torch .export .export (model , example_inputs , strict = True )
73
64
quantizer = XNNPACKQuantizer ().set_global (
74
65
get_symmetric_quantization_config (is_dynamic = True )
75
66
)
@@ -666,7 +657,7 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):
666
657
667
658
m_unwrapped = unwrap_tensor_subclass (m )
668
659
669
- m = torch .export .export (m_unwrapped , example_inputs ).module ()
660
+ m = torch .export .export (m_unwrapped , example_inputs , strict = True ).module ()
670
661
exported_model_res = m (* example_inputs )
671
662
672
663
self .assertTrue (torch .equal (exported_model_res , ref ))
0 commit comments