Skip to content

Commit fc5c30b

Browse files
Qualcomm AI Engine Direct - Support tile op for different I/O rank
Summary: - Support if the rank of input tensor is less than the rank of output tensor. - make_quantizer kwargs alignment. - Remove module.eval() since calling eval() is not supported for exported models.
1 parent c9c5481 commit fc5c30b

File tree

3 files changed

+56
-19
lines changed

3 files changed

+56
-19
lines changed

backends/qualcomm/_passes/expand_broadcast_tensor_shape.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,29 +5,38 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import torch
8+
9+
from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS
810
from executorch.exir.dialects._ops import ops as exir_ops
911
from executorch.exir.pass_base import ExportPass, PassResult
1012
from executorch.exir.passes import dead_code_elimination_pass
1113

14+
from .utils import dq_ops, get_quant_attrs
15+
1216

1317
class ExpandBroadcastTensorShape(ExportPass):
1418
"""
1519
Make tensors have same rank for layout-transform to work properly.
1620
"""
1721

18-
def __init__(self):
22+
def __init__(self, edge_program):
1923
super(ExpandBroadcastTensorShape, self).__init__()
2024
self.broadcast_op_targets = [
2125
exir_ops.edge.aten.add.Tensor,
2226
exir_ops.edge.aten.sub.Tensor,
2327
exir_ops.edge.aten.mul.Tensor,
2428
exir_ops.edge.aten.div.Tensor,
29+
# Support if the rank of input tensor: {input_dims} is less than the rank of output tensor: {output_dims}.
30+
exir_ops.edge.aten.expand_copy.default,
2531
]
32+
self.edge_program = edge_program
2633

2734
def traverse_broadcast_node(self, graph_module: torch.fx.GraphModule):
2835
for node in graph_module.graph.nodes:
2936
if node.target in self.broadcast_op_targets:
3037
for arg in node.args:
38+
if not isinstance(arg, torch.fx.Node):
39+
continue
3140
input_rank = len(arg.meta["val"].shape)
3241
output_rank = len(node.meta["val"].shape)
3342
if input_rank != output_rank:
@@ -45,6 +54,9 @@ def traverse_broadcast_node(self, graph_module: torch.fx.GraphModule):
4554
# to be updated correctly and not affect meta of arg
4655
for k, v in arg.meta.items():
4756
reshape_node.meta[k] = v
57+
if arg.target in dq_ops:
58+
quant_attrs = get_quant_attrs(self.edge_program, arg)
59+
reshape_node.meta[QCOM_QUANT_ATTRS] = quant_attrs
4860
reshape_node.meta["val"] = reshape_node.meta["val"].reshape(
4961
new_rank
5062
)

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,11 @@
6969
from collections import defaultdict
7070
from typing import List
7171

72-
from executorch.backends.qualcomm._passes import FoldQDQ, TagQuantIO
72+
from executorch.backends.qualcomm._passes import (
73+
ExpandBroadcastTensorShape,
74+
FoldQDQ,
75+
TagQuantIO,
76+
)
7377
from executorch.backends.qualcomm.builders.node_visitor import get_node_visitors
7478
from executorch.backends.qualcomm.debugger.utils import DrawGraph
7579
from executorch.examples.models.deeplab_v3 import DeepLabV3ResNet101Model
@@ -430,10 +434,20 @@ def test_qnn_backend_equal(self):
430434

431435
def test_qnn_backend_expand(self):
432436
modules = [ExpandAs(), ExpandCopy()] # noqa: F405
433-
sample_input = (torch.randn([3, 1]),)
434-
for i, module in enumerate(modules):
435-
with self.subTest(i=i):
436-
self.lower_module_and_test_output(module, sample_input)
437+
sample_inputs = [
438+
(torch.randn([3, 1]),),
439+
(torch.randn([4]),),
440+
]
441+
passes_job = get_capture_program_passes()
442+
passes_job[ExpandBroadcastTensorShape][QCOM_PASS_ACTIVATE_KEY] = True
443+
index = 0
444+
for module in modules:
445+
for sample_input in sample_inputs:
446+
with self.subTest(i=index):
447+
self.lower_module_and_test_output(
448+
module, sample_input, passes_job=passes_job
449+
)
450+
index += 1
437451

438452
def test_qnn_backend_expm1(self):
439453
sample_input = (torch.randn(3, 4, 5),)
@@ -1506,11 +1520,21 @@ def test_qnn_backend_equal(self):
15061520

15071521
def test_qnn_backend_expand(self):
15081522
modules = [ExpandAs(), ExpandCopy()] # noqa: F405
1509-
sample_input = (torch.randn([3, 1]),)
1510-
for i, module in enumerate(modules):
1511-
with self.subTest(i=i):
1512-
module = self.get_qdq_module(module, sample_input)
1513-
self.lower_module_and_test_output(module, sample_input)
1523+
sample_inputs = [
1524+
(torch.randn([3, 1]),),
1525+
(torch.randn([4]),),
1526+
]
1527+
passes_job = get_capture_program_passes()
1528+
passes_job[ExpandBroadcastTensorShape][QCOM_PASS_ACTIVATE_KEY] = True
1529+
index = 0
1530+
for module in modules:
1531+
for sample_input in sample_inputs:
1532+
with self.subTest(i=index):
1533+
module = self.get_qdq_module(module, sample_input)
1534+
self.lower_module_and_test_output(
1535+
module, sample_input, passes_job=passes_job
1536+
)
1537+
index += 1
15141538

15151539
def test_qnn_backend_expm1(self):
15161540
sample_input = (torch.randn(3, 4, 5),)

backends/qualcomm/tests/utils.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import subprocess
1010
import tempfile
1111
import unittest
12-
from typing import Callable, Dict, List, Optional, Tuple
12+
from typing import Callable, Dict, List, Optional, OrderedDict, Tuple
1313

1414
import numpy as np
1515
import torch
@@ -435,6 +435,7 @@ def lower_module_and_test_output(
435435
expected_profile_events: int = -1,
436436
expected_intermediate_events: int = -1,
437437
assert_output_equal: bool = True,
438+
passes_job: Optional[OrderedDict] = None,
438439
skip_node_id_set: set = None,
439440
skip_node_op_set: set = None,
440441
dynamic_shapes: Dict = None,
@@ -444,6 +445,7 @@ def lower_module_and_test_output(
444445
sample_inputs,
445446
self.compiler_specs,
446447
dynamic_shapes=dynamic_shapes,
448+
passes_job=passes_job,
447449
skip_node_id_set=skip_node_id_set,
448450
skip_node_op_set=skip_node_op_set,
449451
)
@@ -504,9 +506,8 @@ def get_qdq_module(
504506
dynamic_shapes: Dict = None,
505507
bypass_check: bool = False,
506508
block_size_map: Dict[str, Tuple] = None,
507-
submodule_qconfig_list: Optional[List[Tuple[Callable, ModuleQConfig]]] = None,
509+
callback_qconfig_list: Optional[List[Tuple[Callable, ModuleQConfig]]] = None,
508510
) -> torch.fx.GraphModule:
509-
module = module.eval()
510511
m = torch.export.export(
511512
module, inputs, dynamic_shapes=dynamic_shapes, strict=True
512513
).module()
@@ -516,7 +517,7 @@ def get_qdq_module(
516517
custom_annotations=custom_quant_annotations,
517518
per_channel_conv=is_conv_per_channel,
518519
per_channel_linear=is_linear_per_channel,
519-
submodule_qconfig_list=submodule_qconfig_list,
520+
callback_qconfig_list=callback_qconfig_list,
520521
)
521522
if block_size_map is not None:
522523
quantizer.set_block_size_map(block_size_map)
@@ -544,7 +545,7 @@ def get_prepared_qat_module(
544545
is_linear_per_channel: Optional[bool] = False,
545546
custom_quant_annotations: Tuple[Callable] = (),
546547
quant_dtype: QuantDtype = QuantDtype.use_8a8w,
547-
submodule_qconfig_list: Optional[List[Tuple[Callable, ModuleQConfig]]] = None,
548+
callback_qconfig_list: Optional[List[Tuple[Callable, ModuleQConfig]]] = None,
548549
) -> torch.fx.GraphModule:
549550
m = torch.export.export_for_training(module, inputs, strict=True).module()
550551

@@ -554,11 +555,11 @@ def get_prepared_qat_module(
554555
per_channel_conv=is_conv_per_channel,
555556
per_channel_linear=is_linear_per_channel,
556557
is_qat=True,
557-
submodule_qconfig_list=submodule_qconfig_list,
558+
callback_qconfig_list=callback_qconfig_list,
558559
)
559560

560-
submodule_qconfig_list = submodule_qconfig_list or []
561-
quantizer.set_submodule_qconfig_list(submodule_qconfig_list)
561+
callback_qconfig_list = callback_qconfig_list or []
562+
quantizer.set_submodule_qconfig_list(callback_qconfig_list)
562563

563564
prepared = prepare_qat_pt2e(m, quantizer)
564565
return torch.ao.quantization.move_exported_model_to_train(prepared)

0 commit comments

Comments
 (0)