Skip to content

Qualcomm AI Engine Direct - Support tile op for different I/O rank #10054

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions backends/qualcomm/_passes/expand_broadcast_tensor_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,16 @@ def __init__(self):
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.div.Tensor,
# Support if the rank of input tensor: {input_dims} is less than the rank of output tensor: {output_dims}.
exir_ops.edge.aten.expand_copy.default,
]

def traverse_broadcast_node(self, graph_module: torch.fx.GraphModule):
for node in graph_module.graph.nodes:
if node.target in self.broadcast_op_targets:
for arg in node.args:
if not isinstance(arg, torch.fx.Node):
continue
input_rank = len(arg.meta["val"].shape)
output_rank = len(node.meta["val"].shape)
if input_rank != output_rank:
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/_passes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def get_passes_dependency_for_capture_program():
ConvertConv1dToConv2d: [FoldQDQ],
DecomposeAny: [RemoveRedundancy],
DecomposeLinalgVectorNorm: [RemoveRedundancy],
ExpandBroadcastTensorShape: [RemoveRedundancy],
ExpandBroadcastTensorShape: [FoldQDQ],
FixedLinearKeepDim: [FoldQDQ],
FoldQDQ: [AnnotateQuantAttrs, AnnotateStack, AnnotateUnbind],
I64toI32: [RemoveRedundancy],
Expand Down
44 changes: 34 additions & 10 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,11 @@
from collections import defaultdict
from typing import List

from executorch.backends.qualcomm._passes import FoldQDQ, TagQuantIO
from executorch.backends.qualcomm._passes import (
ExpandBroadcastTensorShape,
FoldQDQ,
TagQuantIO,
)
from executorch.backends.qualcomm.builders.node_visitor import get_node_visitors
from executorch.backends.qualcomm.debugger.utils import DrawGraph
from executorch.examples.models.deeplab_v3 import DeepLabV3ResNet101Model
Expand Down Expand Up @@ -430,10 +434,20 @@ def test_qnn_backend_equal(self):

def test_qnn_backend_expand(self):
modules = [ExpandAs(), ExpandCopy()] # noqa: F405
sample_input = (torch.randn([3, 1]),)
for i, module in enumerate(modules):
with self.subTest(i=i):
self.lower_module_and_test_output(module, sample_input)
sample_inputs = [
(torch.randn([3, 1]),),
(torch.randn([4]),),
]
passes_job = get_capture_program_passes()
passes_job[ExpandBroadcastTensorShape][QCOM_PASS_ACTIVATE_KEY] = True
index = 0
for module in modules:
for sample_input in sample_inputs:
with self.subTest(i=index):
self.lower_module_and_test_output(
module, sample_input, passes_job=passes_job
)
index += 1

def test_qnn_backend_expm1(self):
sample_input = (torch.randn(3, 4, 5),)
Expand Down Expand Up @@ -1506,11 +1520,21 @@ def test_qnn_backend_equal(self):

def test_qnn_backend_expand(self):
modules = [ExpandAs(), ExpandCopy()] # noqa: F405
sample_input = (torch.randn([3, 1]),)
for i, module in enumerate(modules):
with self.subTest(i=i):
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)
sample_inputs = [
(torch.randn([3, 1]),),
(torch.randn([4]),),
]
passes_job = get_capture_program_passes()
passes_job[ExpandBroadcastTensorShape][QCOM_PASS_ACTIVATE_KEY] = True
index = 0
for module in modules:
for sample_input in sample_inputs:
with self.subTest(i=index):
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(
module, sample_input, passes_job=passes_job
)
index += 1

def test_qnn_backend_expm1(self):
sample_input = (torch.randn(3, 4, 5),)
Expand Down
5 changes: 3 additions & 2 deletions backends/qualcomm/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import subprocess
import tempfile
import unittest
from typing import Callable, Dict, List, Optional, Tuple
from typing import Callable, Dict, List, Optional, OrderedDict, Tuple

import numpy as np
import torch
Expand Down Expand Up @@ -435,6 +435,7 @@ def lower_module_and_test_output(
expected_profile_events: int = -1,
expected_intermediate_events: int = -1,
assert_output_equal: bool = True,
passes_job: Optional[OrderedDict] = None,
skip_node_id_set: set = None,
skip_node_op_set: set = None,
dynamic_shapes: Dict = None,
Expand All @@ -444,6 +445,7 @@ def lower_module_and_test_output(
sample_inputs,
self.compiler_specs,
dynamic_shapes=dynamic_shapes,
passes_job=passes_job,
skip_node_id_set=skip_node_id_set,
skip_node_op_set=skip_node_op_set,
)
Expand Down Expand Up @@ -506,7 +508,6 @@ def get_qdq_module(
block_size_map: Dict[str, Tuple] = None,
submodule_qconfig_list: Optional[List[Tuple[Callable, ModuleQConfig]]] = None,
) -> torch.fx.GraphModule:
module = module.eval()
m = torch.export.export(
module, inputs, dynamic_shapes=dynamic_shapes, strict=True
).module()
Expand Down
6 changes: 3 additions & 3 deletions examples/qualcomm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def make_quantizer(
per_channel_linear=False,
act_observer=MovingAverageMinMaxObserver,
is_qat=False,
callback_qconfig_list: Optional[List[Tuple[Callable, ModuleQConfig]]] = None,
submodule_qconfig_list: Optional[List[Tuple[Callable, ModuleQConfig]]] = None,
):
quantizer = QnnQuantizer()
quantizer.add_custom_quant_annotations(custom_annotations)
Expand All @@ -273,8 +273,8 @@ def make_quantizer(
is_linear_per_channel=per_channel_linear,
act_observer=act_observer,
)
callback_qconfig_list = callback_qconfig_list or []
quantizer.set_submodule_qconfig_list(callback_qconfig_list)
submodule_qconfig_list = submodule_qconfig_list or []
quantizer.set_submodule_qconfig_list(submodule_qconfig_list)
return quantizer


Expand Down
Loading