Skip to content

Commit 21f1745

Browse files
author
Joey Tsai
committed
Qualcomm AI Engine Direct - Add llama io be quantized
- Add general function to tag io obtain/genetate quantized tensor - Add quantizing io function to llama2.py
1 parent ad0e5e8 commit 21f1745

File tree

3 files changed

+61
-2
lines changed

3 files changed

+61
-2
lines changed

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
QuantizationConfig,
1313
)
1414
from executorch.backends.qualcomm.quantizer.utils import QUANT_ANNOTATION_KEY
15+
from executorch.exir.dialects._ops import ops as exir_ops
16+
from executorch.extension.llm.export.builder import LLMEdgeManager
1517
from torch.ao.quantization.quantizer import (
1618
QuantizationAnnotation,
1719
SharedQuantizationSpec,
@@ -144,3 +146,35 @@ def annotate_matmul(node: Node, quantization_config: QuantizationConfig):
144146
for node in gm.graph.nodes:
145147
if node.op == "call_function" and node.target == torch.ops.aten.matmul.default:
146148
annotate_matmul(node, quantization_config_16a8w)
149+
150+
151+
def get_custom_quant_ios_dtype(
152+
cache_shape: torch.Size,
153+
node: torch.fx.Node,
154+
kv_dtype=torch.uint8,
155+
sharding_dtype=torch.uint16,
156+
):
157+
"""
158+
This function is specific for llama inputs and outputs
159+
"""
160+
if node.op == "placeholder" and "attention_sdpa_kv_cache_past_" in node.name:
161+
return kv_dtype
162+
163+
# Tag index put node before copy node, because copy is a skipped node in qnn
164+
if (
165+
exir_ops.edge.aten.index_put.default == node.target
166+
and node.meta["val"].shape == cache_shape
167+
):
168+
return kv_dtype
169+
170+
# Tag sharding io
171+
if exir_ops.edge.llama.fallback.default in [
172+
u.target for u in list(node.users.keys())
173+
] + [node.target]:
174+
return sharding_dtype
175+
176+
# Tag index op as quantized tensors. It is caused by sharding
177+
if exir_ops.edge.aten.index.Tensor in [
178+
u.target for u in list(node.users.keys())
179+
] + [node.target]:
180+
return sharding_dtype

backends/qualcomm/utils/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
QCOM_PASS_EXPAND_BROADCAST_SHAPE,
7272
QCOM_PASS_SKIP_ADVANCED_REQUANT,
7373
QCOM_QNN_COMPILE_SPEC,
74+
QCOM_QUANTIZED_IO,
7475
)
7576

7677
from executorch.exir import ExirExportedProgram
@@ -876,3 +877,12 @@ def get_soc_to_chipset_map():
876877
"SM8475": QcomChipset.SM8475,
877878
"SM8450": QcomChipset.SM8450,
878879
}
880+
881+
882+
def tag_quant_io(gm: torch.fx.GraphModule, get_quant_io_dtype_fn: Callable):
883+
"""
884+
Tag io nodes which get/output quantized tensor. No need to insert q/dq in qnn_preprocess
885+
"""
886+
for node in gm.graph.nodes:
887+
if dtype := get_quant_io_dtype_fn(node):
888+
node.meta[QCOM_QUANTIZED_IO] = dtype

examples/models/llama/export_llama_lib.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -642,8 +642,7 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
642642
)
643643
)
644644
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils`
645-
from executorch.backends.qualcomm.utils.utils import _transform
646-
645+
from executorch.backends.qualcomm.utils.utils import _transform, tag_quant_io
647646
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`, Optional type has no attribute `exported_program`
648647
_transform(builder_exported_to_edge.edge_manager.exported_program())
649648

@@ -655,6 +654,22 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
655654
shares=args.num_sharding,
656655
)
657656

657+
from functools import partial
658+
659+
from executorch.backends.qualcomm.quantizer.custom_annotation import (
660+
get_custom_quant_ios_dtype,
661+
)
662+
663+
tag_quant_io(
664+
builder_exported_to_edge.edge_manager.exported_program().graph_module,
665+
partial(
666+
get_custom_quant_ios_dtype,
667+
builder_exported_to_edge.model.layers[
668+
0
669+
].attention.kv_cache.past_k_caches.shape,
670+
),
671+
)
672+
658673
logging.info("Lowering model using following partitioner(s): ")
659674
for partitioner in partitioners:
660675
logging.info(f"--> {partitioner.__class__.__name__}")

0 commit comments

Comments
 (0)