Skip to content

Commit b2f73a3

Browse files
chunit-quicJoey Tsai
and
Joey Tsai
authored
Qualcomm AI Engine Direct - Set llama io as quantized tensor (#5383)
* Qualcomm AI Engine Direct - Add llama io be quantized - Add general function to tag io obtain/genetate quantized tensor - Add quantizing io function to llama2.py * [Fix lint] --------- Co-authored-by: Joey Tsai <chunit@qti.qualcomm.com>
1 parent 16b633b commit b2f73a3

File tree

3 files changed

+60
-1
lines changed

3 files changed

+60
-1
lines changed

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
QuantizationConfig,
1313
)
1414
from executorch.backends.qualcomm.quantizer.utils import QUANT_ANNOTATION_KEY
15+
from executorch.exir.dialects._ops import ops as exir_ops
1516
from torch.ao.quantization.quantizer import (
1617
QuantizationAnnotation,
1718
SharedQuantizationSpec,
@@ -144,3 +145,35 @@ def annotate_matmul(node: Node, quantization_config: QuantizationConfig):
144145
for node in gm.graph.nodes:
145146
if node.op == "call_function" and node.target == torch.ops.aten.matmul.default:
146147
annotate_matmul(node, quantization_config_16a8w)
148+
149+
150+
def get_custom_quant_ios_dtype(
151+
cache_shape: torch.Size,
152+
node: torch.fx.Node,
153+
kv_dtype=torch.uint8,
154+
sharding_dtype=torch.uint16,
155+
):
156+
"""
157+
This function is specific for llama inputs and outputs
158+
"""
159+
if node.op == "placeholder" and "attention_sdpa_kv_cache_past_" in node.name:
160+
return kv_dtype
161+
162+
# Tag index put node before copy node, because copy is a skipped node in qnn
163+
if (
164+
exir_ops.edge.aten.index_put.default == node.target
165+
and node.meta["val"].shape == cache_shape
166+
):
167+
return kv_dtype
168+
169+
# Tag sharding io
170+
if exir_ops.edge.llama.fallback.default in [
171+
u.target for u in list(node.users.keys())
172+
] + [node.target]:
173+
return sharding_dtype
174+
175+
# Tag index op as quantized tensors. It is caused by sharding
176+
if exir_ops.edge.aten.index.Tensor in [
177+
u.target for u in list(node.users.keys())
178+
] + [node.target]:
179+
return sharding_dtype

backends/qualcomm/utils/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
QCOM_PASS_EXPAND_BROADCAST_SHAPE,
7272
QCOM_PASS_SKIP_ADVANCED_REQUANT,
7373
QCOM_QNN_COMPILE_SPEC,
74+
QCOM_QUANTIZED_IO,
7475
)
7576

7677
from executorch.exir import ExirExportedProgram
@@ -876,3 +877,12 @@ def get_soc_to_chipset_map():
876877
"SM8475": QcomChipset.SM8475,
877878
"SM8450": QcomChipset.SM8450,
878879
}
880+
881+
882+
def tag_quant_io(gm: torch.fx.GraphModule, get_quant_io_dtype_fn: Callable):
883+
"""
884+
Tag io nodes which get/output quantized tensor. No need to insert q/dq in qnn_preprocess
885+
"""
886+
for node in gm.graph.nodes:
887+
if dtype := get_quant_io_dtype_fn(node):
888+
node.meta[QCOM_QUANTIZED_IO] = dtype

examples/models/llama/export_llama_lib.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -650,7 +650,7 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
650650
)
651651
)
652652
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils`
653-
from executorch.backends.qualcomm.utils.utils import _transform
653+
from executorch.backends.qualcomm.utils.utils import _transform, tag_quant_io
654654

655655
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`, Optional type has no attribute `exported_program`
656656
_transform(builder_exported_to_edge.edge_manager.exported_program())
@@ -663,6 +663,22 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
663663
shares=args.num_sharding,
664664
)
665665

666+
from functools import partial
667+
668+
from executorch.backends.qualcomm.quantizer.custom_annotation import (
669+
get_custom_quant_ios_dtype,
670+
)
671+
672+
tag_quant_io(
673+
builder_exported_to_edge.edge_manager.exported_program().graph_module,
674+
partial(
675+
get_custom_quant_ios_dtype,
676+
builder_exported_to_edge.model.layers[
677+
0
678+
].attention.kv_cache.past_k_caches.shape,
679+
),
680+
)
681+
666682
logging.info("Lowering model using following partitioner(s): ")
667683
for partitioner in partitioners:
668684
logging.info(f"--> {partitioner.__class__.__name__}")

0 commit comments

Comments
 (0)