Skip to content

Qualcomm AI Engine Direct - Optimize the performance for AR-N model #9079

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 16 additions & 25 deletions backends/qualcomm/_passes/fuse_consecutive_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,6 @@ def _clone_transpose(
clone_permute_node.meta = n.meta
users[i].replace_input_with(n, clone_permute_node)

def _is_dispensable(self, axis_order):
for index, value in enumerate(axis_order):
if index != value:
return False
return True

def _traverse(self, node):
if node in self.visited or node.target not in self.op_map:
return
Expand All @@ -87,25 +81,22 @@ def _fuse(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
axis_order = torch.arange(len(input_shape)).tolist()
for node in self.nodes:
axis_order = [axis_order[i] for i in node.args[1]]
# If axis order is just [0,1,2,3], we ignore permute node
if self._is_dispensable(axis_order):
for user in output_node.users.copy():
user.replace_input_with(output_node, n.args[0])
else:
with graph.inserting_after(input_node):
permute_op = exir_ops.edge.aten.permute_copy.default
permute_node = graph.create_node(
"call_function", permute_op, (input_node, axis_order)
)
users = output_node.users.copy()
for user in users:
user.replace_input_with(output_node, permute_node)

# copy metadata
permute_node.meta = output_node.meta
# Without "qnn_permute", we might obtain wrong input shape
if [pn.meta.get(QCOM_INSERTED_PERMUTE) for pn in self.nodes]:
permute_node.meta[QCOM_INSERTED_PERMUTE] = True

# Reserve [0,1,2,3] permute node to ensure the next node get the right axis order.
with graph.inserting_after(input_node):
permute_op = exir_ops.edge.aten.permute_copy.default
permute_node = graph.create_node(
"call_function", permute_op, (input_node, axis_order)
)
users = output_node.users.copy()
for user in users:
user.replace_input_with(output_node, permute_node)

# copy metadata
permute_node.meta = output_node.meta
# Without "qnn_permute", we might obtain wrong input shape
if [pn.meta.get(QCOM_INSERTED_PERMUTE) for pn in self.nodes]:
permute_node.meta[QCOM_INSERTED_PERMUTE] = True

# clear current stack
self.nodes = []
Expand Down
16 changes: 11 additions & 5 deletions backends/qualcomm/_passes/recompose_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
from executorch.backends.qualcomm.builders.utils import get_parameter, is_parameter
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
Expand All @@ -16,8 +17,9 @@ class RecomposeRmsNorm(ExportPass):
Merge decomposed operators back to one super node.
"""

def __init__(self):
super().__init__()
def __init__(self, edge_program: torch.export.ExportedProgram):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can follow #8505 to get rid of some recompose logic to reduce engineer effort there

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your information. I will try it.

super(RecomposeRmsNorm, self).__init__()
self.edge_program = edge_program

def _get_eps_node(self, nodes):
# eps: one of inputs of add node
Expand Down Expand Up @@ -47,11 +49,15 @@ def call(self, graph_module: torch.fx.GraphModule):
input_node = inp_0 if len(inp_0.users) == 2 else inp_1
else:
raise RuntimeError(
f"Found a edge case of rms_node partitoin {src_partition}, which has {input_len} inputs"
f"Found a edge case of rms_node partition {src_partition}, which has {input_len} inputs"
)

output_node = src_partition.output_nodes[0]
eps_node = self._get_eps_node(src_partition.nodes)
eps = self._get_eps_node(src_partition.nodes)
if isinstance(eps, torch.fx.Node) and is_parameter(
eps, self.edge_program
):
eps = get_parameter(eps, self.edge_program).item()
gamma_node = self._get_gamma_node(output_node)

with graph.inserting_before(output_node):
Expand All @@ -64,7 +70,7 @@ def call(self, graph_module: torch.fx.GraphModule):
input_node,
list(gamma_node.meta["val"].shape),
gamma_node,
eps_node,
eps,
),
)
users = output_node.users.copy()
Expand Down
17 changes: 7 additions & 10 deletions backends/qualcomm/builders/op_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@

import torch
from executorch.backends.qualcomm.builders.utils import get_parameter
from executorch.backends.qualcomm.utils.constants import QCOM_DATA, QCOM_QUANT_ATTRS
from executorch.backends.qualcomm.utils.constants import (
QCOM_DATA,
QCOM_QUANT_ATTRS,
QCOM_ZERO_POINT,
)
from executorch.exir.dialects._ops import ops as exir_ops

from .node_visitor import NodeVisitor, register_node_visitor
Expand Down Expand Up @@ -66,7 +70,7 @@ def define_node(
nodes_to_wrappers,
)

# Fake node, nn module seems to be inconsistant with document
# Fake node, nn module seems to be inconsistent with document
bias_tensor = torch.zeros(weight_tensor.shape)
bias_node = torch.fx.Node(
node.graph,
Expand All @@ -78,6 +82,7 @@ def define_node(
)
if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS):
bias_node.meta[QCOM_QUANT_ATTRS] = quant_attrs
bias_node.meta[QCOM_QUANT_ATTRS][QCOM_ZERO_POINT] = 0
bias_tensor_wrapper = self.define_tensor(
bias_node,
node,
Expand All @@ -87,14 +92,6 @@ def define_node(
)

epsilon = node.args[3]
if isinstance(epsilon, torch.fx.Node):
epsilon = get_parameter(epsilon, self.edge_program)
epsilon = (
epsilon
if isinstance(epsilon, float)
else torch.finfo(epsilon.dtype).eps
)

output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
node,
Expand Down
22 changes: 22 additions & 0 deletions examples/qualcomm/oss_scripts/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,28 @@ def compile(args, pte_filename, tokenizer):
if "model" in state_dict:
state_dict = state_dict["model"]

# Change to HuggingFace weight to improve the performance of RoPE in HTP backend.
def permute(w, heads):
dim_0 = w.size(0)
dim_1 = w.size(1)
return (
w.view(heads, dim_0 // heads // 2, 2, dim_1)
.transpose(1, 2)
.reshape(dim_0, dim_1)
)

n_heads = llama_instance_list[0].n_heads
n_kv_heads = llama_instance_list[0].n_kv_heads
n_layers = llama_instance_list[0].n_layers

for layer_i in range(n_layers):
state_dict[f"layers.{layer_i}.attention.wq.weight"] = permute(
state_dict[f"layers.{layer_i}.attention.wq.weight"], n_heads
)
state_dict[f"layers.{layer_i}.attention.wk.weight"] = permute(
state_dict[f"layers.{layer_i}.attention.wk.weight"], n_kv_heads
)

for llama_instance in llama_instance_list:
llama_instance.load_state_dict(
state_dict,
Expand Down
34 changes: 22 additions & 12 deletions examples/qualcomm/oss_scripts/llama/model/static_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@
def apply_rotary_emb_single(
x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
) -> torch.Tensor:
x_r, x_i = x[..., ::2], x[..., 1::2]

# brodcast for batch_prefill mode input x
# The implementation of RoPE in HuggingFace processes query and key with two half instead of interleaved way.
# The main difference is stride in StrideSlice op. For interleaved way, stride is two which is not friendly for HTP backend.
# Ref: https://github.com/huggingface/transformers/issues/25199
x_r, x_i = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
# broadcast for batch_prefill mode input x
if x.dim() == 4:
freqs_cos = freqs_cos[None, :, None, :]
freqs_sin = freqs_sin[None, :, None, :]
freqs_cos = freqs_cos[None, None, :, :]
freqs_sin = freqs_sin[None, None, :, :]
x_out_r = x_r * freqs_cos - x_i * freqs_sin
x_out_i = x_r * freqs_sin + x_i * freqs_cos

Expand Down Expand Up @@ -104,25 +106,33 @@ def forward_sha(
v_caches: Optional[List[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
bsz, seq_len, _ = hidden_states.shape
# In the HTP backend, the input axis order for the convolution operation is
# more efficient with [1, 1, seq_len, dim] compared to [1, seq_len, 1, dim].
hidden_states = torch.reshape(
hidden_states, (bsz, seq_len, 1, self.dim)
).transpose(1, 3)
q = [
wq_sha(hidden_states).reshape(bsz, self.head_dim, seq_len).transpose(1, 2)
wq_sha(hidden_states)
.permute(0, 2, 3, 1)
.reshape(bsz, seq_len, self.head_dim)
for wq_sha in self.wq_sha
]
k = [
wk_sha(hidden_states).reshape(bsz, self.head_dim, seq_len).transpose(1, 2)
wk_sha(hidden_states)
.permute(0, 2, 3, 1)
.reshape(bsz, seq_len, self.head_dim)
for wk_sha in self.wk_sha
]
v = [
wv_sha(hidden_states).reshape(bsz, self.head_dim, seq_len).transpose(1, 2)
wv_sha(hidden_states)
.permute(0, 2, 3, 1)
.reshape(bsz, seq_len, self.head_dim)
for wv_sha in self.wv_sha
]
for i in range(len(q)):
q[i] = apply_rotary_emb_single(q[i], freqs_cos, freqs_sin)
for i in range(len(k)):
k[i] = apply_rotary_emb_single(k[i], freqs_cos, freqs_sin).permute(0, 2, 1)
k[i] = apply_rotary_emb_single(k[i], freqs_cos, freqs_sin).transpose(1, 2)

output_y = []
kh, vh = [], []
Expand Down Expand Up @@ -249,10 +259,10 @@ def prepare_feedfoward_conv(self):

def forward_feedfoward_conv(self, x):
bsz, _, _ = x.size()
x = torch.reshape(x, (bsz, -1, self.dim, 1))
x = x.transpose(1, 2) # Transpose right before and after Conv
x = torch.reshape(x, (bsz, -1, 1, self.dim))
x = x.transpose(1, 3) # Transpose right before and after Conv
x = self.w2_conv(F.silu(self.w1_conv(x)) * self.w3_conv(x))
x = x.transpose(1, 2)
x = x.transpose(1, 3)
x = torch.reshape(x, (bsz, -1, self.dim))
return x

Expand Down
Loading