Skip to content

Commit 89673b6

Browse files
author
Chun-I Tsai
committed
Fix based on comments
- Change to string based way to set up qconfig for submodule
1 parent 0928239 commit 89673b6

File tree

6 files changed

+89
-59
lines changed

6 files changed

+89
-59
lines changed

backends/qualcomm/_passes/utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from executorch.backends.qualcomm.utils.constants import (
1212
QCOM_DTYPE,
1313
QCOM_ENCODING,
14-
QCOM_NN_MODULE_STACK,
1514
)
1615
from executorch.exir.dialects._ops import ops as exir_ops
1716
from torch._subclasses import FakeTensor
@@ -130,8 +129,8 @@ def copy_nn_module_stack(src, target):
130129
"""
131130
Copy meta["nn_module_stack"] from src node to target node if existing.
132131
"""
133-
if value := src.meta.get(QCOM_NN_MODULE_STACK):
134-
target.meta[QCOM_NN_MODULE_STACK] = value
132+
if value := src.meta.get("nn_module_stack"):
133+
target.meta["nn_module_stack"] = value
135134

136135

137136
def is_float_tensor(node: torch.fx.Node) -> bool:

backends/qualcomm/quantizer/quantizer.py

Lines changed: 62 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6-
import importlib
76
from dataclasses import dataclass
87
from enum import IntEnum, unique
98
from functools import partial
10-
from typing import Callable, Dict, Optional, Sequence, Set, Tuple
9+
from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple
1110

1211
import torch
1312
from executorch.backends.qualcomm._passes import (
@@ -153,9 +152,11 @@ def __post_init__(self):
153152
raise RuntimeError(
154153
f"the quant config, (quant_dtype: {self.quant_dtype}, is_qat: {self.is_qat}) is not support"
155154
)
156-
quant_config_func, per_channel_quant_config_func, per_block_quant_config_func = QUANT_CONFIG_DICT[
157-
(self.quant_dtype, self.is_qat)
158-
]
155+
(
156+
quant_config_func,
157+
per_channel_quant_config_func,
158+
per_block_quant_config_func,
159+
) = QUANT_CONFIG_DICT[(self.quant_dtype, self.is_qat)]
159160
self.quant_config = (
160161
quant_config_func(act_observer=self.act_observer)
161162
if self.act_observer
@@ -197,7 +198,9 @@ def __init__(self):
197198
self.quant_ops: Set[OpOverload] = self.SUPPORTED_OPS.copy()
198199

199200
self.default_quant_config = ModuleQConfig()
200-
self.module_qconfig_dict: Dict[torch.nn.Module, ModuleQConfig] = {}
201+
self.submodule_qconfig_list: List[
202+
Tuple[Callable[[torch.fx.Node], bool], ModuleQConfig]
203+
] = []
201204
self.block_size_map = {}
202205

203206
self.custom_quant_annotations: Sequence[Callable] = []
@@ -216,44 +219,30 @@ def _annotate_custom_annotation(self, gm: GraphModule) -> None:
216219
for annotation_func in self.custom_quant_annotations:
217220
annotation_func(gm)
218221

219-
def _get_submodule(self, node: torch.fx.Node):
220-
"""
221-
An example of nn_module_stack
222-
{
223-
'L__self__': ('', 'executorch.backends.qualcomm.tests.models.SubModules'),
224-
'L__self___add': ('add', 'executorch.backends.qualcomm.tests.models.Add')
225-
}
226-
"""
227-
228-
nn_module_stack = node.meta.get("nn_module_stack")
229-
if nn_module_stack:
230-
module_source_str, module_str = list(nn_module_stack.values())[-1][
231-
-1
232-
].rsplit(".", 1)
233-
module_source = importlib.import_module(module_source_str)
234-
return getattr(module_source, module_str)
235-
return None
222+
def _get_submodule_qconfig(self, node: torch.fx.Node):
223+
for func, qconfig in self.submodule_qconfig_list:
224+
if func(node):
225+
return qconfig
226+
return self.default_quant_config
236227

237228
def _get_quant_config(self, node: torch.fx.Node) -> Optional[QuantizationConfig]:
238229
"""
239230
How to pick:
240-
1. is one of use_per_block_weight_quant_ops
241-
2. Choose specific submodule config if given.
231+
1. is one of per_block_quant_config
232+
2. Pick specific submodule config if given.
242233
3. Pick one if op belongs to use_per_channel_weight_quant_ops
243-
4. If not 2, pick normal quant config
234+
4. If not 3, pick normal quant config
244235
"""
245236
op = node.target
246237
if isinstance(op, str):
247238
return
248239

249-
if block_size := self.block_size_map.get(op.name):
240+
if block_size := self.block_size_map.get(node.name):
250241
config = self.default_quant_config.per_block_quant_config
251242
config.block_size = block_size
252243
return config
253244

254-
config = self.module_qconfig_dict.get(
255-
self._get_submodule(node), self.default_quant_config
256-
)
245+
config = self._get_submodule_qconfig(node)
257246

258247
if op in config.use_per_channel_weight_quant_ops:
259248
return config.per_channel_quant_config
@@ -303,13 +292,14 @@ def set_default_quant_config(
303292
def set_block_size_map(self, block_size_map: Dict[str, Tuple]) -> None:
304293
self.block_size_map = block_size_map
305294

306-
def set_submodule_quant_config(
307-
self, submodule: torch.nn.Module, module_qconfig: ModuleQConfig
295+
def set_submodule_qconfig_list(
296+
self, submodule_qconfig_list: List[Tuple[Callable, ModuleQConfig]]
308297
) -> None:
309298
"""
310-
Set the quant config specific for a submodule
299+
Set specific quant config from a callback function.
300+
If a node fits more than one callback, only apply the first one.
311301
"""
312-
self.module_qconfig_dict[submodule] = module_qconfig
302+
self.submodule_qconfig_list = submodule_qconfig_list
313303

314304
def transform_for_annotation(self, model: GraphModule) -> GraphModule:
315305
model = ReduceDynamicRange()(model).graph_module
@@ -326,3 +316,41 @@ def transform_for_annotation(self, model: GraphModule) -> GraphModule:
326316

327317
def validate(self, model: GraphModule) -> None:
328318
pass
319+
320+
321+
def get_submodule_type_predicate(module_type_str):
322+
"""
323+
An example of nn_module_stack
324+
{
325+
'L__self__': ('', 'executorch.backends.qualcomm.tests.models.SubModules'),
326+
'L__self___add': ('add', 'executorch.backends.qualcomm.tests.models.Add')
327+
}
328+
"""
329+
330+
def predicate(node):
331+
if nn_module_stack := node.meta.get("nn_module_stack"):
332+
for _, type_name in nn_module_stack.values():
333+
if module_type_str in type_name:
334+
return True
335+
return False
336+
337+
return predicate
338+
339+
340+
def get_submodule_name_predicate(module_name_str):
341+
"""
342+
An example of nn_module_stack
343+
{
344+
'L__self__': ('', 'executorch.backends.qualcomm.tests.models.SubModules'),
345+
'L__self___add': ('add', 'executorch.backends.qualcomm.tests.models.Add')
346+
}
347+
"""
348+
349+
def predicate(node):
350+
if nn_module_stack := node.meta.get("nn_module_stack"):
351+
for name in nn_module_stack.keys():
352+
if module_name_str in name:
353+
return True
354+
return False
355+
356+
return predicate

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2118,11 +2118,20 @@ def test_qnn_backend_submodules(self):
21182118
torch.rand(1, 3, 8, 8),
21192119
)
21202120

2121-
submodule_quant_config = {
2122-
Add: ModuleQConfig(QuantDtype.use_16a16w) # noqa: F405
2123-
}
2121+
from executorch.backends.qualcomm.quantizer.quantizer import (
2122+
get_submodule_type_predicate,
2123+
)
2124+
2125+
submodule_qconfig_list = [
2126+
(
2127+
get_submodule_type_predicate("Add"),
2128+
ModuleQConfig(QuantDtype.use_16a16w),
2129+
) # noqa: F405
2130+
]
21242131
module = self.get_qdq_module(
2125-
module, sample_input, submodule_quant_config=submodule_quant_config
2132+
module,
2133+
sample_input,
2134+
submodule_qconfig_list=submodule_qconfig_list,
21262135
)
21272136
self.lower_module_and_test_output(module, sample_input)
21282137

backends/qualcomm/tests/utils.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,7 @@
1717
from executorch import exir
1818
from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner
1919
from executorch.backends.qualcomm.qnn_preprocess import QnnBackend
20-
from executorch.backends.qualcomm.quantizer.quantizer import (
21-
ModuleQConfig,
22-
QnnQuantizer,
23-
QuantDtype,
24-
)
20+
from executorch.backends.qualcomm.quantizer.quantizer import ModuleQConfig, QuantDtype
2521
from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset
2622
from executorch.backends.qualcomm.utils.constants import (
2723
QCOM_DTYPE,
@@ -531,8 +527,9 @@ def get_qdq_module(
531527
dynamic_shapes: Dict = None,
532528
bypass_check: bool = False,
533529
block_size_map: Dict[str, Tuple] = None,
534-
submodule_quant_config: Optional[Dict[torch.nn.Module, ModuleQConfig]] = None,
530+
submodule_qconfig_list: Optional[List[Tuple[Callable, ModuleQConfig]]] = None,
535531
) -> torch.fx.GraphModule:
532+
module = module.eval()
536533
m = torch.export.export(
537534
module, inputs, dynamic_shapes=dynamic_shapes, strict=True
538535
).module()
@@ -542,7 +539,7 @@ def get_qdq_module(
542539
custom_annotations=custom_quant_annotations,
543540
per_channel_conv=is_conv_per_channel,
544541
per_channel_linear=is_linear_per_channel,
545-
submodule_quant_config = submodule_quant_config,
542+
submodule_qconfig_list=submodule_qconfig_list,
546543
)
547544
if block_size_map is not None:
548545
quantizer.set_block_size_map(block_size_map)
@@ -570,7 +567,7 @@ def get_prepared_qat_module(
570567
is_linear_per_channel: Optional[bool] = False,
571568
custom_quant_annotations: Tuple[Callable] = (),
572569
quant_dtype: QuantDtype = QuantDtype.use_8a8w,
573-
submodule_quant_config: Optional[Dict[str, ModuleQConfig]] = None,
570+
submodule_qconfig_list: Optional[List[Tuple[Callable, ModuleQConfig]]] = None,
574571
) -> torch.fx.GraphModule:
575572
m = torch.export.export_for_training(module, inputs).module()
576573

@@ -580,12 +577,11 @@ def get_prepared_qat_module(
580577
per_channel_conv=is_conv_per_channel,
581578
per_channel_linear=is_linear_per_channel,
582579
is_qat=True,
583-
submodule_quant_config=submodule_quant_config
580+
submodule_qconfig_list=submodule_qconfig_list,
584581
)
585582

586-
submodule_quant_config = submodule_quant_config or {}
587-
for submodule, module_qconfig in submodule_quant_config.items():
588-
quantizer.set_submodule_quant_config(submodule, module_qconfig)
583+
submodule_qconfig_list = submodule_qconfig_list or []
584+
quantizer.set_submodule_qconfig_list(submodule_qconfig_list)
589585

590586
prepared = prepare_qat_pt2e(m, quantizer)
591587
return torch.ao.quantization.move_exported_model_to_train(prepared)

backends/qualcomm/utils/constants.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
QCOM_INSERTED_PERMUTE = "qnn_permute"
2222
QCOM_LAYOUT_CHANGE = "layout_change"
2323
QCOM_NUM_BLOCKS_PER_AXIS = "num_blocks_per_axis"
24-
QCOM_NN_MODULE_STACK = "nn_module_stack"
2524
QCOM_OFFSET = "offset"
2625
QCOM_ORIG_DTYPE = "orig_dtype"
2726
QCOM_QUANTIZED_IO = "q_tensor_io"

examples/qualcomm/utils.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import tempfile
1515
from pathlib import Path
1616

17-
from typing import Callable, Dict, List, Optional
17+
from typing import Callable, List, Optional, Tuple
1818

1919
import numpy as np
2020

@@ -281,7 +281,7 @@ def make_quantizer(
281281
per_channel_linear=False,
282282
act_observer=MovingAverageMinMaxObserver,
283283
is_qat=False,
284-
submodule_quant_config: Optional[Dict[str, ModuleQConfig]] = None,
284+
callback_qconfig_list: Optional[List[Tuple[Callable, ModuleQConfig]]] = None,
285285
):
286286
quantizer = QnnQuantizer()
287287
quantizer.add_custom_quant_annotations(custom_annotations)
@@ -292,9 +292,8 @@ def make_quantizer(
292292
is_linear_per_channel=per_channel_linear,
293293
act_observer=act_observer,
294294
)
295-
submodule_quant_config = submodule_quant_config or {}
296-
for submodule, module_qconfig in submodule_quant_config.items():
297-
quantizer.set_submodule_quant_config(submodule, module_qconfig)
295+
callback_qconfig_list = callback_qconfig_list or []
296+
quantizer.set_submodule_qconfig_list(callback_qconfig_list)
298297
return quantizer
299298

300299

0 commit comments

Comments
 (0)