Skip to content

Migrate pt2e qualcomm #11049

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -391,15 +391,12 @@ exclude_patterns = [
"backends/vulkan/quantizer/**",
"backends/vulkan/test/**",
"backends/cadence/aot/quantizer/**",
"backends/qualcomm/quantizer/**",
"examples/qualcomm/**",
"backends/xnnpack/quantizer/**",
"backends/xnnpack/test/**",
"exir/tests/test_passes.py",
"extension/llm/export/builder.py",
"extension/llm/export/quantizer_lib.py",
"exir/tests/test_memory_planning.py",
"backends/transforms/duplicate_dynamic_quant_chain.py",
"exir/backend/test/demos/test_xnnpack_qnnpack.py",
]

Expand Down
4 changes: 2 additions & 2 deletions backends/qualcomm/_passes/qnn_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ def get_to_edge_transform_passes(
from executorch.backends.qualcomm._passes import utils
from executorch.exir.dialects._ops import ops as exir_ops

utils.q_ops.add(exir_ops.edge.pt2e_quant.quantize_affine.default)
utils.dq_ops.add(exir_ops.edge.pt2e_quant.dequantize_affine.default)
utils.q_ops.add(exir_ops.edge.torchao.quantize_affine.default)
utils.dq_ops.add(exir_ops.edge.torchao.dequantize_affine.default)

passes_job = (
passes_job if passes_job is not None else get_capture_program_passes()
Expand Down
4 changes: 2 additions & 2 deletions backends/qualcomm/builders/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,8 @@ def get_quant_encoding_conf(
)
# TODO: refactor this when target could be correctly detected
per_block_encoding = {
exir_ops.edge.pt2e_quant.quantize_affine.default,
exir_ops.edge.pt2e_quant.dequantize_affine.default,
exir_ops.edge.torchao.quantize_affine.default,
exir_ops.edge.torchao.dequantize_affine.default,
}
if quant_attrs[QCOM_ENCODING] in per_block_encoding:
return self.make_qnn_per_block_config(node, quant_attrs)
Expand Down
4 changes: 2 additions & 2 deletions backends/qualcomm/partition/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def get_skip_decomp_table() -> List[torch._ops.OperatorBase]:
torch.ops.aten.upsample_bicubic2d.vec,
# This request is ignored because it is in a blocklist. Refer to exir/program/_program.py
torch.ops.aten.unbind.int,
torch.ops.pt2e_quant.quantize_affine.default,
torch.ops.pt2e_quant.dequantize_affine.default,
torch.ops.torchao.quantize_affine.default,
torch.ops.torchao.dequantize_affine.default,
]
return do_not_decompose
55 changes: 26 additions & 29 deletions backends/qualcomm/quantizer/annotators.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,17 @@
from torch._ops import OpOverload

from torch._subclasses import FakeTensor
from torch.ao.quantization.fake_quantize import FixedQParamsFakeQuantize
from torch.fx import Node

from torch.ao.quantization.observer import FixedQParamsObserver
from torch.ao.quantization.quantizer import (
from torchao.quantization.pt2e import FixedQParamsFakeQuantize, FixedQParamsObserver
from torchao.quantization.pt2e.quantizer import (
annotate_input_qspec_map,
annotate_output_qspec,
DerivedQuantizationSpec,
QuantizationAnnotation,
QuantizationSpec,
SharedQuantizationSpec,
)
from torch.ao.quantization.quantizer.utils import (
_annotate_input_qspec_map,
_annotate_output_qspec,
)
from torch.fx import Node

from .qconfig import (
get_16a16w_qnn_ptq_config,
Expand Down Expand Up @@ -618,19 +615,19 @@ def annotate_rms_norm(node: Node, quantization_config: QuantizationConfig) -> No
return

# TODO current only support 16a16w
_annotate_input_qspec_map(
annotate_input_qspec_map(
node,
act_node,
quantization_config.input_activation,
)

_annotate_input_qspec_map(
annotate_input_qspec_map(
node,
weight_node,
quantization_config.input_activation,
)
nodes_to_mark_annotated = [node]
_annotate_output_qspec(node, quantization_config.output_activation)
annotate_output_qspec(node, quantization_config.output_activation)
_mark_nodes_as_annotated(nodes_to_mark_annotated)


Expand Down Expand Up @@ -819,25 +816,25 @@ def annotate_group_norm(node: Node, quantization_config: QuantizationConfig) ->
if _is_annotated([node]):
return

_annotate_input_qspec_map(
annotate_input_qspec_map(
node,
act_node,
quantization_config.input_activation,
)
_annotate_input_qspec_map(
annotate_input_qspec_map(
node,
weight_node,
quantization_config.weight,
)
nodes_to_mark_annotated = [node, weight_node]
if bias_node:
_annotate_input_qspec_map(
annotate_input_qspec_map(
node,
bias_node,
quantization_config.bias,
)
nodes_to_mark_annotated.append(bias_node)
_annotate_output_qspec(node, quantization_config.output_activation)
annotate_output_qspec(node, quantization_config.output_activation)
_mark_nodes_as_annotated(nodes_to_mark_annotated)


Expand Down Expand Up @@ -1002,12 +999,12 @@ def annotate_linear(node: Node, quantization_config: QuantizationConfig) -> None
if _is_annotated([node]):
return

_annotate_input_qspec_map(
annotate_input_qspec_map(
node,
act_node,
quantization_config.input_activation,
)
_annotate_input_qspec_map(
annotate_input_qspec_map(
node,
weight_node,
quantization_config.weight,
Expand All @@ -1018,9 +1015,9 @@ def annotate_linear(node: Node, quantization_config: QuantizationConfig) -> None
bias_config = quantization_config.bias(node)
else:
bias_config = quantization_config.bias
_annotate_input_qspec_map(node, bias_node, bias_config)
annotate_input_qspec_map(node, bias_node, bias_config)
nodes_to_mark_annotated.append(bias_node)
_annotate_output_qspec(node, quantization_config.output_activation)
annotate_output_qspec(node, quantization_config.output_activation)
_mark_nodes_as_annotated(nodes_to_mark_annotated)

# We use get_source_partition in pass, but it is the same source for MultiheadAttention, so we need to change its source_fn_stack.
Expand All @@ -1038,29 +1035,29 @@ def annotate_batch_and_instance_norm(
return

annotated_args = [act]
_annotate_input_qspec_map(
annotate_input_qspec_map(
node,
act,
quantization_config.input_activation,
)
# QNN requires uint8 instead of int8 in 'weight' config
if weight is not None:
_annotate_input_qspec_map(
annotate_input_qspec_map(
node,
weight,
quantization_config.input_activation,
)
annotated_args.append(weight)

if bias is not None:
_annotate_input_qspec_map(
annotate_input_qspec_map(
node,
bias,
quantization_config.bias,
)
annotated_args.append(bias)

_annotate_output_qspec(node, quantization_config.output_activation)
annotate_output_qspec(node, quantization_config.output_activation)
_mark_nodes_as_annotated([node, *annotated_args])


Expand All @@ -1070,7 +1067,7 @@ def annotate_getitem(node: Node, quantization_config: QuantizationConfig) -> Non
return

if _is_float_tensor(node):
_annotate_output_qspec(node, quantization_config.output_activation)
annotate_output_qspec(node, quantization_config.output_activation)
_mark_nodes_as_annotated([node])


Expand All @@ -1086,32 +1083,32 @@ def annotate_layer_norm(node: Node, quantization_config: QuantizationConfig) ->
return
input_act_qspec = quantization_config.input_activation

_annotate_input_qspec_map(
annotate_input_qspec_map(
node,
act_node,
input_act_qspec,
)
if input_act_qspec.dtype == torch.int32:
_annotate_input_qspec_map(
annotate_input_qspec_map(
node,
weight_node,
get_16a16w_qnn_ptq_config().weight,
)
else:
_annotate_input_qspec_map(
annotate_input_qspec_map(
node,
weight_node,
input_act_qspec,
)
nodes_to_mark_annotated = [node, weight_node]
if bias_node:
_annotate_input_qspec_map(
annotate_input_qspec_map(
node,
bias_node,
quantization_config.bias,
)
nodes_to_mark_annotated.append(bias_node)
_annotate_output_qspec(node, quantization_config.output_activation)
annotate_output_qspec(node, quantization_config.output_activation)
_mark_nodes_as_annotated(nodes_to_mark_annotated)


Expand Down
6 changes: 3 additions & 3 deletions backends/qualcomm/quantizer/custom_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
QuantizationConfig,
)
from executorch.exir.dialects._ops import ops as exir_ops
from torch.ao.quantization.observer import FixedQParamsObserver, MinMaxObserver
from torch.ao.quantization.quantizer import (
from torch.fx import Node
from torchao.quantization.pt2e import FixedQParamsObserver, MinMaxObserver
from torchao.quantization.pt2e.quantizer import (
QuantizationAnnotation,
QuantizationSpec,
SharedQuantizationSpec,
)
from torch.fx import Node


def annotate_mimi_decoder(gm: torch.fx.GraphModule):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from typing import Tuple

import torch
from torch.ao.quantization.observer import MappingType, PerBlock
from torch.ao.quantization.pt2e._affine_quantization import (
from torchao.quantization.pt2e import MappingType, PerBlock
from torchao.quantization.pt2e._affine_quantization import (
_get_reduction_params,
AffineQuantizedMinMaxObserver,
choose_qparams_affine_with_min_max,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

import torch
from torch.ao.quantization.observer import UniformQuantizationObserverBase
from torchao.quantization.pt2e import UniformQuantizationObserverBase


# TODO move to torch/ao/quantization/observer.py.
Expand Down
11 changes: 6 additions & 5 deletions backends/qualcomm/quantizer/qconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,19 @@
PerBlockParamObserver,
)
from torch import Tensor
from torch.ao.quantization.fake_quantize import (
from torch.fx import Node
from torchao.quantization.pt2e import (
FakeQuantize,
FusedMovingAvgObsFakeQuantize,
)
from torch.ao.quantization.observer import (
MinMaxObserver,
MovingAverageMinMaxObserver,
MovingAveragePerChannelMinMaxObserver,
PerChannelMinMaxObserver,
)
from torch.ao.quantization.quantizer import DerivedQuantizationSpec, QuantizationSpec
from torch.fx import Node
from torchao.quantization.pt2e.quantizer import (
DerivedQuantizationSpec,
QuantizationSpec,
)


@dataclass(eq=True)
Expand Down
7 changes: 3 additions & 4 deletions backends/qualcomm/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
from executorch.backends.qualcomm._passes.qnn_pass_manager import QnnPassManager

from torch._ops import OpOverload
from torch.ao.quantization.quantizer import Quantizer
from torch.fx import GraphModule
from torchao.quantization.pt2e import UniformQuantizationObserverBase
from torchao.quantization.pt2e.quantizer import Quantizer

from .annotators import OP_ANNOTATOR

Expand Down Expand Up @@ -130,9 +131,7 @@ class ModuleQConfig:
is_qat: bool = False
is_conv_per_channel: bool = False
is_linear_per_channel: bool = False
act_observer: Optional[
torch.ao.quantization.observer.UniformQuantizationObserverBase
] = None
act_observer: Optional[UniformQuantizationObserverBase] = None

def __post_init__(self):
if (self.quant_dtype, self.is_qat) not in QUANT_CONFIG_DICT:
Expand Down
7 changes: 4 additions & 3 deletions backends/qualcomm/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import numpy as np
import torch
import torchao
from executorch import exir
from executorch.backends.qualcomm._passes.utils import dq_ops
from executorch.backends.qualcomm.qnn_preprocess import QnnBackend
Expand Down Expand Up @@ -537,8 +538,8 @@ def get_qdq_module(
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.quantize_per_channel.default,
torch.ops.quantized_decomposed.dequantize_per_channel.default,
torch.ops.pt2e_quant.quantize_affine.default,
torch.ops.pt2e_quant.dequantize_affine.default,
torch.ops.torchao.quantize_affine.default,
torch.ops.torchao.dequantize_affine.default,
}
if not bypass_check:
self.assertTrue(nodes.intersection(q_and_dq))
Expand Down Expand Up @@ -569,7 +570,7 @@ def get_prepared_qat_module(
quantizer.set_submodule_qconfig_list(submodule_qconfig_list)

prepared = prepare_qat_pt2e(m, quantizer)
return torch.ao.quantization.move_exported_model_to_train(prepared)
return torchao.quantization.pt2e.move_exported_model_to_train(prepared)

def get_converted_sgd_trained_module(
self,
Expand Down
10 changes: 4 additions & 6 deletions backends/transforms/duplicate_dynamic_quant_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,12 @@

import torch

from torch.ao.quantization.pt2e.utils import (
_filter_sym_size_users,
_is_valid_annotation,
)

from torch.fx.node import map_arg
from torch.fx.passes.infra.pass_base import PassBase, PassResult

from torchao.quantization.pt2e.quantizer import is_valid_annotation
from torchao.quantization.pt2e.utils import _filter_sym_size_users


logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
Expand Down Expand Up @@ -129,7 +127,7 @@ def _maybe_duplicate_dynamic_quantize_chain(
dq_node_users = list(dq_node.users.copy())
for user in dq_node_users:
annotation = user.meta.get("quantization_annotation", None)
if not _is_valid_annotation(annotation):
if not is_valid_annotation(annotation):
return
with gm.graph.inserting_after(dq_node):
new_node = gm.graph.node_copy(dq_node)
Expand Down
2 changes: 1 addition & 1 deletion examples/qualcomm/oss_scripts/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
from pytorch_tokenizers import get_tokenizer, TiktokenTokenizer
from pytorch_tokenizers.llama2c import Llama2cTokenizer as SentencePieceTokenizer

from torch.ao.quantization.observer import MinMaxObserver
from torchao.quantization.pt2e import MinMaxObserver
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e

sys.setrecursionlimit(4096)
Expand Down
2 changes: 1 addition & 1 deletion examples/qualcomm/oss_scripts/moshi/mimi.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from huggingface_hub import hf_hub_download
from moshi.models import loaders

from torch.ao.quantization.observer import MinMaxObserver
from torchao.quantization.pt2e import MinMaxObserver


def seed_all(seed):
Expand Down
Loading
Loading