Skip to content

Commit

Permalink
Qualcomm AI Engine Direct - Add llama sha transforming pass
Browse files Browse the repository at this point in the history
Differential Revision: D64435128

Pull Request resolved: #6211
  • Loading branch information
chunit-quic authored Nov 11, 2024
1 parent 623a9a6 commit 576e96c
Show file tree
Hide file tree
Showing 5 changed files with 267 additions and 16 deletions.
1 change: 1 addition & 0 deletions examples/models/llama/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ runtime.python_library(
"export_llama_lib.py",
"model.py",
"source_transformation/apply_spin_quant_r1_r2.py",
"source_transformation/attention.py",
"source_transformation/lora.py",
"source_transformation/pre_quantization.py",
"source_transformation/prune_vocab.py",
Expand Down
3 changes: 3 additions & 0 deletions examples/models/llama/export_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@
# Example script for exporting Llama2 to flatbuffer

import logging
import sys

import torch

from .export_llama_lib import build_args_parser, export_llama

sys.setrecursionlimit(4096)


FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=FORMAT)
Expand Down
59 changes: 44 additions & 15 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
fuse_layer_norms,
get_model_with_r1_r2,
)

from .source_transformation.attention import replace_attention_to_attention_sha
from .source_transformation.quantize import (
get_quant_embedding_transform,
get_quant_weight_transform,
Expand Down Expand Up @@ -175,6 +177,12 @@ def build_args_parser() -> argparse.ArgumentParser:
help="checkpoint directory. Use with a sharded checkpoint, not for the standard llama2 model. Note, checkpoint_dir takes precedence over checkpoint if both are set.",
)

parser.add_argument(
"--use_qnn_sha",
action="store_true",
help="Change multi head attention to multiple single head attention for qnn backend (Qualcomm)",
)

parser.add_argument(
"--calibration_tasks",
nargs="+",
Expand Down Expand Up @@ -700,15 +708,24 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
get_custom_quant_ios_dtype,
)

atten = builder_exported_to_edge.model.layers[0].attention
if args.use_qnn_sha:
cache_shape = torch.Size(
(atten.max_batch_size, atten.max_seq_len, atten.head_dim)
)
else:
cache_shape = torch.Size(
(
atten.max_batch_size,
atten.max_seq_len,
atten.n_kv_heads,
atten.head_dim,
)
)
# pyre-ignore
tag_quant_io(
builder_exported_to_edge.edge_manager.exported_program().graph_module,
partial(
get_custom_quant_ios_dtype, # pyre-ignore
builder_exported_to_edge.model.layers[
0
].attention.kv_cache.past_k_caches.shape,
),
partial(get_custom_quant_ios_dtype, cache_shape), # pyre-ignore
)

logging.info("Lowering model using following partitioner(s): ")
Expand Down Expand Up @@ -977,15 +994,27 @@ def _get_source_transforms( # noqa
convert_linear_to_conv2d,
)

transforms.append(replace_kv_cache_with_simple_kv_cache)
transforms.append(replace_sdpa_with_flex_sdpa)
transforms.append(replace_causal_mask)
transforms.append(replace_rms_norm_with_native_rms_norm)
if args.optimized_rotation_path:
transforms.append(fuse_layer_norms)
transforms.append(get_model_with_r1_r2(args.optimized_rotation_path))
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
transforms.append(convert_linear_to_conv2d)
if args.use_qnn_sha:
if args.optimized_rotation_path:
transforms.append(fuse_layer_norms)
transforms.append(
get_model_with_r1_r2(args.optimized_rotation_path)
)
transforms.append(replace_attention_to_attention_sha)
transforms.append(replace_causal_mask)
transforms.append(replace_rms_norm_with_native_rms_norm)
transforms.append(convert_linear_to_conv2d)
else:
transforms.append(replace_kv_cache_with_simple_kv_cache)
transforms.append(replace_sdpa_with_flex_sdpa)
transforms.append(replace_causal_mask)
transforms.append(replace_rms_norm_with_native_rms_norm)
if args.optimized_rotation_path:
transforms.append(fuse_layer_norms)
transforms.append(
get_model_with_r1_r2(args.optimized_rotation_path)
)
transforms.append(convert_linear_to_conv2d)

elif args.mps:
# Currently mps doesn't support sdpa op, use the simpler decomposition
Expand Down
1 change: 0 additions & 1 deletion examples/models/llama/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,6 @@ def __init__(self, args: ModelArgs, layer_id: int):
self.max_batch_size = args.max_batch_size
self.max_seq_len = args.max_seq_len
self.dim = args.dim
# self.dim = 4096, self.n_heads = 32, self.head_dim = 4096 / 32 = 125
self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
Expand Down
219 changes: 219 additions & 0 deletions examples/models/llama/source_transformation/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

# Example script for exporting Llama2 to flatbuffer

import math
from typing import List, Optional, Tuple

import torch
from executorch.examples.models.llama.llama_transformer import Attention
from torch import nn


def apply_rotary_emb_single(
x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
) -> torch.Tensor:
x_r, x_i = x[..., ::2], x[..., 1::2]

x_out_r = x_r * freqs_cos - x_i * freqs_sin
x_out_i = x_r * freqs_sin + x_i * freqs_cos

x_out = torch.cat([x_out_r, x_out_i], dim=-1)
return x_out


class KVCacheSHA(torch.nn.Module):
def __init__(
self,
max_batch_size: int,
max_seq_length: int,
n_heads: int,
head_dim: int,
dtype=torch.float32,
):
super().__init__()

# a buffer per head
cache_shape = (max_batch_size, max_seq_length, head_dim)
for i in range(n_heads):
self.register_buffer(
f"past_k_caches_{i}",
torch.zeros(cache_shape, dtype=dtype, device="cpu"),
persistent=False,
)
self.register_buffer(
f"past_v_caches_{i}",
torch.zeros(cache_shape, dtype=dtype, device="cpu"),
persistent=False,
)

def update(
self,
input_pos: torch.Tensor,
k_val: torch.Tensor,
v_val: torch.Tensor,
cache_idx: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
new_k = torch.ops.aten.index_put_(
getattr(self, f"past_k_caches_{cache_idx}"), [None, input_pos], k_val
)
new_v = torch.ops.aten.index_put_(
getattr(self, f"past_v_caches_{cache_idx}"), [None, input_pos], v_val
)
return new_k, new_v

def get_cache(self, head_idx):
return getattr(self, f"past_k_caches_{head_idx}"), getattr(
self, f"past_v_caches_{head_idx}"
)


class SDPASHA(torch.nn.Module):

def __init__(
self,
max_batch_size: int,
max_seq_length: int,
n_heads: int,
n_rep: int,
head_dim: int,
dim: int,
):
super().__init__()
self.head_dim = head_dim
self.n_rep = n_rep
self.dim = dim
self.kv_cache = KVCacheSHA(
max_batch_size, max_seq_length, n_heads // n_rep, head_dim
)
self.scale_factor = math.sqrt(head_dim)

def forward(
self,
input_pos: torch.Tensor,
qs: List[torch.Tensor],
ks: List[torch.Tensor],
vs: List[torch.Tensor],
mask,
):

transpose_ks = []
for i in range(len(ks)):
new_k, _ = self.kv_cache.update(input_pos, ks[i], vs[i], i)
transpose_ks.append(new_k.transpose(-2, -1).contiguous())

output = []
for i, q in enumerate(qs):
cache_idx = i // self.n_rep
_, v = self.kv_cache.get_cache(cache_idx)

attn_mask = mask[input_pos]

attn_weight = q @ transpose_ks[cache_idx] / self.scale_factor
attn_weight += attn_mask
attn_weight = torch.softmax(attn_weight, dim=-1)
output.append(attn_weight @ v.contiguous())

return torch.cat(output, dim=-1)


class AttentionSHA(nn.Module):
def __init__(self, attention_mha: nn.Module):
super().__init__()
if not attention_mha.use_kv_cache:
raise NotImplementedError("bert mode is not support")

self.n_heads = attention_mha.n_heads
self.n_kv_heads = attention_mha.n_kv_heads
self.n_rep = self.n_heads // self.n_kv_heads
self.dim = attention_mha.dim
self.max_batch_size = attention_mha.max_batch_size
self.max_seq_len = attention_mha.max_seq_len
self.head_dim = attention_mha.dim // self.n_heads
self.SDPA = SDPASHA(
self.max_batch_size,
self.max_seq_len,
self.n_heads,
self.n_rep,
self.head_dim,
self.dim,
)
self.wq = nn.ModuleList(
[
nn.Linear(self.dim, self.head_dim, bias=False)
for _ in range(self.n_heads)
]
)
self.wk = nn.ModuleList(
[
nn.Linear(self.dim, self.head_dim, bias=False)
for _ in range(self.n_kv_heads)
]
)
self.wv = nn.ModuleList(
[
nn.Linear(self.dim, self.head_dim, bias=False)
for _ in range(self.n_kv_heads)
]
)

for i in range(self.n_heads):
self.wq[i].weight.data.copy_(
attention_mha.wq.weight[i * self.head_dim : (i + 1) * self.head_dim]
)
for i in range(self.n_kv_heads):
self.wk[i].weight.data.copy_(
attention_mha.wk.weight[i * self.head_dim : (i + 1) * self.head_dim]
)
self.wv[i].weight.data.copy_(
attention_mha.wv.weight[i * self.head_dim : (i + 1) * self.head_dim]
)
self.wo = attention_mha.wo

causal_mask = torch.tril(
torch.ones(
self.max_seq_len,
self.max_seq_len,
dtype=torch.bool,
device="cpu",
)
)
self.register_buffer("mask", causal_mask, persistent=False)

def forward(
self,
x: torch.Tensor,
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor,
input_pos: Optional[torch.Tensor] = None,
):
# QKV
q = [wq(x) for wq in self.wq]
k = [wk(x) for wk in self.wk]
v = [wv(x) for wv in self.wv]
for i in range(len(q)):
q[i] = apply_rotary_emb_single(q[i], freqs_cos, freqs_sin)
for i in range(len(k)):
k[i] = apply_rotary_emb_single(k[i], freqs_cos, freqs_sin)

output = self.SDPA(input_pos, q, k, v, self.mask)
return self.wo(output)


def replace_attention_to_attention_sha(module: torch.nn.Module):
for name, child in module.named_children():
if isinstance(child, Attention):
setattr(
module,
name,
AttentionSHA(child),
)
else:
replace_attention_to_attention_sha(child)
return module

0 comments on commit 576e96c

Please sign in to comment.