Skip to content

Commit df31960

Browse files
committed
Qualcomm AI Engine Direct - issue fix pytorch#2
- pytorch#14048 > add quantized test case with GLU decomposition - pytorch#14049 > add e2e example where constant expansion is applied - pytorch#14050 > add e2e example and source transform for 6D operation - pytorch#14051 > add e2e example and complement missed annotation - pytorch#14052 > add e2e example and dedicated passe for 6D partition
1 parent c1b7ec5 commit df31960

21 files changed

+1140
-135
lines changed

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .decompose_col_im import DecomposeColIm
1818
from .decompose_einsum import DecomposeEinsum
1919
from .decompose_expm1 import DecomposeExpM1
20+
from .decompose_glu import DecomposeGlu
2021
from .decompose_linalg_vector_norm import DecomposeLinalgVectorNorm
2122
from .decompose_minmaxdim import DecomposeMinMaxDim
2223
from .decompose_roll import DecomposeRoll
@@ -57,6 +58,7 @@
5758
DecomposeColIm,
5859
DecomposeEinsum,
5960
DecomposeExpM1,
61+
DecomposeGlu,
6062
DecomposeLinalgVectorNorm,
6163
DecomposeMinMaxDim,
6264
DecomposeRoll,

backends/qualcomm/_passes/annotate_quant_attrs.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
QCOM_SCALE,
2020
QCOM_ZERO_POINT,
2121
)
22+
from executorch.exir.dialects._ops import ops as exir_ops
2223
from executorch.exir.pass_base import ExportPass, PassResult
2324

2425
from .utils import get_quant_attrs
@@ -38,6 +39,9 @@ def __init__(
3839
super(AnnotateQuantAttrs, self).__init__()
3940
self.edge_program = edge_program
4041
self.skip_advanced_requant = skip_advanced_requant
42+
self.skip_requant_allowlist = {
43+
exir_ops.edge.aten.sigmoid.default,
44+
}
4145

4246
def _annotate_source_nodes(
4347
self, quant_node: torch.fx.Node, quant_attrs: Dict[str, Any]
@@ -80,6 +84,10 @@ def _annotate_requant(self, n):
8084
# node1 -> q_ui8 (n) -> dq_ui8 -> q_int32 -> dq_int32 -> node2 -> ....
8185
# We store {node2: quant_attr in dq_int32} in node1.meta
8286
if n.target in q_ops and n.args[0].target not in dq_ops:
87+
# for some fixed scale op, there is no need to requantize it
88+
if n.args[0].target in self.skip_requant_allowlist:
89+
return
90+
8391
dq_nodes = self._find_last_dq_nodes(n)
8492
q_attrs = get_quant_attrs(self.edge_program, n)
8593
for dq_node in dq_nodes:

backends/qualcomm/_passes/decompose_any.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from executorch.exir import to_edge
99
from executorch.exir.pass_base import ExportPass, PassResult
1010

11+
from .utils import merge_decomposed_graph
12+
1113

1214
class Any(torch.nn.Module):
1315
def __init__(self, dim, keepdim):
@@ -49,26 +51,12 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
4951
# remap is used to map original node values to new node values,
5052
# which ensures that reference to nodes are correctly updated in the new graph
5153
remap = {"x": node.args[0]}
52-
53-
for decomposed_node in decomposed_module.graph.nodes:
54-
# no need to copy existent 'output'
55-
if decomposed_node.op == "output":
56-
for user in node.users.copy():
57-
# remap
58-
user.replace_input_with(
59-
node,
60-
remap[decomposed_node.args[0][0]],
61-
)
62-
# no need to copy existent placeholders
63-
elif decomposed_node.op == "placeholder":
64-
# replace node map from string to graph node
65-
remap[decomposed_node] = remap.pop(decomposed_node.name)
66-
else:
67-
remap[decomposed_node] = graph.node_copy(
68-
decomposed_node,
69-
arg_transform=lambda x, remap=remap: remap[x],
70-
)
71-
54+
merge_decomposed_graph(
55+
remap=remap,
56+
target_node=node,
57+
target_graph=graph,
58+
decomposed_graph_module=decomposed_module,
59+
)
7260
graph.erase_node(node)
7361

7462
graph.eliminate_dead_code()

backends/qualcomm/_passes/decompose_cdist.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import torch
88
from executorch.exir.pass_base import ExportPass, PassResult
99

10+
from .utils import merge_decomposed_graph
11+
1012

1113
class CDist(torch.nn.Module):
1214
def __init__(self):
@@ -54,26 +56,12 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
5456
# remap is used to map original node values to new node values,
5557
# which ensures that reference to nodes are correctly updated in the new graph
5658
remap = {"x": node.args[0], "y": node.args[1]}
57-
58-
for decomposed_node in decomposed_module.graph.nodes:
59-
# no need to copy existent 'output'
60-
if decomposed_node.op == "output":
61-
for user in node.users.copy():
62-
# remap
63-
user.replace_input_with(
64-
node,
65-
remap[decomposed_node.args[0][0]],
66-
)
67-
# no need to copy existent placeholders
68-
elif decomposed_node.op == "placeholder":
69-
# replace node map from string to graph node
70-
remap[decomposed_node] = remap.pop(decomposed_node.name)
71-
else:
72-
remap[decomposed_node] = graph.node_copy(
73-
decomposed_node,
74-
arg_transform=lambda x, remap=remap: remap[x],
75-
)
76-
59+
merge_decomposed_graph(
60+
remap=remap,
61+
target_node=node,
62+
target_graph=graph,
63+
decomposed_graph_module=decomposed_module,
64+
)
7765
graph.erase_node(node)
7866

7967
graph.eliminate_dead_code()

backends/qualcomm/_passes/decompose_einsum.py

Lines changed: 8 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from executorch.exir.pass_base import ExportPass, PassResult
99
from torch.fx.experimental.proxy_tensor import make_fx
1010

11-
from .utils import copy_nn_module_stack
11+
from .utils import merge_decomposed_graph
1212

1313

1414
class DecomposeEinsum(ExportPass):
@@ -37,30 +37,13 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
3737
for i, arg in enumerate(node.args[1]):
3838
remap[f"arg1_{i+1}"] = arg
3939

40-
for decomposed_node in decomposed_module.graph.nodes:
41-
copy_nn_module_stack(node, decomposed_node)
42-
# This is the arg[0] equation string, which is not required anymore after decomposition
43-
if "arg0" in decomposed_node.name:
44-
continue
45-
46-
# no need to copy existent 'output'
47-
if decomposed_node.op == "output":
48-
for user in node.users.copy():
49-
# remap
50-
user.replace_input_with(
51-
node,
52-
remap[decomposed_node.args[0][0]],
53-
)
54-
# no need to copy existent placeholders
55-
elif decomposed_node.op == "placeholder":
56-
# replace node map from string to graph node
57-
remap[decomposed_node] = remap.pop(decomposed_node.name)
58-
else:
59-
remap[decomposed_node] = graph.node_copy(
60-
decomposed_node,
61-
arg_transform=lambda x, remap=remap: remap[x],
62-
)
63-
40+
merge_decomposed_graph(
41+
remap=remap,
42+
target_node=node,
43+
target_graph=graph,
44+
decomposed_graph_module=decomposed_module,
45+
predicate=lambda decomp_node: "arg0" not in decomp_node.name,
46+
)
6447
graph.erase_node(node)
6548

6649
graph.eliminate_dead_code()
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.pass_base import ExportPass, PassResult
9+
10+
from .utils import merge_decomposed_graph
11+
12+
13+
# this wrapper is required for IO name mapping with decomposed graph
14+
class Glu(torch.nn.Module):
15+
def __init__(self, dim=-1):
16+
super().__init__()
17+
self.glu = torch.nn.GLU(dim=dim)
18+
19+
def forward(self, x):
20+
return self.glu(x)
21+
22+
23+
class DecomposeGlu(ExportPass):
24+
"""
25+
Decompose glu for quantization annotation to work properly.
26+
"""
27+
28+
def __init__(self) -> None:
29+
super().__init__()
30+
31+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
32+
graph = graph_module.graph
33+
for node in graph.nodes:
34+
if node.target == torch.ops.aten.glu.default:
35+
ep = torch.export.export(
36+
Glu(dim=-1 if len(node.args) < 2 else node.args[1]),
37+
(node.args[0].meta["val"],),
38+
)
39+
decomposed_module = ep.run_decompositions().graph_module
40+
41+
with graph.inserting_before(node):
42+
# remap is used to map original node values to new node values,
43+
# which ensures that reference to nodes are correctly updated in the new graph
44+
remap = {"x": node.args[0]}
45+
merge_decomposed_graph(
46+
remap=remap,
47+
target_node=node,
48+
target_graph=graph,
49+
decomposed_graph_module=decomposed_module,
50+
)
51+
graph.erase_node(node)
52+
53+
graph.eliminate_dead_code()
54+
graph_module.recompile()
55+
return PassResult(graph_module, True)

backends/qualcomm/_passes/decompose_linalg_vector_norm.py

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from executorch.exir import to_edge
99
from executorch.exir.pass_base import ExportPass, PassResult
1010

11-
from .utils import copy_nn_module_stack
11+
from .utils import merge_decomposed_graph
1212

1313

1414
class LinalgVectorNorm(torch.nn.Module):
@@ -62,27 +62,12 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
6262
# remap is used to map original node values to new node values,
6363
# which ensures that reference to nodes are correctly updated in the new graph
6464
remap = {"x": node.args[0]}
65-
66-
for decomposed_node in decomposed_module.graph.nodes:
67-
copy_nn_module_stack(node, decomposed_node)
68-
# no need to copy existent 'output'
69-
if decomposed_node.op == "output":
70-
for user in node.users.copy():
71-
# remap
72-
user.replace_input_with(
73-
node,
74-
remap[decomposed_node.args[0][0]],
75-
)
76-
# no need to copy existent placeholders
77-
elif decomposed_node.op == "placeholder":
78-
# replace node map from string to graph node
79-
remap[decomposed_node] = remap.pop(decomposed_node.name)
80-
else:
81-
remap[decomposed_node] = graph.node_copy(
82-
decomposed_node,
83-
arg_transform=lambda x, remap=remap: remap[x],
84-
)
85-
65+
merge_decomposed_graph(
66+
remap=remap,
67+
target_node=node,
68+
target_graph=graph,
69+
decomposed_graph_module=decomposed_module,
70+
)
8671
graph.erase_node(node)
8772

8873
graph.eliminate_dead_code()

backends/qualcomm/_passes/decompose_roll.py

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from executorch.exir.pass_base import ExportPass, PassResult
99

10-
from .utils import copy_nn_module_stack
10+
from .utils import merge_decomposed_graph
1111

1212

1313
class SliceCopy(torch.nn.Module):
@@ -65,27 +65,12 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
6565
# remap is used to map original node values to new node values,
6666
# which ensures that reference to nodes are correctly updated in the new graph
6767
remap = {"x": input_node}
68-
69-
for decomposed_node in decomposed_module.graph.nodes:
70-
copy_nn_module_stack(node, decomposed_node)
71-
# no need to copy existent 'output'
72-
if decomposed_node.op == "output":
73-
for user in node.users.copy():
74-
# remap
75-
user.replace_input_with(
76-
node,
77-
remap[decomposed_node.args[0][0]],
78-
)
79-
# no need to copy existent placeholders
80-
elif decomposed_node.op == "placeholder":
81-
# replace node map from string to graph node
82-
remap[decomposed_node] = remap.pop(decomposed_node.name)
83-
else:
84-
remap[decomposed_node] = graph.node_copy(
85-
decomposed_node,
86-
arg_transform=lambda x, remap=remap: remap[x],
87-
)
88-
68+
merge_decomposed_graph(
69+
remap=remap,
70+
target_node=node,
71+
target_graph=graph,
72+
decomposed_graph_module=decomposed_module,
73+
)
8974
graph.erase_node(node)
9075

9176
graph.eliminate_dead_code()

backends/qualcomm/_passes/decompose_wrap_with_autocast.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import torch
1111
from executorch.exir.pass_base import ExportPass, PassResult
1212

13-
from .utils import copy_nn_module_stack
13+
from .utils import merge_decomposed_graph
1414

1515

1616
class DecomposeWrapWithAutocast(ExportPass):
@@ -52,7 +52,7 @@ def _replace(self, gm: torch.fx.GraphModule) -> None:
5252
graph = gm.graph
5353
for node in graph.nodes:
5454
if isinstance(node.target, torch._higher_order_ops.wrap.WrapWithAutocast):
55-
submod, submod_name = self._get_submod(gm, node)
55+
submod, _ = self._get_submod(gm, node)
5656
n_args = node.args
5757
input_submod = n_args[4]
5858
decomposed_module = submod
@@ -61,22 +61,13 @@ def _replace(self, gm: torch.fx.GraphModule) -> None:
6161
# which ensures that reference to nodes are correctly updated in the new graph
6262
# remap = {"expand_1": node.args[5], "to_4": node.args[6]}
6363
remap = {n_args[i].name: n_args[i] for i in range(5, len(n_args))}
64-
65-
for decomposed_node in decomposed_module.graph.nodes:
66-
copy_nn_module_stack(node, decomposed_node)
67-
# no need to copy existent 'output'
68-
if decomposed_node.op == "output":
69-
self._replace_output(node, decomposed_node, remap)
70-
# no need to copy existent placeholders
71-
elif decomposed_node.op == "placeholder":
72-
# replace node map from string to graph node
73-
remap[decomposed_node] = remap.pop(decomposed_node.name)
74-
else:
75-
remap[decomposed_node] = graph.node_copy(
76-
decomposed_node,
77-
arg_transform=lambda x, remap=remap: remap[x],
78-
)
79-
64+
merge_decomposed_graph(
65+
remap=remap,
66+
target_node=node,
67+
target_graph=graph,
68+
decomposed_graph_module=decomposed_module,
69+
output_processor=self._replace_output,
70+
)
8071
graph.erase_node(node)
8172

8273
graph.erase_node(input_submod)

0 commit comments

Comments
 (0)