3
3
#
4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
+ from dataclasses import dataclass
6
7
from enum import IntEnum , unique
7
8
from functools import partial
8
- from typing import Callable , Dict , Optional , Sequence , Set , Tuple
9
+ from typing import Callable , Dict , List , Optional , Sequence , Set , Tuple
9
10
10
11
import torch
11
12
from executorch .backends .qualcomm ._passes .qnn_pass_manager import QnnPassManager
@@ -58,7 +59,7 @@ class QuantDtype(IntEnum):
58
59
use_8a8w = 4
59
60
60
61
61
- quant_config_dict = {
62
+ QUANT_CONFIG_DICT = {
62
63
# PTQ
63
64
(QuantDtype .use_16a16w , False ): (
64
65
get_16a16w_qnn_ptq_config ,
@@ -123,21 +124,71 @@ class QuantDtype(IntEnum):
123
124
}
124
125
125
126
127
+ @dataclass
128
+ class ModuleQConfig :
129
+ quant_dtype : QuantDtype = QuantDtype .use_8a8w
130
+ is_qat : bool = False
131
+ is_conv_per_channel : bool = False
132
+ is_linear_per_channel : bool = False
133
+ act_observer : Optional [
134
+ torch .ao .quantization .observer .UniformQuantizationObserverBase
135
+ ] = None
136
+
137
+ def __post_init__ (self ):
138
+ if (self .quant_dtype , self .is_qat ) not in QUANT_CONFIG_DICT :
139
+ raise RuntimeError (
140
+ f"the quant config, (quant_dtype: { self .quant_dtype } , is_qat: { self .is_qat } ) is not support"
141
+ )
142
+ (
143
+ quant_config_func ,
144
+ per_channel_quant_config_func ,
145
+ per_block_quant_config_func ,
146
+ ) = QUANT_CONFIG_DICT [(self .quant_dtype , self .is_qat )]
147
+ self .quant_config = (
148
+ quant_config_func (act_observer = self .act_observer )
149
+ if self .act_observer
150
+ else quant_config_func ()
151
+ )
152
+ self .per_channel_quant_config = (
153
+ per_channel_quant_config_func (act_observer = self .act_observer )
154
+ if self .act_observer
155
+ else per_channel_quant_config_func ()
156
+ )
157
+ self .use_per_channel_weight_quant_ops = set ()
158
+ if self .is_conv_per_channel :
159
+ self .use_per_channel_weight_quant_ops .update (
160
+ {
161
+ torch .ops .aten .conv1d .default ,
162
+ torch .ops .aten .conv2d .default ,
163
+ torch .ops .aten .conv_transpose2d .input ,
164
+ }
165
+ )
166
+ if self .is_linear_per_channel :
167
+ self .use_per_channel_weight_quant_ops .update (
168
+ {
169
+ torch .ops .aten .linear .default ,
170
+ }
171
+ )
172
+ if per_block_quant_config_func :
173
+ self .per_block_quant_config = (
174
+ per_block_quant_config_func (act_observer = self .act_observer )
175
+ if self .act_observer
176
+ else per_block_quant_config_func ()
177
+ )
178
+
179
+
126
180
class QnnQuantizer (Quantizer ):
127
181
SUPPORTED_OPS : Set = set (OP_ANNOTATOR .keys ())
128
182
129
183
def __init__ (self ):
130
184
super ().__init__ ()
131
185
self .quant_ops : Set [OpOverload ] = self .SUPPORTED_OPS .copy ()
132
186
133
- self .is_qat = False
134
- self .quant_dtype = QuantDtype .use_8a8w
135
- self .quant_config : QuantizationConfig = get_8a8w_qnn_ptq_config ()
136
- self .per_channel_quant_config = get_ptq_per_channel_quant_config ()
137
- self .per_block_quant_config = get_ptq_per_block_quant_config ()
187
+ self .default_quant_config = ModuleQConfig ()
188
+ self .submodule_qconfig_list : List [
189
+ Tuple [Callable [[torch .fx .Node ], bool ], ModuleQConfig ]
190
+ ] = []
138
191
self .block_size_map = {}
139
- self .use_per_channel_weight_quant_ops : Set [OpOverload ] = set ()
140
- self .use_per_block_weight_quant_ops : Set [OpOverload ] = set ()
141
192
142
193
self .custom_quant_annotations : Sequence [Callable ] = []
143
194
self .discard_nodes : Set [str ] = set ()
@@ -155,41 +206,38 @@ def _annotate_custom_annotation(self, gm: GraphModule) -> None:
155
206
for annotation_func in self .custom_quant_annotations :
156
207
annotation_func (gm )
157
208
158
- def _get_quant_config (self , op : torch .fx .Node ) -> Optional [QuantizationConfig ]:
209
+ def _get_submodule_qconfig (self , node : torch .fx .Node ):
210
+ for func , qconfig in self .submodule_qconfig_list :
211
+ if func (node ):
212
+ return qconfig
213
+ return self .default_quant_config
214
+
215
+ def _get_quant_config (self , node : torch .fx .Node ) -> Optional [QuantizationConfig ]:
159
216
"""
160
- Priority:
161
- 1. is one of use_per_block_weight_quant_ops
162
- 2. is one of use_per_channel_weight_quant_ops
163
- 3. quant config
217
+ How to pick:
218
+ 1. is one of per_block_quant_config
219
+ 2. Pick specific submodule config if given.
220
+ 3. Pick one if op belongs to use_per_channel_weight_quant_ops
221
+ 4. If not 3, pick normal quant config
164
222
"""
165
- target = op .target
166
- if isinstance (target , str ):
223
+ op = node .target
224
+ if isinstance (op , str ):
167
225
return
168
226
169
- if target in self .use_per_block_weight_quant_ops :
170
- if block_size : = self .block_size_map . get ( op . name ):
171
- self . per_block_quant_config .block_size = block_size
172
- return self . per_block_quant_config
227
+ if block_size := self .block_size_map . get ( node . name ) :
228
+ config = self .default_quant_config . per_block_quant_config
229
+ config .block_size = block_size
230
+ return config
173
231
174
- if target in self .use_per_channel_weight_quant_ops :
175
- return self .per_channel_quant_config
232
+ config = self ._get_submodule_qconfig (node )
176
233
177
- if target in self . quant_ops :
178
- return self . quant_config
234
+ if op in config . use_per_channel_weight_quant_ops :
235
+ return config . per_channel_quant_config
179
236
180
- print (f"No quant config is implemented for op, { op } " )
181
-
182
- def _update_per_block_weight_quant_ops (self , ops : Set [OpOverload ], enable : bool ):
183
- if enable :
184
- self .use_per_block_weight_quant_ops .update (ops )
185
- else :
186
- self .use_per_block_weight_quant_ops .difference_update (ops )
237
+ if op in self .quant_ops :
238
+ return config .quant_config
187
239
188
- def _update_per_channel_weight_quant_ops (self , ops : Set [OpOverload ], enable : bool ):
189
- if enable :
190
- self .use_per_channel_weight_quant_ops .update (ops )
191
- else :
192
- self .use_per_channel_weight_quant_ops .difference_update (ops )
240
+ print (f"No quant config is implemented for op, { op } " )
193
241
194
242
def add_custom_quant_annotations (
195
243
self , custom_quant_annotations : Sequence [Callable ]
@@ -212,55 +260,74 @@ def annotate(self, model: GraphModule) -> GraphModule:
212
260
def get_supported_ops (self ) -> Set [OpOverload ]:
213
261
return self .SUPPORTED_OPS
214
262
215
- def set_quant_config (
216
- self , quant_dtype : QuantDtype , is_qat = False , act_observer = None
263
+ def set_default_quant_config (
264
+ self ,
265
+ quant_dtype : QuantDtype ,
266
+ is_qat = False ,
267
+ is_conv_per_channel = False ,
268
+ is_linear_per_channel = False ,
269
+ act_observer = None ,
217
270
) -> None :
218
- self .quant_dtype = quant_dtype
219
- self .is_qat = is_qat
220
- if (quant_dtype , is_qat ) not in quant_config_dict :
221
- raise RuntimeError (
222
- f"the quant config, (quant_dtype: { quant_dtype } , is_qat: { is_qat } ) is not support"
223
- )
224
-
225
- quant_config_fuc , per_channel_quant_config_fuc , per_block_quant_config_fuc = (
226
- quant_config_dict [(quant_dtype , is_qat )]
227
- )
228
- self .quant_config = (
229
- quant_config_fuc (act_observer = act_observer )
230
- if act_observer
231
- else quant_config_fuc ()
271
+ self .default_quant_config = ModuleQConfig (
272
+ quant_dtype ,
273
+ is_qat ,
274
+ is_conv_per_channel ,
275
+ is_linear_per_channel ,
276
+ act_observer ,
232
277
)
233
- self .per_channel_quant_config = (
234
- per_channel_quant_config_fuc (act_observer = act_observer )
235
- if act_observer
236
- else per_channel_quant_config_fuc ()
237
- )
238
- if per_block_quant_config_fuc is not None :
239
- self .per_block_quant_config = (
240
- per_block_quant_config_fuc (act_observer = act_observer )
241
- if act_observer
242
- else per_block_quant_config_fuc ()
243
- )
244
278
245
279
def set_block_size_map (self , block_size_map : Dict [str , Tuple ]) -> None :
246
280
self .block_size_map = block_size_map
247
281
248
- def set_per_block_conv_quant (self , enable : bool ) -> None :
249
- conv_ops = {torch .ops .aten .conv2d .default }
250
- self ._update_per_block_weight_quant_ops (conv_ops , enable )
251
-
252
- def set_per_channel_conv_quant (self , enable : bool ) -> None :
253
- conv_ops = {torch .ops .aten .conv1d .default , torch .ops .aten .conv2d .default }
254
- self ._update_per_channel_weight_quant_ops (conv_ops , enable )
255
-
256
- def set_per_channel_linear_quant (self , enable : bool ) -> None :
257
- linear_ops = {
258
- torch .ops .aten .linear .default ,
259
- }
260
- self ._update_per_channel_weight_quant_ops (linear_ops , enable )
282
+ def set_submodule_qconfig_list (
283
+ self , submodule_qconfig_list : List [Tuple [Callable , ModuleQConfig ]]
284
+ ) -> None :
285
+ """
286
+ Set specific quant config from a callback function.
287
+ If a node fits more than one callback, only apply the first one.
288
+ """
289
+ self .submodule_qconfig_list = submodule_qconfig_list
261
290
262
291
def transform_for_annotation (self , model : GraphModule ) -> GraphModule :
263
292
return QnnPassManager ().transform_for_annotation_pipeline (model )
264
293
265
294
def validate (self , model : GraphModule ) -> None :
266
295
pass
296
+
297
+
298
+ def get_submodule_type_predicate (module_type_str ):
299
+ """
300
+ An example of nn_module_stack
301
+ {
302
+ 'L__self__': ('', 'executorch.backends.qualcomm.tests.models.SubModules'),
303
+ 'L__self___add': ('add', 'executorch.backends.qualcomm.tests.models.Add')
304
+ }
305
+ """
306
+
307
+ def predicate (node ):
308
+ if nn_module_stack := node .meta .get ("nn_module_stack" ):
309
+ for _ , type_name in nn_module_stack .values ():
310
+ if module_type_str in type_name :
311
+ return True
312
+ return False
313
+
314
+ return predicate
315
+
316
+
317
+ def get_submodule_name_predicate (module_name_str ):
318
+ """
319
+ An example of nn_module_stack
320
+ {
321
+ 'L__self__': ('', 'executorch.backends.qualcomm.tests.models.SubModules'),
322
+ 'L__self___add': ('add', 'executorch.backends.qualcomm.tests.models.Add')
323
+ }
324
+ """
325
+
326
+ def predicate (node ):
327
+ if nn_module_stack := node .meta .get ("nn_module_stack" ):
328
+ for name in nn_module_stack .keys ():
329
+ if module_name_str in name :
330
+ return True
331
+ return False
332
+
333
+ return predicate
0 commit comments