Skip to content

Migrate pt2e arm #11053

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -386,8 +386,6 @@ exclude_patterns = [
"third-party/**",
# TODO: remove exceptions as we migrate
# backends
"backends/arm/quantizer/**",
"backends/arm/test/ops/**",
"backends/vulkan/quantizer/**",
"backends/vulkan/test/**",
"backends/cadence/aot/quantizer/**",
Expand Down
4 changes: 4 additions & 0 deletions backends/arm/quantizer/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ python_library(
srcs = ["quantization_config.py"],
deps = [
"//caffe2:torch",
"//pytorch/ao:torchao",
],
)

Expand All @@ -18,6 +19,7 @@ python_library(
":quantization_annotator",
"//caffe2:torch",
"//executorch/exir:lib",
"//pytorch/ao:torchao",
],
)

Expand All @@ -28,6 +30,7 @@ python_library(
":arm_quantizer_utils",
":quantization_config",
"//caffe2:torch",
"//pytorch/ao:torchao",
],
)

Expand All @@ -36,6 +39,7 @@ python_library(
srcs = ["arm_quantizer_utils.py"],
deps = [
":quantization_config",
"//pytorch/ao:torchao",
],
)

Expand Down
25 changes: 13 additions & 12 deletions backends/arm/quantizer/arm_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,25 +30,26 @@
is_vgf,
) # usort: skip
from executorch.exir.backend.compile_spec_schema import CompileSpec
from torch.ao.quantization.fake_quantize import (

from torch.fx import GraphModule, Node
from torchao.quantization.pt2e import (
FakeQuantize,
FusedMovingAvgObsFakeQuantize,
)
from torch.ao.quantization.observer import (
HistogramObserver,
MinMaxObserver,
MovingAverageMinMaxObserver,
MovingAveragePerChannelMinMaxObserver,
ObserverOrFakeQuantizeConstructor,
PerChannelMinMaxObserver,
PlaceholderObserver,
)
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
from torch.ao.quantization.quantizer.utils import (
_annotate_input_qspec_map,
_annotate_output_qspec,

from torchao.quantization.pt2e.quantizer import (
annotate_input_qspec_map,
annotate_output_qspec,
QuantizationSpec,
Quantizer,
)
from torch.fx import GraphModule, Node

__all__ = [
"TOSAQuantizer",
Expand Down Expand Up @@ -97,7 +98,7 @@ def get_symmetric_quantization_config(
weight_qscheme = (
torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric
)
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
weight_observer_or_fake_quant_ctr: ObserverOrFakeQuantizeConstructor = (
MinMaxObserver
)
if is_qat:
Expand Down Expand Up @@ -337,14 +338,14 @@ def _annotate_io(
if is_annotated(node):
continue
if node.op == "placeholder" and len(node.users) > 0:
_annotate_output_qspec(
annotate_output_qspec(
node,
quantization_config.get_output_act_qspec(),
)
mark_node_as_annotated(node)
if node.op == "output":
parent = node.all_input_nodes[0]
_annotate_input_qspec_map(
annotate_input_qspec_map(
node, parent, quantization_config.get_input_act_qspec()
)
mark_node_as_annotated(node)
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/quantizer/arm_quantizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@

import torch
from torch._subclasses import FakeTensor

from torch.ao.quantization.quantizer import QuantizationAnnotation
from torch.fx import GraphModule, Node

from torchao.quantization.pt2e.quantizer import QuantizationAnnotation


def is_annotated(node: Node) -> bool:
"""Given a node return whether the node is annotated."""
Expand Down
16 changes: 9 additions & 7 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
import torch.nn.functional as F
from executorch.backends.arm.quantizer import QuantizationConfig
from executorch.backends.arm.tosa_utils import get_node_debug_info
from torch.ao.quantization.quantizer import QuantizationSpecBase, SharedQuantizationSpec
from torch.ao.quantization.quantizer.utils import (
_annotate_input_qspec_map,
_annotate_output_qspec,
)

from torch.fx import Node
from torchao.quantization.pt2e.quantizer import (
annotate_input_qspec_map,
annotate_output_qspec,
QuantizationSpecBase,
SharedQuantizationSpec,
)

from .arm_quantizer_utils import (
is_annotated,
Expand Down Expand Up @@ -119,7 +121,7 @@ def _annotate_input(node: Node, quant_property: _QuantProperty):
strict=True,
):
assert isinstance(n_arg, Node)
_annotate_input_qspec_map(node, n_arg, qspec)
annotate_input_qspec_map(node, n_arg, qspec)
if quant_property.mark_annotated:
mark_node_as_annotated(n_arg) # type: ignore[attr-defined]

Expand All @@ -130,7 +132,7 @@ def _annotate_output(node: Node, quant_property: _QuantProperty):
assert not quant_property.optional
assert quant_property.index == 0, "Only one output annotation supported currently"

_annotate_output_qspec(node, quant_property.qspec)
annotate_output_qspec(node, quant_property.qspec)


def _match_pattern(
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/quantizer/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
from dataclasses import dataclass

import torch
from torch.ao.quantization import ObserverOrFakeQuantize
from torchao.quantization.pt2e import ObserverOrFakeQuantize

from torch.ao.quantization.quantizer import (
from torchao.quantization.pt2e.quantizer import (
DerivedQuantizationSpec,
FixedQParamsQuantizationSpec,
QuantizationSpec,
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/test/ops/test_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
)
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.backends.xnnpack.test.tester import Quantize
from torch.ao.quantization.observer import HistogramObserver
from torch.ao.quantization.quantizer import QuantizationSpec
from torchao.quantization.pt2e import HistogramObserver
from torchao.quantization.pt2e.quantizer import QuantizationSpec

aten_op = "torch.ops.aten.add.Tensor"
exir_op = "executorch_exir_dialects_edge__ops_aten_add_Tensor"
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/test/ops/test_sigmoid_16bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
)
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.backends.xnnpack.test.tester import Quantize
from torch.ao.quantization.observer import HistogramObserver
from torch.ao.quantization.quantizer import QuantizationSpec
from torchao.quantization.pt2e import HistogramObserver
from torchao.quantization.pt2e.quantizer import QuantizationSpec


def _get_16_bit_quant_config():
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/test/ops/test_sigmoid_32bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
)
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.backends.xnnpack.test.tester import Quantize
from torch.ao.quantization.observer import HistogramObserver
from torch.ao.quantization.quantizer import QuantizationSpec
from torchao.quantization.pt2e import HistogramObserver
from torchao.quantization.pt2e.quantizer import QuantizationSpec


def _get_16_bit_quant_config():
Expand Down
Loading