Skip to content

Qualcomm AI Engine Direct - xr model enablement (mld_f) #10546

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/qualcomm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .expand_broadcast_tensor_shape import ExpandBroadcastTensorShape
from .fixed_linear_keep_dim import FixedLinearKeepDim
from .fold_qdq import FoldQDQ
from .fuse_consecutive_cast import FuseConsecutiveCast
from .fuse_consecutive_transpose import FuseConsecutiveTranspose
from .i64_to_i32 import I64toI32
from .insert_io_qdq import InsertIOQDQ
Expand Down Expand Up @@ -54,6 +55,7 @@
ExpandBroadcastTensorShape,
FixedLinearKeepDim,
FoldQDQ,
FuseConsecutiveCast,
FuseConsecutiveTranspose,
I64toI32,
InsertIOQDQ,
Expand Down
116 changes: 116 additions & 0 deletions backends/qualcomm/_passes/fuse_consecutive_cast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import torch

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from executorch.exir.passes import dead_code_elimination_pass


class FuseConsecutiveCast(ExportPass):
"""
This pass fuses consecutive cast into one or none to reduce runtime
overhead.
To simplify the fuse logic, we ensure each cast node's output has at most 1 cast node
by cloning cast.
Example:
Before clone cast:
relu -> cast1 ─> cast2
|──────> cast3

After clone cast:
relu ─> cast1 ──────> cast2
|───> cast4(new) ─> cast3
"""

def __init__(self):
super().__init__()
self.op_map = {
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
exir_ops.edge.aten._to_copy.default,
}
self.visited = set()
self.nodes = []

def _canonicalize_cast(
self, graph_module: torch.fx.GraphModule
) -> torch.fx.GraphModule:
# replace all i64 cast nodes with i32 version
graph = graph_module.graph
for n in graph_module.graph.nodes:
if n.target in self.op_map and n.meta["val"].dtype == torch.int64:
users = list(n.users)
for user in users:
# bypass graph output node to meet original convention
if user.op == "output":
continue

with graph.inserting_after(n):
cast_node = graph.create_node(
"call_function",
exir_ops.edge.aten._to_copy.default,
n.args,
kwargs={"dtype": torch.int32},
)
cast_node.meta = n.meta
cast_node.meta["val"] = cast_node.meta["val"].to(torch.int32)
user.replace_input_with(n, cast_node)

graph.eliminate_dead_code()

# clone nodes for future fusion
for n in graph_module.graph.nodes:
# make sure we're handling cast node instead of convert node
if n.target in self.op_map and n.kwargs.get("dtype", None) is not None:
users = [user for user in list(n.users) if user.target in self.op_map]
if len(users) > 1:
for i in range(1, len(users)):
with graph.inserting_after(n):
clone_cast_node = graph.create_node(
"call_function",
exir_ops.edge.aten._to_copy.default,
n.args,
kwargs=n.kwargs,
)
clone_cast_node.meta = n.meta
users[i].replace_input_with(n, clone_cast_node)

def _traverse(self, node):
if node in self.visited or node.target not in self.op_map:
return

self.nodes.append(node)
self.visited.add(node)
next_users = [n for n in list(node.users) if n.target in self.op_map]

assert (
len(next_users) <= 1
), "Each cast node should have at most 1 cast output node after _clone_cast"
if not next_users:
return
else:
self._traverse(list(node.users)[0])

def _fuse(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
for n in graph_module.graph.nodes:
self._traverse(n)
# TODO: how to handle following scenario (won't happen for quantized graph)
# fp -> to(i32) -> to(fp)
if len(self.nodes) > 1:
input_node, output_node = self.nodes[0], self.nodes[-1]
output_node.replace_input_with(output_node.args[0], input_node.args[0])

# clear current stack
self.nodes = []

def call(self, graph_module: torch.fx.GraphModule):
self._canonicalize_cast(graph_module)
self._fuse(graph_module)
graph_module.recompile()
dead_code_elimination_pass(graph_module)
return PassResult(graph_module, True)
29 changes: 29 additions & 0 deletions backends/qualcomm/_passes/i64_to_i32.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ class I64toI32(ExportPass):
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.scalar_tensor.default,
}
# This dict is to ensure that the input of the OPs are int64 due to Pytorch restrictions.
# For example, scatter op can only accept args[2], the index, as int64.
# Key: Ops to cast input to i64
# Value: The args' indices to add casting op
I64_IN_OPS = {
exir_ops.edge.aten.gather.default: [2],
exir_ops.edge.aten.scatter.src: [2],
}
copy_op = exir_ops.edge.aten._to_copy.default

def __init__(
Expand Down Expand Up @@ -141,11 +149,32 @@ def _cast_constant_to_int32(self, graph_module: torch.fx.GraphModule):
n.replace_all_uses_with(to_dst_node)
to_dst_node.args = (n,)

def _cast_op_args_to_i64(self, graph_module: torch.fx.GraphModule):
# input will be cast to i32 during call_operator dtype propogation
# insert i64 cast node to prevent PyTorch's operator validation failure
for node in graph_module.graph.nodes:
if node.target in self.I64_IN_OPS:
with graph_module.graph.inserting_before(node):
arg_indices = self.I64_IN_OPS[node.target]
for arg_index in arg_indices:
input_node = node.args[arg_index]
cast_i64_node = graph_module.graph.create_node(
"call_function",
self.copy_op,
(input_node,),
{"dtype": torch.int64},
)
cast_i64_node.meta["val"] = node.meta["val"].to(torch.int64)
args_list = list(node.args)
args_list[arg_index] = cast_i64_node
node.args = tuple(args_list)

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
# Record original output dtype to ensure that if user expects int64 as output,
# convert the output back to int64 if it is casted from int64->int32.
self._record_original_output_dtype(graph_module)
self._cast_constant_to_int32(graph_module)
self._cast_op_args_to_i64(graph_module)
graph_module = super().call(graph_module).graph_module
self._preserve_output_dtype(graph_module)
graph_module.recompile()
Expand Down
3 changes: 3 additions & 0 deletions backends/qualcomm/_passes/qnn_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
ExpandBroadcastTensorShape,
FixedLinearKeepDim,
FoldQDQ,
FuseConsecutiveCast,
FuseConsecutiveTranspose,
I64toI32,
InsertIOQDQ,
Expand Down Expand Up @@ -182,6 +183,7 @@ def transform_for_to_edge_pipeline(

# Before quantizer
def transform_for_annotation_pipeline(self, graph_module: GraphModule):
self.add_pass(RemoveRedundancy(quantization_capture=True))
self.add_pass(ReduceDynamicRange())
self.add_pass(RecomposePixelUnshuffle(quantization_capture=True))
self.add_pass(ReplaceArangeArgs())
Expand Down Expand Up @@ -214,5 +216,6 @@ def transform_for_preprocess_pipeline(self, exported_program: ExportedProgram):
self.add_pass(InsertRequantize())
self.add_pass(InsertIOQDQ(exported_program))
self.add_pass(LayoutTransform(exported_program, insert_permute=True))
self.add_pass(FuseConsecutiveCast())
self.add_pass(FuseConsecutiveTranspose())
return self._transform(exported_program.graph_module)
17 changes: 15 additions & 2 deletions backends/qualcomm/_passes/remove_redundancy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ class RemoveRedundancy(ExportPass):
Trim certain operators to reduce unnecessary overhead.
"""

def __init__(self):
def __init__(self, quantization_capture=False):
super(RemoveRedundancy, self).__init__()
self.redundant_ops = {
self.redundant_ops_general = {
torch.clone: self._default_condition,
torch.ops.aten.clone.default: self._default_condition,
exir_ops.edge.aten.clone.default: self._default_condition,
Expand All @@ -27,7 +27,16 @@ def __init__(self):
exir_ops.edge.dim_order_ops._to_dim_order_copy.default: self._dim_order_op_condition,
# remove channel_last / contiguous _to_copy if '_skip_dim_order' is set to True
exir_ops.edge.aten._to_copy.default: self._to_copy_op_condition,
torch.ops.aten._assert_tensor_metadata.default: self._default_condition,
}
self.redundant_ops_annotation = {
torch.ops.aten._assert_tensor_metadata.default: self._default_condition,
}
self.redundant_ops = (
self.redundant_ops_annotation
if quantization_capture
else self.redundant_ops_general
)

def _dim_order_op_condition(self, node):
dim_order = node.kwargs.get("dim_order")
Expand All @@ -49,6 +58,10 @@ def _remove(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
continue

to_be_remove = n
# assert_tensor_metadata op has no user
if len(n.users.keys()) == 0:
n.args = ()
# normal case
for user_n in list(n.users.keys()):
user_n.replace_input_with(n, n.args[0])
graph_module.graph.erase_node(to_be_remove)
Expand Down
2 changes: 2 additions & 0 deletions backends/qualcomm/builders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
op_expand,
op_full,
op_full_like,
op_gather,
op_ge,
op_gelu,
op_group_norm,
Expand Down Expand Up @@ -120,6 +121,7 @@
op_expand,
op_full,
op_full_like,
op_gather,
op_ge,
op_gelu,
op_group_norm,
Expand Down
53 changes: 10 additions & 43 deletions backends/qualcomm/builders/op_argmin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import torch
from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA

from .node_visitor import NodeVisitor, QNN_TENSOR_TYPE_MAP, register_node_visitor
from .qnn_constants import OpArgmin, OpCast, QNN_OP_PACKAGE_NAME_QTI_AISW
from .node_visitor import NodeVisitor, register_node_visitor
from .qnn_constants import OpArgmin, QNN_OP_PACKAGE_NAME_QTI_AISW


@register_node_visitor
Expand All @@ -26,7 +26,6 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
op_wrapper_list = []
input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
output_tensor = self.get_tensor(node, node)
Expand All @@ -38,26 +37,14 @@ def define_node(
nodes_to_wrappers,
)
argmin_input_tensors = [argmin_inp_tensor_wrapper]

# arg output is index, do not quantize it.
node.meta.pop("quant_attrs", None)
input_quant_encoding, input_quant_configs = self.get_quant_encoding_conf(
input_node, node
)

argmin_intermediate_tensor_wrapper = self.define_custom_tensor_wrapper(
node_name=node.name + "_cast",
tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
dtype=QNN_TENSOR_TYPE_MAP[torch.int32],
quant_encoding=input_quant_encoding,
quant_configs=input_quant_configs,
dims=output_tensor.size(),
tensor=output_tensor,
is_fake_tensor=True,
nodes_to_wrappers=nodes_to_wrappers,
argmin_out_tensor_wrapper = self.define_tensor(
node,
node,
output_tensor.to(torch.int32),
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)

argmin_output_tensors = [argmin_intermediate_tensor_wrapper]
argmin_output_tensors = [argmin_out_tensor_wrapper]

dim = cast(int, node.args[1])
if dim < 0:
Expand Down Expand Up @@ -87,24 +74,4 @@ def define_node(
{QCOM_DATA: keep_dims},
)

op_wrapper_list.append(argmin_op)

cast_op = PyQnnWrapper.PyQnnOpWrapper(
node.name + "_cast",
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpCast.op_name,
)

output_tensor_wrapper = self.define_tensor(
node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)

cast_op.AddInputTensors([argmin_intermediate_tensor_wrapper])
cast_op.AddOutputTensors([output_tensor_wrapper])
op_wrapper_list.append(cast_op)

return op_wrapper_list
return argmin_op
Loading
Loading