Skip to content

Commit 30dc9f2

Browse files
chunit-quicChun-I Tsai
authored andcommitted
Qualcomm AI Engine Direct - Add submodule quant config setting (#9355)
- Add API to qnn quantizer for setting submodule quant config - Refine QnnQuantizer setting functions --------- Co-authored-by: Chun-I Tsai <chunit@qti.qualcomm.com>
1 parent c2143db commit 30dc9f2

File tree

10 files changed

+253
-111
lines changed

10 files changed

+253
-111
lines changed

backends/qualcomm/_passes/decompose_einsum.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from executorch.exir.pass_base import ExportPass, PassResult
99
from torch.fx.experimental.proxy_tensor import make_fx
1010

11+
from .utils import copy_nn_module_stack
12+
1113

1214
class DecomposeEinsum(ExportPass):
1315
"""
@@ -36,6 +38,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
3638
remap[f"arg1_{i+1}"] = arg
3739

3840
for decomposed_node in decomposed_module.graph.nodes:
41+
copy_nn_module_stack(node, decomposed_node)
3942
# This is the arg[0] equation string, which is not required anymore after decomposition
4043
if "arg0" in decomposed_node.name:
4144
continue

backends/qualcomm/_passes/decompose_linalg_vector_norm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from executorch.exir import to_edge
99
from executorch.exir.pass_base import ExportPass, PassResult
1010

11+
from .utils import copy_nn_module_stack
12+
1113

1214
class LinalgVectorNorm(torch.nn.Module):
1315
def __init__(self, exp, dim, keepdim):
@@ -62,6 +64,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
6264
remap = {"x": node.args[0]}
6365

6466
for decomposed_node in decomposed_module.graph.nodes:
67+
copy_nn_module_stack(node, decomposed_node)
6568
# no need to copy existent 'output'
6669
if decomposed_node.op == "output":
6770
for user in node.users.copy():

backends/qualcomm/_passes/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,14 @@ def get_passes_dependency_for_capture_program():
121121
}
122122

123123

124+
def copy_nn_module_stack(src, target):
125+
"""
126+
Copy meta["nn_module_stack"] from src node to target node if existing.
127+
"""
128+
if value := src.meta.get("nn_module_stack"):
129+
target.meta["nn_module_stack"] = value
130+
131+
124132
def is_float_tensor(node: torch.fx.Node) -> bool:
125133
if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor):
126134
return False

backends/qualcomm/quantizer/quantizer.py

Lines changed: 143 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6+
from dataclasses import dataclass
67
from enum import IntEnum, unique
78
from functools import partial
8-
from typing import Callable, Dict, Optional, Sequence, Set, Tuple
9+
from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple
910

1011
import torch
1112
from executorch.backends.qualcomm._passes.qnn_pass_manager import QnnPassManager
@@ -58,7 +59,7 @@ class QuantDtype(IntEnum):
5859
use_8a8w = 4
5960

6061

61-
quant_config_dict = {
62+
QUANT_CONFIG_DICT = {
6263
# PTQ
6364
(QuantDtype.use_16a16w, False): (
6465
get_16a16w_qnn_ptq_config,
@@ -123,21 +124,71 @@ class QuantDtype(IntEnum):
123124
}
124125

125126

127+
@dataclass
128+
class ModuleQConfig:
129+
quant_dtype: QuantDtype = QuantDtype.use_8a8w
130+
is_qat: bool = False
131+
is_conv_per_channel: bool = False
132+
is_linear_per_channel: bool = False
133+
act_observer: Optional[
134+
torch.ao.quantization.observer.UniformQuantizationObserverBase
135+
] = None
136+
137+
def __post_init__(self):
138+
if (self.quant_dtype, self.is_qat) not in QUANT_CONFIG_DICT:
139+
raise RuntimeError(
140+
f"the quant config, (quant_dtype: {self.quant_dtype}, is_qat: {self.is_qat}) is not support"
141+
)
142+
(
143+
quant_config_func,
144+
per_channel_quant_config_func,
145+
per_block_quant_config_func,
146+
) = QUANT_CONFIG_DICT[(self.quant_dtype, self.is_qat)]
147+
self.quant_config = (
148+
quant_config_func(act_observer=self.act_observer)
149+
if self.act_observer
150+
else quant_config_func()
151+
)
152+
self.per_channel_quant_config = (
153+
per_channel_quant_config_func(act_observer=self.act_observer)
154+
if self.act_observer
155+
else per_channel_quant_config_func()
156+
)
157+
self.use_per_channel_weight_quant_ops = set()
158+
if self.is_conv_per_channel:
159+
self.use_per_channel_weight_quant_ops.update(
160+
{
161+
torch.ops.aten.conv1d.default,
162+
torch.ops.aten.conv2d.default,
163+
torch.ops.aten.conv_transpose2d.input,
164+
}
165+
)
166+
if self.is_linear_per_channel:
167+
self.use_per_channel_weight_quant_ops.update(
168+
{
169+
torch.ops.aten.linear.default,
170+
}
171+
)
172+
if per_block_quant_config_func:
173+
self.per_block_quant_config = (
174+
per_block_quant_config_func(act_observer=self.act_observer)
175+
if self.act_observer
176+
else per_block_quant_config_func()
177+
)
178+
179+
126180
class QnnQuantizer(Quantizer):
127181
SUPPORTED_OPS: Set = set(OP_ANNOTATOR.keys())
128182

129183
def __init__(self):
130184
super().__init__()
131185
self.quant_ops: Set[OpOverload] = self.SUPPORTED_OPS.copy()
132186

133-
self.is_qat = False
134-
self.quant_dtype = QuantDtype.use_8a8w
135-
self.quant_config: QuantizationConfig = get_8a8w_qnn_ptq_config()
136-
self.per_channel_quant_config = get_ptq_per_channel_quant_config()
137-
self.per_block_quant_config = get_ptq_per_block_quant_config()
187+
self.default_quant_config = ModuleQConfig()
188+
self.submodule_qconfig_list: List[
189+
Tuple[Callable[[torch.fx.Node], bool], ModuleQConfig]
190+
] = []
138191
self.block_size_map = {}
139-
self.use_per_channel_weight_quant_ops: Set[OpOverload] = set()
140-
self.use_per_block_weight_quant_ops: Set[OpOverload] = set()
141192

142193
self.custom_quant_annotations: Sequence[Callable] = []
143194
self.discard_nodes: Set[str] = set()
@@ -155,41 +206,38 @@ def _annotate_custom_annotation(self, gm: GraphModule) -> None:
155206
for annotation_func in self.custom_quant_annotations:
156207
annotation_func(gm)
157208

158-
def _get_quant_config(self, op: torch.fx.Node) -> Optional[QuantizationConfig]:
209+
def _get_submodule_qconfig(self, node: torch.fx.Node):
210+
for func, qconfig in self.submodule_qconfig_list:
211+
if func(node):
212+
return qconfig
213+
return self.default_quant_config
214+
215+
def _get_quant_config(self, node: torch.fx.Node) -> Optional[QuantizationConfig]:
159216
"""
160-
Priority:
161-
1. is one of use_per_block_weight_quant_ops
162-
2. is one of use_per_channel_weight_quant_ops
163-
3. quant config
217+
How to pick:
218+
1. is one of per_block_quant_config
219+
2. Pick specific submodule config if given.
220+
3. Pick one if op belongs to use_per_channel_weight_quant_ops
221+
4. If not 3, pick normal quant config
164222
"""
165-
target = op.target
166-
if isinstance(target, str):
223+
op = node.target
224+
if isinstance(op, str):
167225
return
168226

169-
if target in self.use_per_block_weight_quant_ops:
170-
if block_size := self.block_size_map.get(op.name):
171-
self.per_block_quant_config.block_size = block_size
172-
return self.per_block_quant_config
227+
if block_size := self.block_size_map.get(node.name):
228+
config = self.default_quant_config.per_block_quant_config
229+
config.block_size = block_size
230+
return config
173231

174-
if target in self.use_per_channel_weight_quant_ops:
175-
return self.per_channel_quant_config
232+
config = self._get_submodule_qconfig(node)
176233

177-
if target in self.quant_ops:
178-
return self.quant_config
234+
if op in config.use_per_channel_weight_quant_ops:
235+
return config.per_channel_quant_config
179236

180-
print(f"No quant config is implemented for op, {op}")
181-
182-
def _update_per_block_weight_quant_ops(self, ops: Set[OpOverload], enable: bool):
183-
if enable:
184-
self.use_per_block_weight_quant_ops.update(ops)
185-
else:
186-
self.use_per_block_weight_quant_ops.difference_update(ops)
237+
if op in self.quant_ops:
238+
return config.quant_config
187239

188-
def _update_per_channel_weight_quant_ops(self, ops: Set[OpOverload], enable: bool):
189-
if enable:
190-
self.use_per_channel_weight_quant_ops.update(ops)
191-
else:
192-
self.use_per_channel_weight_quant_ops.difference_update(ops)
240+
print(f"No quant config is implemented for op, {op}")
193241

194242
def add_custom_quant_annotations(
195243
self, custom_quant_annotations: Sequence[Callable]
@@ -212,55 +260,74 @@ def annotate(self, model: GraphModule) -> GraphModule:
212260
def get_supported_ops(self) -> Set[OpOverload]:
213261
return self.SUPPORTED_OPS
214262

215-
def set_quant_config(
216-
self, quant_dtype: QuantDtype, is_qat=False, act_observer=None
263+
def set_default_quant_config(
264+
self,
265+
quant_dtype: QuantDtype,
266+
is_qat=False,
267+
is_conv_per_channel=False,
268+
is_linear_per_channel=False,
269+
act_observer=None,
217270
) -> None:
218-
self.quant_dtype = quant_dtype
219-
self.is_qat = is_qat
220-
if (quant_dtype, is_qat) not in quant_config_dict:
221-
raise RuntimeError(
222-
f"the quant config, (quant_dtype: {quant_dtype}, is_qat: {is_qat}) is not support"
223-
)
224-
225-
quant_config_fuc, per_channel_quant_config_fuc, per_block_quant_config_fuc = (
226-
quant_config_dict[(quant_dtype, is_qat)]
227-
)
228-
self.quant_config = (
229-
quant_config_fuc(act_observer=act_observer)
230-
if act_observer
231-
else quant_config_fuc()
271+
self.default_quant_config = ModuleQConfig(
272+
quant_dtype,
273+
is_qat,
274+
is_conv_per_channel,
275+
is_linear_per_channel,
276+
act_observer,
232277
)
233-
self.per_channel_quant_config = (
234-
per_channel_quant_config_fuc(act_observer=act_observer)
235-
if act_observer
236-
else per_channel_quant_config_fuc()
237-
)
238-
if per_block_quant_config_fuc is not None:
239-
self.per_block_quant_config = (
240-
per_block_quant_config_fuc(act_observer=act_observer)
241-
if act_observer
242-
else per_block_quant_config_fuc()
243-
)
244278

245279
def set_block_size_map(self, block_size_map: Dict[str, Tuple]) -> None:
246280
self.block_size_map = block_size_map
247281

248-
def set_per_block_conv_quant(self, enable: bool) -> None:
249-
conv_ops = {torch.ops.aten.conv2d.default}
250-
self._update_per_block_weight_quant_ops(conv_ops, enable)
251-
252-
def set_per_channel_conv_quant(self, enable: bool) -> None:
253-
conv_ops = {torch.ops.aten.conv1d.default, torch.ops.aten.conv2d.default}
254-
self._update_per_channel_weight_quant_ops(conv_ops, enable)
255-
256-
def set_per_channel_linear_quant(self, enable: bool) -> None:
257-
linear_ops = {
258-
torch.ops.aten.linear.default,
259-
}
260-
self._update_per_channel_weight_quant_ops(linear_ops, enable)
282+
def set_submodule_qconfig_list(
283+
self, submodule_qconfig_list: List[Tuple[Callable, ModuleQConfig]]
284+
) -> None:
285+
"""
286+
Set specific quant config from a callback function.
287+
If a node fits more than one callback, only apply the first one.
288+
"""
289+
self.submodule_qconfig_list = submodule_qconfig_list
261290

262291
def transform_for_annotation(self, model: GraphModule) -> GraphModule:
263292
return QnnPassManager().transform_for_annotation_pipeline(model)
264293

265294
def validate(self, model: GraphModule) -> None:
266295
pass
296+
297+
298+
def get_submodule_type_predicate(module_type_str):
299+
"""
300+
An example of nn_module_stack
301+
{
302+
'L__self__': ('', 'executorch.backends.qualcomm.tests.models.SubModules'),
303+
'L__self___add': ('add', 'executorch.backends.qualcomm.tests.models.Add')
304+
}
305+
"""
306+
307+
def predicate(node):
308+
if nn_module_stack := node.meta.get("nn_module_stack"):
309+
for _, type_name in nn_module_stack.values():
310+
if module_type_str in type_name:
311+
return True
312+
return False
313+
314+
return predicate
315+
316+
317+
def get_submodule_name_predicate(module_name_str):
318+
"""
319+
An example of nn_module_stack
320+
{
321+
'L__self__': ('', 'executorch.backends.qualcomm.tests.models.SubModules'),
322+
'L__self___add': ('add', 'executorch.backends.qualcomm.tests.models.Add')
323+
}
324+
"""
325+
326+
def predicate(node):
327+
if nn_module_stack := node.meta.get("nn_module_stack"):
328+
for name in nn_module_stack.keys():
329+
if module_name_str in name:
330+
return True
331+
return False
332+
333+
return predicate

backends/qualcomm/tests/models.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1460,6 +1460,18 @@ def forward(self, x):
14601460
return 10 - x
14611461

14621462

1463+
class SimpleSubModules(torch.nn.Module):
1464+
def __init__(self):
1465+
super().__init__()
1466+
self.add = Add()
1467+
self.sub = Sub()
1468+
1469+
def forward(self, a, b, c, d):
1470+
lhs = self.add(a, b)
1471+
rhs = self.sub(c, d)
1472+
return torch.mul(lhs, rhs)
1473+
1474+
14631475
class SumIntList(torch.nn.Module):
14641476
def __init__(self):
14651477
super().__init__()

0 commit comments

Comments
 (0)