27
27
28
28
from ..export .export_example import export_to_pte
29
29
30
- from ..models import MODEL_NAME_TO_MODEL
30
+ from ..models import MODEL_NAME_TO_MODEL , MODEL_NAME_TO_OPTIONS
31
31
32
- # Note: for mv3, the mul op is not supported in XNNPACKQuantizer, that could be supported soon
33
- QUANT_MODEL_NAME_TO_MODEL = {
34
- name : MODEL_NAME_TO_MODEL [name ] for name in ["linear" , "add" , "add_mul" , "mv2" ]
35
- }
36
-
37
-
38
- def quantize (model_name , model , example_inputs ):
39
- """This is the official recommended flow for quantization in pytorch 2.0 export"""
40
- m = model .eval ()
41
- m = export .capture_pre_autograd_graph (m , copy .deepcopy (example_inputs ))
42
- print ("original model:" , m )
43
- quantizer = XNNPACKQuantizer ()
44
- # if we set is_per_channel to True, we also need to add out_variant of quantize_per_channel/dequantize_per_channel
45
- operator_config = get_symmetric_quantization_config (is_per_channel = False )
46
- quantizer .set_global (operator_config )
47
- m = prepare_pt2e (m , quantizer )
48
- # calibration
49
- m (* example_inputs )
50
- m = convert_pt2e (m )
51
- print ("quantized model:" , m )
52
- # make sure we can export to flat buffer
53
- export_to_pte (model_name , m , copy .deepcopy (example_inputs ))
32
+ from .utils import quantize
54
33
55
34
56
35
def verify_xnnpack_quantizer_matching_fx_quant_model (model_name , model , example_inputs ):
@@ -102,7 +81,7 @@ def verify_xnnpack_quantizer_matching_fx_quant_model(model_name, model, example_
102
81
"-m" ,
103
82
"--model_name" ,
104
83
required = True ,
105
- help = f"Provide model name. Valid ones: { list (QUANT_MODEL_NAME_TO_MODEL .keys ())} " ,
84
+ help = f"Provide model name. Valid ones: { list (MODEL_NAME_TO_OPTIONS .keys ())} " ,
106
85
)
107
86
parser .add_argument (
108
87
"-ve" ,
@@ -122,12 +101,12 @@ def verify_xnnpack_quantizer_matching_fx_quant_model(model_name, model, example_
122
101
args = parser .parse_args ()
123
102
if args .so_library :
124
103
torch .ops .load_library (args .so_library )
125
- if not args .verify and args .model_name not in QUANT_MODEL_NAME_TO_MODEL :
104
+ if not args .verify and args .model_name not in MODEL_NAME_TO_OPTIONS :
126
105
raise RuntimeError (
127
106
f"Model { args .model_name } is not a valid name. or not quantizable right now, "
128
107
"please contact executorch team if you want to learn why or how to support "
129
108
"quantization for the requested model"
130
- f"Available models are { list (QUANT_MODEL_NAME_TO_MODEL .keys ())} ."
109
+ f"Available models are { list (MODEL_NAME_TO_OPTIONS .keys ())} ."
131
110
)
132
111
133
112
model , example_inputs = MODEL_NAME_TO_MODEL [args .model_name ]()
@@ -137,5 +116,6 @@ def verify_xnnpack_quantizer_matching_fx_quant_model(model_name, model, example_
137
116
args .model_name , model , example_inputs
138
117
)
139
118
140
- quantize (args .model_name , model , example_inputs )
119
+ quantized_model = quantize (model , example_inputs )
120
+ export_to_pte (args .model_name , quantized_model , copy .deepcopy (example_inputs ))
141
121
print ("finished" )
0 commit comments