3
3
#
4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
+ import importlib
7
+ from dataclasses import dataclass
6
8
from enum import IntEnum , unique
7
9
from functools import partial
8
- from typing import Callable , Optional , Sequence , Set
10
+ from typing import Callable , Dict , Optional , Sequence , Set
9
11
10
12
import torch
11
13
from executorch .backends .qualcomm ._passes import (
@@ -66,7 +68,7 @@ class QuantDtype(IntEnum):
66
68
use_8a8w = 3
67
69
68
70
69
- quant_config_dict = {
71
+ QUANT_CONFIG_DICT = {
70
72
# PTQ
71
73
(QuantDtype .use_16a16w , False ): (
72
74
get_16a16w_qnn_ptq_config ,
@@ -112,18 +114,60 @@ class QuantDtype(IntEnum):
112
114
}
113
115
114
116
117
+ @dataclass
118
+ class ModuleQConfig :
119
+ quant_dtype : QuantDtype = QuantDtype .use_8a8w
120
+ is_qat : bool = False
121
+ is_conv_per_channel : bool = False
122
+ is_linear_per_channel : bool = False
123
+ act_observer : Optional [
124
+ torch .ao .quantization .observer .UniformQuantizationObserverBase
125
+ ] = None
126
+
127
+ def __post_init__ (self ):
128
+ if (self .quant_dtype , self .is_qat ) not in QUANT_CONFIG_DICT :
129
+ raise RuntimeError (
130
+ f"the quant config, (quant_dtype: { self .quant_dtype } , is_qat: { self .is_qat } ) is not support"
131
+ )
132
+ quant_config_func , per_channel_quant_config_func = QUANT_CONFIG_DICT [
133
+ (self .quant_dtype , self .is_qat )
134
+ ]
135
+ self .quant_config = (
136
+ quant_config_func (act_observer = self .act_observer )
137
+ if self .act_observer
138
+ else quant_config_func ()
139
+ )
140
+ self .per_channel_quant_config = (
141
+ per_channel_quant_config_func (act_observer = self .act_observer )
142
+ if self .act_observer
143
+ else per_channel_quant_config_func ()
144
+ )
145
+ self .use_per_channel_weight_quant_ops = set ()
146
+ if self .is_conv_per_channel :
147
+ self .use_per_channel_weight_quant_ops .update (
148
+ {
149
+ torch .ops .aten .conv1d .default ,
150
+ torch .ops .aten .conv2d .default ,
151
+ torch .ops .aten .conv_transpose2d .input ,
152
+ }
153
+ )
154
+ if self .is_linear_per_channel :
155
+ self .use_per_channel_weight_quant_ops .update (
156
+ {
157
+ torch .ops .aten .linear .default ,
158
+ }
159
+ )
160
+
161
+
115
162
class QnnQuantizer (Quantizer ):
116
163
SUPPORTED_OPS : Set = set (OP_ANNOTATOR .keys ())
117
164
118
165
def __init__ (self ):
119
166
super ().__init__ ()
120
167
self .quant_ops : Set [OpOverload ] = self .SUPPORTED_OPS .copy ()
121
168
122
- self .is_qat = False
123
- self .quant_dtype = QuantDtype .use_8a8w
124
- self .quant_config : QuantizationConfig = get_8a8w_qnn_ptq_config ()
125
- self .per_channel_quant_config = get_ptq_per_channel_quant_config ()
126
- self .use_per_channel_weight_quant_ops : Set [OpOverload ] = set ()
169
+ self .default_quant_config = ModuleQConfig ()
170
+ self .module_qconfig_dict : Dict [torch .nn .Module , ModuleQConfig ] = {}
127
171
128
172
self .custom_quant_annotations : Sequence [Callable ] = []
129
173
self .discard_nodes : Set [str ] = set ()
@@ -133,37 +177,55 @@ def _annotate(self, gm: GraphModule) -> None:
133
177
if node .name in self .discard_nodes :
134
178
continue
135
179
136
- quant_config = self ._get_quant_config (node . target )
180
+ quant_config = self ._get_quant_config (node )
137
181
if quant_config :
138
182
OP_ANNOTATOR [node .target ](node , quant_config )
139
183
140
184
def _annotate_custom_annotation (self , gm : GraphModule ) -> None :
141
185
for annotation_func in self .custom_quant_annotations :
142
186
annotation_func (gm )
143
187
144
- def _get_quant_config (self , op : str | OpOverload ) -> Optional [QuantizationConfig ]:
188
+ def _get_submodule (self , node : torch .fx .Node ):
189
+ """
190
+ An example of nn_module_stack
191
+ {
192
+ 'L__self__': ('', 'executorch.backends.qualcomm.tests.models.SubModules'),
193
+ 'L__self___add': ('add', 'executorch.backends.qualcomm.tests.models.Add')
194
+ }
195
+ """
196
+
197
+ nn_module_stack = node .meta .get ("nn_module_stack" )
198
+ if nn_module_stack :
199
+ module_source_str , module_str = list (nn_module_stack .values ())[- 1 ][
200
+ - 1
201
+ ].rsplit ("." , 1 )
202
+ module_source = importlib .import_module (module_source_str )
203
+ return getattr (module_source , module_str )
204
+ return None
205
+
206
+ def _get_quant_config (self , node : torch .fx .Node ) -> Optional [QuantizationConfig ]:
145
207
"""
146
- Priority:
147
- 1. is one of use_per_channel_weight_quant_ops
148
- 2. quant config
208
+ How to pick:
209
+ 1. Choose specific submodule config if given.
210
+ 2. Pick one if op belongs to use_per_channel_weight_quant_ops
211
+ 3. If not 2, pick normal quant config
149
212
"""
213
+ op = node .target
150
214
if isinstance (op , str ):
151
215
return
152
216
153
- if op in self .use_per_channel_weight_quant_ops :
154
- return self .per_channel_quant_config
217
+ config = self .module_qconfig_dict .get (
218
+ self ._get_submodule (node ), self .default_quant_config
219
+ )
220
+
221
+ if op in config .use_per_channel_weight_quant_ops :
222
+ return config .per_channel_quant_config
155
223
156
224
if op in self .quant_ops :
157
- return self .quant_config
225
+ return config .quant_config
158
226
159
227
print (f"No quant config is implemented for op, { op } " )
160
228
161
- def _update_per_channel_weight_quant_ops (self , ops : Set [OpOverload ], enable : bool ):
162
- if enable :
163
- self .use_per_channel_weight_quant_ops .update (ops )
164
- else :
165
- self .use_per_channel_weight_quant_ops .difference_update (ops )
166
-
167
229
def add_custom_quant_annotations (
168
230
self , custom_quant_annotations : Sequence [Callable ]
169
231
) -> None :
@@ -185,39 +247,29 @@ def annotate(self, model: GraphModule) -> GraphModule:
185
247
def get_supported_ops (self ) -> Set [OpOverload ]:
186
248
return self .SUPPORTED_OPS
187
249
188
- def set_quant_config (
189
- self , quant_dtype : QuantDtype , is_qat = False , act_observer = None
250
+ def set_default_quant_config (
251
+ self ,
252
+ quant_dtype : QuantDtype ,
253
+ is_qat = False ,
254
+ is_conv_per_channel = False ,
255
+ is_linear_per_channel = False ,
256
+ act_observer = None ,
190
257
) -> None :
191
- self .quant_dtype = quant_dtype
192
- self .is_qat = is_qat
193
- if (quant_dtype , is_qat ) not in quant_config_dict :
194
- raise RuntimeError (
195
- f"the quant config, (quant_dtype: { quant_dtype } , is_qat: { is_qat } ) is not support"
196
- )
197
-
198
- quant_config_fuc , per_channel_quant_config_fuc = quant_config_dict [
199
- (quant_dtype , is_qat )
200
- ]
201
- self .quant_config = (
202
- quant_config_fuc (act_observer = act_observer )
203
- if act_observer
204
- else quant_config_fuc ()
205
- )
206
- self .per_channel_quant_config = (
207
- per_channel_quant_config_fuc (act_observer = act_observer )
208
- if act_observer
209
- else per_channel_quant_config_fuc ()
258
+ self .default_quant_config = ModuleQConfig (
259
+ quant_dtype ,
260
+ is_qat ,
261
+ is_conv_per_channel ,
262
+ is_linear_per_channel ,
263
+ act_observer ,
210
264
)
211
265
212
- def set_per_channel_conv_quant (self , enable : bool ) -> None :
213
- conv_ops = {torch .ops .aten .conv1d .default , torch .ops .aten .conv2d .default }
214
- self ._update_per_channel_weight_quant_ops (conv_ops , enable )
215
-
216
- def set_per_channel_linear_quant (self , enable : bool ) -> None :
217
- linear_ops = {
218
- torch .ops .aten .linear .default ,
219
- }
220
- self ._update_per_channel_weight_quant_ops (linear_ops , enable )
266
+ def set_submodule_quant_config (
267
+ self , submodule : torch .nn .Module , module_qconfig : ModuleQConfig
268
+ ) -> None :
269
+ """
270
+ Set the quant config specific for a submodule
271
+ """
272
+ self .module_qconfig_dict [submodule ] = module_qconfig
221
273
222
274
def transform_for_annotation (self , model : GraphModule ) -> GraphModule :
223
275
model = ReduceDynamicRange ()(model ).graph_module
0 commit comments