Skip to content

Qualcomm AI Engine Direct - XR model enablement pipe_clean #8299

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions backends/qualcomm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,55 @@
from .annotate_decomposed import AnnotateDecomposed
from .annotate_quant_attrs import AnnotateQuantAttrs
from .constant_i64_to_i32 import ConstantI64toI32
from .convert_binary_op_with_scalar import ConvertBinaryOpsWithScalar
from .convert_bmm_to_matmul import ConvertBmmToMatmul
from .convert_interpolate_with_upsample2d import ConvertInterpolateWithUpsample2D
from .convert_prelu import ConvertPReLU
from .convert_to_linear import ConvertToLinear
from .decompose_any import DecomposeAny
from .decompose_einsum import DecomposeEinsum
from .decompose_linalg_vector_norm import DecomposeLinalgVectorNorm
from .decompose_silu import DecomposeSilu
from .expand_broadcast_tensor_shape import ExpandBroadcastTensorShape
from .fold_qdq import FoldQDQ
from .fuse_consecutive_transpose import FuseConsecutiveTranspose
from .insert_io_qdq import InsertIOQDQ
from .insert_requantize import InsertRequantize
from .layout_transform import LayoutTransform
from .recompose_pixel_unshuffle import RecomposePixelUnshuffle
from .recompose_rms_norm import RecomposeRmsNorm
from .reduce_dynamic_range import ReduceDynamicRange
from .remove_redundancy import RemoveRedundancy
from .replace_index_put_input import ReplaceIndexPutInput
from .replace_inf_buffer import ReplaceInfBuffer
from .tensor_i64_to_i32 import TensorI64toI32


__all__ = [
AnnotateAndQuantScalar,
AnnotateDecomposed,
AnnotateQuantAttrs,
ConstantI64toI32,
ConvertBmmToMatmul,
ConvertBinaryOpsWithScalar,
ConvertInterpolateWithUpsample2D,
ConvertPReLU,
ConvertToLinear,
DecomposeAny,
DecomposeEinsum,
DecomposeLinalgVectorNorm,
DecomposeSilu,
ExpandBroadcastTensorShape,
FoldQDQ,
ConstantI64toI32,
TensorI64toI32,
FuseConsecutiveTranspose,
InsertIOQDQ,
InsertRequantize,
LayoutTransform,
RecomposePixelUnshuffle,
RecomposeRmsNorm,
ReduceDynamicRange,
RemoveRedundancy,
ReplaceIndexPutInput,
ReplaceInfBuffer,
TensorI64toI32,
]
1 change: 1 addition & 0 deletions backends/qualcomm/_passes/convert_to_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class ConvertToLinear(ExportPass):
mm = exir_ops.edge.aten.mm.default

addmm_patterns = [
{view_copy: 1, permute_copy: 1, addmm: 1},
{view_copy: 2, permute_copy: 1, addmm: 1},
{permute_copy: 1, addmm: 1},
]
Expand Down
76 changes: 76 additions & 0 deletions backends/qualcomm/_passes/decompose_any.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.exir import to_edge
from executorch.exir.pass_base import ExportPass, PassResult


class Any(torch.nn.Module):
def __init__(self, dim, keepdim):
super().__init__()
self.dim = tuple(dim) if isinstance(dim, list) else dim
self.keepdim = keepdim

def forward(self, x):
if self.dim is None:
x = torch.flatten(x)
self.dim = 0

x = x.to(torch.bool).to(torch.int32)
x = torch.sum(x, dim=self.dim, keepdim=self.keepdim, dtype=torch.int32)
return torch.not_equal(x, torch.zeros(1, dtype=torch.int32))


class DecomposeAny(ExportPass):
"""
Decompose for math equivalent op.
"""

def __init__(self, quantization_capture=False) -> None:
super().__init__()

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
graph = graph_module.graph
for node in graph.nodes:
if "any.dim" in str(node.target):
dim = node.args[1] if len(node.args) > 1 else None
keepdim = node.args[2] if len(node.args) > 2 else False
model = Any(dim, keepdim)
edge_mgr = to_edge(
torch.export.export(model, (node.args[0].meta["val"],))
)
decomposed_module = edge_mgr.exported_program()

with graph.inserting_before(node):
# remap is used to map original node values to new node values,
# which ensures that reference to nodes are correctly updated in the new graph
remap = {"x": node.args[0]}

for decomposed_node in decomposed_module.graph.nodes:
# no need to copy existent 'output'
if decomposed_node.op == "output":
for user in node.users.copy():
# remap
user.replace_input_with(
node,
remap[decomposed_node.args[0][0]],
)
# no need to copy existent placeholders
elif decomposed_node.op == "placeholder":
# replace node map from string to graph node
remap[decomposed_node] = remap.pop(decomposed_node.name)
else:
remap[decomposed_node] = graph.node_copy(
decomposed_node,
arg_transform=lambda x, remap=remap: remap[x],
)

graph.erase_node(node)

graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)
2 changes: 1 addition & 1 deletion backends/qualcomm/_passes/decompose_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:

with graph.inserting_before(node):
# remap is used to map original node values to new node values,
# which ensures that reference to nodes are correclty updated in the new graph
# which ensures that reference to nodes are correctly updated in the new graph
remap = {}
# Different from other nodes, einsum args[0] is the einsum equation,
# while input nodes are stored in args[1]
Expand Down
85 changes: 85 additions & 0 deletions backends/qualcomm/_passes/decompose_linalg_vector_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.exir import to_edge
from executorch.exir.pass_base import ExportPass, PassResult


class LinalgVectorNorm(torch.nn.Module):
def __init__(self, exp, dim, keepdim):
super().__init__()
self.exp = exp
self.dim = tuple(dim) if dim is not None else None
self.keepdim = keepdim

def forward(self, x):
if self.dim is None:
x = torch.flatten(x)
self.dim = 0

x = torch.abs(x)
x = torch.pow(x, self.exp)
x = torch.sum(x, dim=self.dim, keepdim=self.keepdim)
return torch.pow(x, 1.0 / self.exp)


class DecomposeLinalgVectorNorm(ExportPass):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like there are lots engineer work here, maybe we (ExecuTorch side) should figure out how to make it simpler

"""
Decompose for math equivalent op.
"""

def __init__(self, quantization_capture=False) -> None:
super().__init__()
self.quantization_capture = quantization_capture

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
graph = graph_module.graph
for node in graph.nodes:
if "linalg_vector_norm" in str(node.target):
ord = node.args[1] if len(node.args) > 1 else 2.0
dim = node.args[2] if len(node.args) > 2 else None
keepdim = node.args[3] if len(node.args) > 3 else False
model = LinalgVectorNorm(ord, dim, keepdim)
if self.quantization_capture:
decomposed_module = torch.export.export(
model, (node.args[0].meta["val"],)
).module()
else:
edge_mgr = to_edge(
torch.export.export(model, (node.args[0].meta["val"],))
)
decomposed_module = edge_mgr.exported_program()

with graph.inserting_before(node):
# remap is used to map original node values to new node values,
# which ensures that reference to nodes are correctly updated in the new graph
remap = {"x": node.args[0]}

for decomposed_node in decomposed_module.graph.nodes:
# no need to copy existent 'output'
if decomposed_node.op == "output":
for user in node.users.copy():
# remap
user.replace_input_with(
node,
remap[decomposed_node.args[0][0]],
)
# no need to copy existent placeholders
elif decomposed_node.op == "placeholder":
# replace node map from string to graph node
remap[decomposed_node] = remap.pop(decomposed_node.name)
else:
remap[decomposed_node] = graph.node_copy(
decomposed_node,
arg_transform=lambda x, remap=remap: remap[x],
)

graph.erase_node(node)

graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)
5 changes: 5 additions & 0 deletions backends/qualcomm/_passes/layout_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ class LayoutTransform(ExportPass):
exir_ops.edge.aten.adaptive_avg_pool2d.default,
exir_ops.edge.aten.avg_pool2d.default,
exir_ops.edge.aten.convolution.default,
exir_ops.edge.aten.instance_norm.default,
exir_ops.edge.aten.max_pool2d_with_indices.default,
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
exir_ops.edge.aten._native_batch_norm_legit.no_stats,
exir_ops.edge.aten.native_group_norm.default,
exir_ops.edge.aten.pixel_shuffle.default,
exir_ops.edge.aten.pixel_unshuffle.default,
Expand All @@ -54,6 +56,7 @@ class LayoutTransform(ExportPass):
exir_ops.edge.aten.eq.Scalar,
exir_ops.edge.aten.eq.Tensor,
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.full_like.default,
exir_ops.edge.aten.ge.Scalar,
exir_ops.edge.aten.ge.Tensor,
exir_ops.edge.aten.gelu.default,
Expand All @@ -75,6 +78,8 @@ class LayoutTransform(ExportPass):
exir_ops.edge.aten.mean.dim,
exir_ops.edge.aten.minimum.default,
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.ne.Scalar,
exir_ops.edge.aten.ne.Tensor,
exir_ops.edge.aten.neg.default,
exir_ops.edge.aten.pow.Tensor_Scalar,
exir_ops.edge.aten.prelu.default,
Expand Down
4 changes: 2 additions & 2 deletions backends/qualcomm/_passes/tensor_i64_to_i32.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _cast_to_int32(self, core_ep: ExirExportedProgram):
for n in core_ep.exported_program.graph.nodes:
# Keep track of original output dtype so we ensure the dtype of the graph is consistent with nn.Module
if is_graph_output(n):
if isinstance(n.meta["val"], tuple):
if isinstance(n.meta["val"], (tuple, list)):
dtype_list = [tensor.dtype for tensor in n.meta["val"]]
n.meta[QCOM_ORIG_DTYPE] = dtype_list
else:
Expand Down Expand Up @@ -76,7 +76,7 @@ def _preserve_output_dtype(
copy_op = exir_ops.edge.aten._to_copy.default
for n in graph_module.graph.nodes:
if is_graph_output(n) and QCOM_ORIG_DTYPE in n.meta:
if isinstance(n.meta["val"], tuple):
if isinstance(n.meta["val"], (tuple, list)):
for i, dtype in enumerate(n.meta[QCOM_ORIG_DTYPE]):
# TODO: Enable this in future to support OP such as topK
if n.meta["val"][i].dtype != dtype:
Expand Down
30 changes: 17 additions & 13 deletions backends/qualcomm/_passes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def get_passes_dependency_for_capture_program():
ConvertInterpolateWithUpsample2D,
ConvertPReLU,
ConvertToLinear,
DecomposeAny,
DecomposeLinalgVectorNorm,
ExpandBroadcastTensorShape,
FoldQDQ,
LayoutTransform,
Expand All @@ -76,14 +78,10 @@ def get_passes_dependency_for_capture_program():
)

return {
RecomposePixelUnshuffle: [RemoveRedundancy],
RecomposeRmsNorm: [RemoveRedundancy],
ConvertToLinear: [RecomposePixelUnshuffle],
ConvertPReLU: [RemoveRedundancy],
ConvertBmmToMatmul: [ConvertToLinear],
ConvertInterpolateWithUpsample2D: [RemoveRedundancy],
ConstantI64toI32: [RemoveRedundancy],
TensorI64toI32: [RemoveRedundancy],
AnnotateAndQuantScalar: [
AnnotateQuantAttrs,
],
AnnotateDecomposed: [RemoveRedundancy],
AnnotateQuantAttrs: [
RecomposePixelUnshuffle,
RecomposeRmsNorm,
Expand All @@ -92,16 +90,22 @@ def get_passes_dependency_for_capture_program():
ConvertBmmToMatmul,
ConvertInterpolateWithUpsample2D,
],
AnnotateAndQuantScalar: [
AnnotateQuantAttrs,
],
AnnotateDecomposed: [RemoveRedundancy],
FoldQDQ: [AnnotateQuantAttrs, AnnotateAndQuantScalar, AnnotateDecomposed],
ConstantI64toI32: [ConvertInterpolateWithUpsample2D],
ConvertBmmToMatmul: [ConvertToLinear],
ConvertInterpolateWithUpsample2D: [RemoveRedundancy],
ConvertPReLU: [RemoveRedundancy],
ConvertToLinear: [RecomposePixelUnshuffle],
DecomposeAny: [RemoveRedundancy],
DecomposeLinalgVectorNorm: [RemoveRedundancy],
ExpandBroadcastTensorShape: [RemoveRedundancy],
FoldQDQ: [AnnotateQuantAttrs, AnnotateAndQuantScalar, AnnotateDecomposed],
LayoutTransform: [
AnnotateQuantAttrs,
AnnotateAndQuantScalar,
ExpandBroadcastTensorShape,
],
RecomposePixelUnshuffle: [RemoveRedundancy],
RecomposeRmsNorm: [RemoveRedundancy],
ReplaceIndexPutInput: [LayoutTransform],
TensorI64toI32: [RemoveRedundancy],
}
6 changes: 6 additions & 0 deletions backends/qualcomm/builders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
op_embedding,
op_eq,
op_expand,
op_full,
op_full_like,
op_ge,
op_gelu,
Expand All @@ -35,6 +36,7 @@
op_hardtanh,
op_index,
op_index_put,
op_instance_norm,
op_layer_norm,
op_le,
op_linear,
Expand All @@ -48,6 +50,7 @@
op_mean_dim,
op_min,
op_mul,
op_ne,
op_neg,
op_pad,
op_pow,
Expand Down Expand Up @@ -101,6 +104,7 @@
op_embedding,
op_eq,
op_expand,
op_full,
op_full_like,
op_ge,
op_gelu,
Expand All @@ -111,6 +115,7 @@
op_hardsigmoid,
op_index,
op_index_put,
op_instance_norm,
op_layer_norm,
op_le,
op_linear,
Expand All @@ -125,6 +130,7 @@
op_min,
op_mul,
op_neg,
op_ne,
op_pad,
op_pow,
op_prelu,
Expand Down
3 changes: 3 additions & 0 deletions backends/qualcomm/builders/op_arange.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ def define_node(
step = node.args[2] if len(node.args) > 2 else 1
out_tensor = torch.arange(start, end, step)

# since we can derive the constant value of current op in AoT stage
# we only build static tensor here for consumers of current node
# to correctly reference the data
self.define_tensor(
node,
node,
Expand Down
Loading
Loading