Skip to content

Commit 4101be2

Browse files
author
Chun-I Tsai
committed
Qualcomm AI Engine Direct - Add submodule quant config setting
- Add API to qnn quantizer for setting submodule quant config
1 parent f789df2 commit 4101be2

File tree

11 files changed

+227
-85
lines changed

11 files changed

+227
-85
lines changed

backends/qualcomm/_passes/decompose_einsum.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from executorch.exir.pass_base import ExportPass, PassResult
99
from torch.fx.experimental.proxy_tensor import make_fx
1010

11+
from .utils import copy_nn_module_stack
12+
1113

1214
class DecomposeEinsum(ExportPass):
1315
"""
@@ -36,6 +38,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
3638
remap[f"arg1_{i+1}"] = arg
3739

3840
for decomposed_node in decomposed_module.graph.nodes:
41+
copy_nn_module_stack(node, decomposed_node)
3942
# This is the arg[0] equation string, which is not required anymore after decomposition
4043
if "arg0" in decomposed_node.name:
4144
continue

backends/qualcomm/_passes/decompose_linalg_vector_norm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from executorch.exir import to_edge
99
from executorch.exir.pass_base import ExportPass, PassResult
1010

11+
from .utils import copy_nn_module_stack
12+
1113

1214
class LinalgVectorNorm(torch.nn.Module):
1315
def __init__(self, exp, dim, keepdim):
@@ -62,6 +64,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
6264
remap = {"x": node.args[0]}
6365

6466
for decomposed_node in decomposed_module.graph.nodes:
67+
copy_nn_module_stack(node, decomposed_node)
6568
# no need to copy existent 'output'
6669
if decomposed_node.op == "output":
6770
for user in node.users.copy():

backends/qualcomm/_passes/utils.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66

77
import torch
88
from executorch.backends.qualcomm.builders.utils import get_parameter
9-
from executorch.backends.qualcomm.utils.constants import QCOM_ENCODING
9+
from executorch.backends.qualcomm.utils.constants import (
10+
QCOM_ENCODING,
11+
QCOM_NN_MODULE_STACK,
12+
)
1013
from executorch.exir.dialects._ops import ops as exir_ops
1114
from torch._subclasses import FakeTensor
1215

@@ -107,6 +110,14 @@ def get_passes_dependency_for_capture_program():
107110
}
108111

109112

113+
def copy_nn_module_stack(src, target):
114+
"""
115+
Copy meta["nn_module_stack"] from src node to target node if existing.
116+
"""
117+
if value := src.meta.get(QCOM_NN_MODULE_STACK):
118+
target.meta[QCOM_NN_MODULE_STACK] = value
119+
120+
110121
def is_float_tensor(node: torch.fx.Node) -> bool:
111122
if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor):
112123
return False

backends/qualcomm/quantizer/quantizer.py

Lines changed: 103 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6+
import importlib
7+
from dataclasses import dataclass
68
from enum import IntEnum, unique
79
from functools import partial
8-
from typing import Callable, Optional, Sequence, Set
10+
from typing import Callable, Dict, Optional, Sequence, Set
911

1012
import torch
1113
from executorch.backends.qualcomm._passes import (
@@ -66,7 +68,7 @@ class QuantDtype(IntEnum):
6668
use_8a8w = 3
6769

6870

69-
quant_config_dict = {
71+
QUANT_CONFIG_DICT = {
7072
# PTQ
7173
(QuantDtype.use_16a16w, False): (
7274
get_16a16w_qnn_ptq_config,
@@ -112,18 +114,60 @@ class QuantDtype(IntEnum):
112114
}
113115

114116

117+
@dataclass
118+
class ModuleQConfig:
119+
quant_dtype: QuantDtype = QuantDtype.use_8a8w
120+
is_qat: bool = False
121+
is_conv_per_channel: bool = False
122+
is_linear_per_channel: bool = False
123+
act_observer: Optional[
124+
torch.ao.quantization.observer.UniformQuantizationObserverBase
125+
] = None
126+
127+
def __post_init__(self):
128+
if (self.quant_dtype, self.is_qat) not in QUANT_CONFIG_DICT:
129+
raise RuntimeError(
130+
f"the quant config, (quant_dtype: {self.quant_dtype}, is_qat: {self.is_qat}) is not support"
131+
)
132+
quant_config_func, per_channel_quant_config_func = QUANT_CONFIG_DICT[
133+
(self.quant_dtype, self.is_qat)
134+
]
135+
self.quant_config = (
136+
quant_config_func(act_observer=self.act_observer)
137+
if self.act_observer
138+
else quant_config_func()
139+
)
140+
self.per_channel_quant_config = (
141+
per_channel_quant_config_func(act_observer=self.act_observer)
142+
if self.act_observer
143+
else per_channel_quant_config_func()
144+
)
145+
self.use_per_channel_weight_quant_ops = set()
146+
if self.is_conv_per_channel:
147+
self.use_per_channel_weight_quant_ops.update(
148+
{
149+
torch.ops.aten.conv1d.default,
150+
torch.ops.aten.conv2d.default,
151+
torch.ops.aten.conv_transpose2d.input,
152+
}
153+
)
154+
if self.is_linear_per_channel:
155+
self.use_per_channel_weight_quant_ops.update(
156+
{
157+
torch.ops.aten.linear.default,
158+
}
159+
)
160+
161+
115162
class QnnQuantizer(Quantizer):
116163
SUPPORTED_OPS: Set = set(OP_ANNOTATOR.keys())
117164

118165
def __init__(self):
119166
super().__init__()
120167
self.quant_ops: Set[OpOverload] = self.SUPPORTED_OPS.copy()
121168

122-
self.is_qat = False
123-
self.quant_dtype = QuantDtype.use_8a8w
124-
self.quant_config: QuantizationConfig = get_8a8w_qnn_ptq_config()
125-
self.per_channel_quant_config = get_ptq_per_channel_quant_config()
126-
self.use_per_channel_weight_quant_ops: Set[OpOverload] = set()
169+
self.default_quant_config = ModuleQConfig()
170+
self.module_qconfig_dict: Dict[torch.nn.Module, ModuleQConfig] = {}
127171

128172
self.custom_quant_annotations: Sequence[Callable] = []
129173
self.discard_nodes: Set[str] = set()
@@ -133,37 +177,55 @@ def _annotate(self, gm: GraphModule) -> None:
133177
if node.name in self.discard_nodes:
134178
continue
135179

136-
quant_config = self._get_quant_config(node.target)
180+
quant_config = self._get_quant_config(node)
137181
if quant_config:
138182
OP_ANNOTATOR[node.target](node, quant_config)
139183

140184
def _annotate_custom_annotation(self, gm: GraphModule) -> None:
141185
for annotation_func in self.custom_quant_annotations:
142186
annotation_func(gm)
143187

144-
def _get_quant_config(self, op: str | OpOverload) -> Optional[QuantizationConfig]:
188+
def _get_submodule(self, node: torch.fx.Node):
189+
"""
190+
An example of nn_module_stack
191+
{
192+
'L__self__': ('', 'executorch.backends.qualcomm.tests.models.SubModules'),
193+
'L__self___add': ('add', 'executorch.backends.qualcomm.tests.models.Add')
194+
}
195+
"""
196+
197+
nn_module_stack = node.meta.get("nn_module_stack")
198+
if nn_module_stack:
199+
module_source_str, module_str = list(nn_module_stack.values())[-1][
200+
-1
201+
].rsplit(".", 1)
202+
module_source = importlib.import_module(module_source_str)
203+
return getattr(module_source, module_str)
204+
return None
205+
206+
def _get_quant_config(self, node: torch.fx.Node) -> Optional[QuantizationConfig]:
145207
"""
146-
Priority:
147-
1. is one of use_per_channel_weight_quant_ops
148-
2. quant config
208+
How to pick:
209+
1. Choose specific submodule config if given.
210+
2. Pick one if op belongs to use_per_channel_weight_quant_ops
211+
3. If not 2, pick normal quant config
149212
"""
213+
op = node.target
150214
if isinstance(op, str):
151215
return
152216

153-
if op in self.use_per_channel_weight_quant_ops:
154-
return self.per_channel_quant_config
217+
config = self.module_qconfig_dict.get(
218+
self._get_submodule(node), self.default_quant_config
219+
)
220+
221+
if op in config.use_per_channel_weight_quant_ops:
222+
return config.per_channel_quant_config
155223

156224
if op in self.quant_ops:
157-
return self.quant_config
225+
return config.quant_config
158226

159227
print(f"No quant config is implemented for op, {op}")
160228

161-
def _update_per_channel_weight_quant_ops(self, ops: Set[OpOverload], enable: bool):
162-
if enable:
163-
self.use_per_channel_weight_quant_ops.update(ops)
164-
else:
165-
self.use_per_channel_weight_quant_ops.difference_update(ops)
166-
167229
def add_custom_quant_annotations(
168230
self, custom_quant_annotations: Sequence[Callable]
169231
) -> None:
@@ -185,39 +247,29 @@ def annotate(self, model: GraphModule) -> GraphModule:
185247
def get_supported_ops(self) -> Set[OpOverload]:
186248
return self.SUPPORTED_OPS
187249

188-
def set_quant_config(
189-
self, quant_dtype: QuantDtype, is_qat=False, act_observer=None
250+
def set_default_quant_config(
251+
self,
252+
quant_dtype: QuantDtype,
253+
is_qat=False,
254+
is_conv_per_channel=False,
255+
is_linear_per_channel=False,
256+
act_observer=None,
190257
) -> None:
191-
self.quant_dtype = quant_dtype
192-
self.is_qat = is_qat
193-
if (quant_dtype, is_qat) not in quant_config_dict:
194-
raise RuntimeError(
195-
f"the quant config, (quant_dtype: {quant_dtype}, is_qat: {is_qat}) is not support"
196-
)
197-
198-
quant_config_fuc, per_channel_quant_config_fuc = quant_config_dict[
199-
(quant_dtype, is_qat)
200-
]
201-
self.quant_config = (
202-
quant_config_fuc(act_observer=act_observer)
203-
if act_observer
204-
else quant_config_fuc()
205-
)
206-
self.per_channel_quant_config = (
207-
per_channel_quant_config_fuc(act_observer=act_observer)
208-
if act_observer
209-
else per_channel_quant_config_fuc()
258+
self.default_quant_config = ModuleQConfig(
259+
quant_dtype,
260+
is_qat,
261+
is_conv_per_channel,
262+
is_linear_per_channel,
263+
act_observer,
210264
)
211265

212-
def set_per_channel_conv_quant(self, enable: bool) -> None:
213-
conv_ops = {torch.ops.aten.conv1d.default, torch.ops.aten.conv2d.default}
214-
self._update_per_channel_weight_quant_ops(conv_ops, enable)
215-
216-
def set_per_channel_linear_quant(self, enable: bool) -> None:
217-
linear_ops = {
218-
torch.ops.aten.linear.default,
219-
}
220-
self._update_per_channel_weight_quant_ops(linear_ops, enable)
266+
def set_submodule_quant_config(
267+
self, submodule: torch.nn.Module, module_qconfig: ModuleQConfig
268+
) -> None:
269+
"""
270+
Set the quant config specific for a submodule
271+
"""
272+
self.module_qconfig_dict[submodule] = module_qconfig
221273

222274
def transform_for_annotation(self, model: GraphModule) -> GraphModule:
223275
model = ReduceDynamicRange()(model).graph_module

backends/qualcomm/tests/models.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1398,6 +1398,18 @@ def forward(self, x):
13981398
return 10 - x
13991399

14001400

1401+
class SimpleSubModules(torch.nn.Module):
1402+
def __init__(self):
1403+
super().__init__()
1404+
self.add = Add()
1405+
self.sub = Sub()
1406+
1407+
def forward(self, a, b, c, d):
1408+
lhs = self.add(a, b)
1409+
rhs = self.sub(c, d)
1410+
return torch.mul(lhs, rhs)
1411+
1412+
14011413
class SumIntList(torch.nn.Module):
14021414
def __init__(self):
14031415
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import torch
1616
from executorch.backends.qualcomm.tests.utils import (
1717
generate_context_binary,
18+
ModuleQConfig,
1819
QnnPartitioner,
1920
QnnQuantizer,
2021
QuantDtype,
@@ -1219,8 +1220,8 @@ def test_qnn_backend_element_wise_add(self):
12191220
for module in comb[QCOM_MODULE]:
12201221
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
12211222
with self.subTest(i=index):
1222-
module = self.get_qdq_module(module, sample_input)
1223-
self.lower_module_and_test_output(module, sample_input)
1223+
gm = self.get_qdq_module(module, sample_input)
1224+
self.lower_module_and_test_output(gm, sample_input)
12241225
index += 1
12251226

12261227
def test_qnn_backend_element_wise_ceil(self):
@@ -1251,8 +1252,8 @@ def test_qnn_backend_element_wise_div(self):
12511252
for module in comb[QCOM_MODULE]:
12521253
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
12531254
with self.subTest(i=index):
1254-
module = self.get_qdq_module(module, sample_input)
1255-
self.lower_module_and_test_output(module, sample_input)
1255+
gm = self.get_qdq_module(module, sample_input)
1256+
self.lower_module_and_test_output(gm, sample_input)
12561257
index += 1
12571258

12581259
def test_qnn_backend_element_wise_mul(self):
@@ -1279,8 +1280,8 @@ def test_qnn_backend_element_wise_mul(self):
12791280
for module in comb[QCOM_MODULE]:
12801281
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
12811282
with self.subTest(i=index):
1282-
module = self.get_qdq_module(module, sample_input)
1283-
self.lower_module_and_test_output(module, sample_input)
1283+
gm = self.get_qdq_module(module, sample_input)
1284+
self.lower_module_and_test_output(gm, sample_input)
12841285
index += 1
12851286

12861287
def test_qnn_backend_element_wise_or(self):
@@ -1339,8 +1340,8 @@ def test_qnn_backend_element_wise_sub(self):
13391340
for module in comb[QCOM_MODULE]:
13401341
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
13411342
with self.subTest(i=index):
1342-
module = self.get_qdq_module(module, sample_input)
1343-
self.lower_module_and_test_output(module, sample_input)
1343+
gm = self.get_qdq_module(module, sample_input)
1344+
self.lower_module_and_test_output(gm, sample_input)
13441345
index += 1
13451346

13461347
def test_qnn_backend_embedding(self):
@@ -1985,6 +1986,23 @@ def test_qnn_backend_simple_model(self):
19851986
module = self.get_qdq_module(module, sample_input)
19861987
self.lower_module_and_test_output(module, sample_input)
19871988

1989+
def test_qnn_backend_submodules(self):
1990+
module = SimpleSubModules() # noqa: F405
1991+
sample_input = (
1992+
torch.rand(1, 3, 8, 8),
1993+
torch.rand(1, 3, 8, 8),
1994+
torch.rand(1, 3, 8, 8),
1995+
torch.rand(1, 3, 8, 8),
1996+
)
1997+
1998+
submodule_quant_config = {
1999+
Add: ModuleQConfig(QuantDtype.use_16a16w) # noqa: F405
2000+
}
2001+
module = self.get_qdq_module(
2002+
module, sample_input, submodule_quant_config=submodule_quant_config
2003+
)
2004+
self.lower_module_and_test_output(module, sample_input)
2005+
19882006
def test_qnn_backend_topk_and_index(self):
19892007
module = TopKandIndex() # noqa: F405
19902008
sample_input = (torch.randn(3, 10),)

0 commit comments

Comments
 (0)