3
3
#
4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
- import importlib
7
6
from dataclasses import dataclass
8
7
from enum import IntEnum , unique
9
8
from functools import partial
10
- from typing import Callable , Dict , Optional , Sequence , Set , Tuple
9
+ from typing import Callable , Dict , List , Optional , Sequence , Set , Tuple
11
10
12
11
import torch
13
12
from executorch .backends .qualcomm ._passes import (
@@ -153,9 +152,11 @@ def __post_init__(self):
153
152
raise RuntimeError (
154
153
f"the quant config, (quant_dtype: { self .quant_dtype } , is_qat: { self .is_qat } ) is not support"
155
154
)
156
- quant_config_func , per_channel_quant_config_func , per_block_quant_config_func = QUANT_CONFIG_DICT [
157
- (self .quant_dtype , self .is_qat )
158
- ]
155
+ (
156
+ quant_config_func ,
157
+ per_channel_quant_config_func ,
158
+ per_block_quant_config_func ,
159
+ ) = QUANT_CONFIG_DICT [(self .quant_dtype , self .is_qat )]
159
160
self .quant_config = (
160
161
quant_config_func (act_observer = self .act_observer )
161
162
if self .act_observer
@@ -197,7 +198,9 @@ def __init__(self):
197
198
self .quant_ops : Set [OpOverload ] = self .SUPPORTED_OPS .copy ()
198
199
199
200
self .default_quant_config = ModuleQConfig ()
200
- self .module_qconfig_dict : Dict [torch .nn .Module , ModuleQConfig ] = {}
201
+ self .submodule_qconfig_list : List [
202
+ Tuple [Callable [[torch .fx .Node ], bool ], ModuleQConfig ]
203
+ ] = []
201
204
self .block_size_map = {}
202
205
203
206
self .custom_quant_annotations : Sequence [Callable ] = []
@@ -216,44 +219,30 @@ def _annotate_custom_annotation(self, gm: GraphModule) -> None:
216
219
for annotation_func in self .custom_quant_annotations :
217
220
annotation_func (gm )
218
221
219
- def _get_submodule (self , node : torch .fx .Node ):
220
- """
221
- An example of nn_module_stack
222
- {
223
- 'L__self__': ('', 'executorch.backends.qualcomm.tests.models.SubModules'),
224
- 'L__self___add': ('add', 'executorch.backends.qualcomm.tests.models.Add')
225
- }
226
- """
227
-
228
- nn_module_stack = node .meta .get ("nn_module_stack" )
229
- if nn_module_stack :
230
- module_source_str , module_str = list (nn_module_stack .values ())[- 1 ][
231
- - 1
232
- ].rsplit ("." , 1 )
233
- module_source = importlib .import_module (module_source_str )
234
- return getattr (module_source , module_str )
235
- return None
222
+ def _get_submodule_qconfig (self , node : torch .fx .Node ):
223
+ for func , qconfig in self .submodule_qconfig_list :
224
+ if func (node ):
225
+ return qconfig
226
+ return self .default_quant_config
236
227
237
228
def _get_quant_config (self , node : torch .fx .Node ) -> Optional [QuantizationConfig ]:
238
229
"""
239
230
How to pick:
240
- 1. is one of use_per_block_weight_quant_ops
241
- 2. Choose specific submodule config if given.
231
+ 1. is one of per_block_quant_config
232
+ 2. Pick specific submodule config if given.
242
233
3. Pick one if op belongs to use_per_channel_weight_quant_ops
243
- 4. If not 2 , pick normal quant config
234
+ 4. If not 3 , pick normal quant config
244
235
"""
245
236
op = node .target
246
237
if isinstance (op , str ):
247
238
return
248
239
249
- if block_size := self .block_size_map .get (op .name ):
240
+ if block_size := self .block_size_map .get (node .name ):
250
241
config = self .default_quant_config .per_block_quant_config
251
242
config .block_size = block_size
252
243
return config
253
244
254
- config = self .module_qconfig_dict .get (
255
- self ._get_submodule (node ), self .default_quant_config
256
- )
245
+ config = self ._get_submodule_qconfig (node )
257
246
258
247
if op in config .use_per_channel_weight_quant_ops :
259
248
return config .per_channel_quant_config
@@ -303,13 +292,14 @@ def set_default_quant_config(
303
292
def set_block_size_map (self , block_size_map : Dict [str , Tuple ]) -> None :
304
293
self .block_size_map = block_size_map
305
294
306
- def set_submodule_quant_config (
307
- self , submodule : torch . nn . Module , module_qconfig : ModuleQConfig
295
+ def set_submodule_qconfig_list (
296
+ self , submodule_qconfig_list : List [ Tuple [ Callable , ModuleQConfig ]]
308
297
) -> None :
309
298
"""
310
- Set the quant config specific for a submodule
299
+ Set specific quant config from a callback function.
300
+ If a node fits more than one callback, only apply the first one.
311
301
"""
312
- self .module_qconfig_dict [ submodule ] = module_qconfig
302
+ self .submodule_qconfig_list = submodule_qconfig_list
313
303
314
304
def transform_for_annotation (self , model : GraphModule ) -> GraphModule :
315
305
model = ReduceDynamicRange ()(model ).graph_module
@@ -326,3 +316,41 @@ def transform_for_annotation(self, model: GraphModule) -> GraphModule:
326
316
327
317
def validate (self , model : GraphModule ) -> None :
328
318
pass
319
+
320
+
321
+ def get_submodule_type_predicate (module_type_str ):
322
+ """
323
+ An example of nn_module_stack
324
+ {
325
+ 'L__self__': ('', 'executorch.backends.qualcomm.tests.models.SubModules'),
326
+ 'L__self___add': ('add', 'executorch.backends.qualcomm.tests.models.Add')
327
+ }
328
+ """
329
+
330
+ def predicate (node ):
331
+ if nn_module_stack := node .meta .get ("nn_module_stack" ):
332
+ for _ , type_name in nn_module_stack .values ():
333
+ if module_type_str in type_name :
334
+ return True
335
+ return False
336
+
337
+ return predicate
338
+
339
+
340
+ def get_submodule_name_predicate (module_name_str ):
341
+ """
342
+ An example of nn_module_stack
343
+ {
344
+ 'L__self__': ('', 'executorch.backends.qualcomm.tests.models.SubModules'),
345
+ 'L__self___add': ('add', 'executorch.backends.qualcomm.tests.models.Add')
346
+ }
347
+ """
348
+
349
+ def predicate (node ):
350
+ if nn_module_stack := node .meta .get ("nn_module_stack" ):
351
+ for name in nn_module_stack .keys ():
352
+ if module_name_str in name :
353
+ return True
354
+ return False
355
+
356
+ return predicate
0 commit comments