Skip to content

Qualcomm AI Engine Direct - Set llama io as quantized tensor #5383

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions backends/qualcomm/quantizer/custom_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
QuantizationConfig,
)
from executorch.backends.qualcomm.quantizer.utils import QUANT_ANNOTATION_KEY
from executorch.exir.dialects._ops import ops as exir_ops
from torch.ao.quantization.quantizer import (
QuantizationAnnotation,
SharedQuantizationSpec,
Expand Down Expand Up @@ -144,3 +145,35 @@ def annotate_matmul(node: Node, quantization_config: QuantizationConfig):
for node in gm.graph.nodes:
if node.op == "call_function" and node.target == torch.ops.aten.matmul.default:
annotate_matmul(node, quantization_config_16a8w)


def get_custom_quant_ios_dtype(
cache_shape: torch.Size,
node: torch.fx.Node,
kv_dtype=torch.uint8,
sharding_dtype=torch.uint16,
):
"""
This function is specific for llama inputs and outputs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm trying to understand why it is specific to llama inputs/outputs. Is it because of the sharding of the model? Like the output of the first sharding doesn't need to dequant and the input of the second sharding doesn't need quant node?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me try to make it clear.
This function is to clean the redundant Q/DQ nodes such as KV I/O and intermediate tensors between the sharding as you mentioned.
In original flow, we will quantize KV input and dequantize KV output every inference.
In fact, we don’t need to do this, we can directly output the quantized KV and put it into the model for the next inference.

"""
if node.op == "placeholder" and "attention_sdpa_kv_cache_past_" in node.name:
return kv_dtype

# Tag index put node before copy node, because copy is a skipped node in qnn
if (
exir_ops.edge.aten.index_put.default == node.target
and node.meta["val"].shape == cache_shape
):
return kv_dtype

# Tag sharding io
if exir_ops.edge.llama.fallback.default in [
u.target for u in list(node.users.keys())
] + [node.target]:
return sharding_dtype

# Tag index op as quantized tensors. It is caused by sharding
if exir_ops.edge.aten.index.Tensor in [
u.target for u in list(node.users.keys())
] + [node.target]:
return sharding_dtype
10 changes: 10 additions & 0 deletions backends/qualcomm/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
QCOM_PASS_EXPAND_BROADCAST_SHAPE,
QCOM_PASS_SKIP_ADVANCED_REQUANT,
QCOM_QNN_COMPILE_SPEC,
QCOM_QUANTIZED_IO,
)

from executorch.exir import ExirExportedProgram
Expand Down Expand Up @@ -876,3 +877,12 @@ def get_soc_to_chipset_map():
"SM8475": QcomChipset.SM8475,
"SM8450": QcomChipset.SM8450,
}


def tag_quant_io(gm: torch.fx.GraphModule, get_quant_io_dtype_fn: Callable):
"""
Tag io nodes which get/output quantized tensor. No need to insert q/dq in qnn_preprocess
"""
for node in gm.graph.nodes:
if dtype := get_quant_io_dtype_fn(node):
node.meta[QCOM_QUANTIZED_IO] = dtype
18 changes: 17 additions & 1 deletion examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,7 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
)
)
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils`
from executorch.backends.qualcomm.utils.utils import _transform
from executorch.backends.qualcomm.utils.utils import _transform, tag_quant_io

# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`, Optional type has no attribute `exported_program`
_transform(builder_exported_to_edge.edge_manager.exported_program())
Expand All @@ -656,6 +656,22 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
shares=args.num_sharding,
)

from functools import partial

from executorch.backends.qualcomm.quantizer.custom_annotation import (
get_custom_quant_ios_dtype,
)

tag_quant_io(
builder_exported_to_edge.edge_manager.exported_program().graph_module,
partial(
get_custom_quant_ios_dtype,
builder_exported_to_edge.model.layers[
0
].attention.kv_cache.past_k_caches.shape,
),
)

logging.info("Lowering model using following partitioner(s): ")
for partitioner in partitioners:
logging.info(f"--> {partitioner.__class__.__name__}")
Expand Down
Loading