Skip to content

Migrate ExecuTorch's use of pt2e from torch.ao to torchao #10294

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -378,3 +378,31 @@ command = [
'--',
'@{{PATHSFILE}}',
]

[[linter]]
code = "TORCH_AO_IMPORT"
include_patterns = ["**/*.py"]
exclude_patterns = [
"third-party/**",
]

command = [
"python3",
"-m",
"lintrunner_adapters",
"run",
"grep_linter",
"--pattern=\\bfrom torch\\.ao\\.quantization\\.(?:quantizer|observer|quantize_pt2e|pt2e)(?:\\.[A-Za-z0-9_]+)*\\b",
"--linter-name=TorchAOImport",
"--error-name=Prohibited torch.ao.quantization import",
"""--error-description=\
Imports from torch.ao.quantization are not allowed. \
Please import from torchao.quantization.pt2e instead.\n \
* torchao.quantization.pt2e (includes all the utils, including observers, fake quants etc.) \n \
* torchao.quantization.pt2e.quantizer (quantizer related objects and utils) \n \
* torchao.quantization.pt2e.quantize_pt2e (prepare_pt2e, prepare_qat_pt2e, convert_pt2e) \n\n \
If you need something from torch.ao.quantization, you can add your file to an exclude_patterns for TORCH_AO_IMPORT in .lintrunner.toml. \
""",
"--",
"@{{PATHSFILE}}",
]
3 changes: 3 additions & 0 deletions .mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,6 @@ ignore_missing_imports = True

[mypy-zstd]
ignore_missing_imports = True

[mypy-torchao.*]
follow_untyped_imports = True
4 changes: 2 additions & 2 deletions backends/apple/coreml/test/test_coreml_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
)

from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer
from torch.ao.quantization.quantize_pt2e import (
from torch.export import export_for_training
from torchao.quantization.pt2e.quantize_pt2e import (
convert_pt2e,
prepare_pt2e,
prepare_qat_pt2e,
)
from torch.export import export_for_training


class TestCoreMLQuantizer:
Expand Down
23 changes: 11 additions & 12 deletions backends/arm/quantizer/arm_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,25 +30,24 @@
is_vgf,
) # usort: skip
from executorch.exir.backend.compile_spec_schema import CompileSpec
from torch.ao.quantization.fake_quantize import (
from torch.fx import GraphModule, Node
from torchao.quantization.pt2e import (
FakeQuantize,
FusedMovingAvgObsFakeQuantize,
)
from torch.ao.quantization.observer import (
HistogramObserver,
MinMaxObserver,
MovingAverageMinMaxObserver,
MovingAveragePerChannelMinMaxObserver,
ObserverOrFakeQuantizeConstructor,
PerChannelMinMaxObserver,
PlaceholderObserver,
)
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
from torch.ao.quantization.quantizer.utils import (
_annotate_input_qspec_map,
_annotate_output_qspec,
from torchao.quantization.pt2e.quantizer import (
annotate_input_qspec_map,
annotate_output_qspec,
QuantizationSpec,
Quantizer,
)
from torch.fx import GraphModule, Node

__all__ = [
"TOSAQuantizer",
Expand Down Expand Up @@ -97,7 +96,7 @@ def get_symmetric_quantization_config(
weight_qscheme = (
torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric
)
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
weight_observer_or_fake_quant_ctr: ObserverOrFakeQuantizeConstructor = (
MinMaxObserver
)
if is_qat:
Expand Down Expand Up @@ -337,14 +336,14 @@ def _annotate_io(
if is_annotated(node):
continue
if node.op == "placeholder" and len(node.users) > 0:
_annotate_output_qspec(
annotate_output_qspec(
node,
quantization_config.get_output_act_qspec(),
)
mark_node_as_annotated(node)
if node.op == "output":
parent = node.all_input_nodes[0]
_annotate_input_qspec_map(
annotate_input_qspec_map(
node, parent, quantization_config.get_input_act_qspec()
)
mark_node_as_annotated(node)
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/quantizer/arm_quantizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@

import torch
from torch._subclasses import FakeTensor

from torch.ao.quantization.quantizer import QuantizationAnnotation
from torch.fx import GraphModule, Node

from torchao.quantization.pt2e.quantizer import QuantizationAnnotation


def is_annotated(node: Node) -> bool:
"""Given a node return whether the node is annotated."""
Expand Down
15 changes: 8 additions & 7 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
import torch.fx
from executorch.backends.arm.quantizer import QuantizationConfig
from executorch.backends.arm.tosa_utils import get_node_debug_info
from torch.ao.quantization.quantizer import QuantizationSpecBase, SharedQuantizationSpec
from torch.ao.quantization.quantizer.utils import (
_annotate_input_qspec_map,
_annotate_output_qspec,
)
from torch.fx import Node
from torchao.quantization.pt2e.quantizer import (
annotate_input_qspec_map,
annotate_output_qspec,
QuantizationSpecBase,
SharedQuantizationSpec,
)

from .arm_quantizer_utils import (
is_annotated,
Expand Down Expand Up @@ -118,7 +119,7 @@ def _annotate_input(node: Node, quant_property: _QuantProperty):
strict=True,
):
assert isinstance(n_arg, Node)
_annotate_input_qspec_map(node, n_arg, qspec)
annotate_input_qspec_map(node, n_arg, qspec)
if quant_property.mark_annotated:
mark_node_as_annotated(n_arg) # type: ignore[attr-defined]

Expand All @@ -129,7 +130,7 @@ def _annotate_output(node: Node, quant_property: _QuantProperty):
assert not quant_property.optional
assert quant_property.index == 0, "Only one output annotation supported currently"

_annotate_output_qspec(node, quant_property.qspec)
annotate_output_qspec(node, quant_property.qspec)


def _match_pattern(
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/quantizer/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
from dataclasses import dataclass

import torch
from torch.ao.quantization import ObserverOrFakeQuantize
from torchao.quantization.pt2e import ObserverOrFakeQuantize

from torch.ao.quantization.quantizer import (
from torchao.quantization.pt2e.quantizer import (
DerivedQuantizationSpec,
FixedQParamsQuantizationSpec,
QuantizationSpec,
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/test/ops/test_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
)
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.backends.xnnpack.test.tester import Quantize
from torch.ao.quantization.observer import HistogramObserver
from torch.ao.quantization.quantizer import QuantizationSpec
from torchao.quantization.pt2e.observer import HistogramObserver
from torchao.quantization.pt2e.quantizer import QuantizationSpec

aten_op = "torch.ops.aten.add.Tensor"
exir_op = "executorch_exir_dialects_edge__ops_aten_add_Tensor"
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/test/ops/test_sigmoid_16bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
)
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.backends.xnnpack.test.tester import Quantize
from torch.ao.quantization.observer import HistogramObserver
from torch.ao.quantization.quantizer import QuantizationSpec
from torchao.quantization.pt2e.observer import HistogramObserver
from torchao.quantization.pt2e.quantizer import QuantizationSpec


def _get_16_bit_quant_config():
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/test/ops/test_sigmoid_32bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
)
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.backends.xnnpack.test.tester import Quantize
from torch.ao.quantization.observer import HistogramObserver
from torch.ao.quantization.quantizer import QuantizationSpec
from torchao.quantization.pt2e.observer import HistogramObserver
from torchao.quantization.pt2e.quantizer import QuantizationSpec


def _get_16_bit_quant_config():
Expand Down
2 changes: 1 addition & 1 deletion backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@
from executorch.exir.passes.sym_shape_eval_pass import HintBasedSymShapeEvalPass
from executorch.exir.program._program import to_edge_with_preserved_ops
from torch._inductor.decomposition import remove_decompositions
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e

from torch.export.exported_program import ExportedProgram
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


from .passes import get_cadence_passes

Expand Down
1 change: 1 addition & 0 deletions backends/cadence/aot/quantizer/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ python_library(
deps = [
":utils",
"//caffe2:torch",
"//pytorch/ao:torchao",
],
)

Expand Down
2 changes: 1 addition & 1 deletion backends/cadence/aot/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from torch import fx
from torch._ops import OpOverload
from torch.ao.quantization.quantizer import (
from torchao.quantization.pt2e.quantizer import (
DerivedQuantizationSpec,
SharedQuantizationSpec,
)
Expand Down
6 changes: 3 additions & 3 deletions backends/cadence/aot/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@

from torch import fx

from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver
from torch.ao.quantization.quantizer import DerivedQuantizationSpec, Quantizer
from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer
from torchao.quantization.pt2e.observer import HistogramObserver, MinMaxObserver
from torchao.quantization.pt2e.quantizer import DerivedQuantizationSpec, Quantizer
from torchao.quantization.pt2e.quantizer.composable_quantizer import ComposableQuantizer


act_qspec_asym8s = QuantizationSpec(
Expand Down
2 changes: 1 addition & 1 deletion backends/cadence/aot/quantizer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
import torch
from torch import fx
from torch._ops import OpOverload
from torch.ao.quantization import ObserverOrFakeQuantize

from torch.fx import GraphModule
from torch.fx.passes.utils.source_matcher_utils import (
check_subgraphs_connected,
SourcePartition,
)
from torchao.quantization.pt2e import ObserverOrFakeQuantize


def quantize_tensor_multiplier(
Expand Down
10 changes: 5 additions & 5 deletions backends/cortex_m/test/test_replace_quant_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@
ReplaceQuantNodesPass,
)
from executorch.exir.dialects._ops import ops as exir_ops
from torch.ao.quantization.observer import HistogramObserver
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer.quantizer import (
from torch.export import export, export_for_training
from torch.fx import GraphModule
from torchao.quantization.pt2e.observer import HistogramObserver
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
from torchao.quantization.pt2e.quantizer import (
QuantizationAnnotation,
QuantizationSpec,
Quantizer,
)
from torch.export import export, export_for_training
from torch.fx import GraphModule


@dataclass(eq=True, frozen=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dim_order_utils import get_dim_order
from executorch.exir.pass_base import ExportPass, PassResult
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
from torchao.quantization.pt2e import find_sequential_partitions


class PermuteMemoryFormatsPass(ExportPass):
Expand Down
2 changes: 1 addition & 1 deletion backends/example/example_operators/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torch.ao.quantization.quantizer.quantizer import QuantizationAnnotation
from torchao.quantization.pt2e.quantizer.quantizer import QuantizationAnnotation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

single quantizer



def _nodes_are_annotated(node_list):
Expand Down
2 changes: 1 addition & 1 deletion backends/example/example_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.graph_module import get_control_flow_submodules
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
from torch.export import ExportedProgram
from torch.fx.passes.operator_support import OperatorSupportBase
from torchao.quantization.pt2e import find_sequential_partitions


@final
Expand Down
9 changes: 6 additions & 3 deletions backends/example/example_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@
from executorch.backends.example.example_operators.ops import module_to_annotator
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils import OperatorConfig
from torch import fx
from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
from torchao.quantization.pt2e import (
find_sequential_partitions,
HistogramObserver,
MinMaxObserver,
)
from torchao.quantization.pt2e.quantizer import QuantizationSpec, Quantizer


def get_uint8_tensor_spec(observer_or_fake_quant_ctr):
Expand Down
4 changes: 2 additions & 2 deletions backends/example/test_example_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
DuplicateDequantNodePass,
)
from executorch.exir.delegate import executorch_call_delegate

from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.export import export

from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e

from torchvision.models.quantization import mobilenet_v2


Expand Down
18 changes: 9 additions & 9 deletions backends/mediatek/quantizer/annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,18 @@
from torch._ops import OpOverload
from torch._subclasses import FakeTensor

from torch.ao.quantization.quantizer import QuantizationAnnotation
from torch.ao.quantization.quantizer.utils import (
_annotate_input_qspec_map,
_annotate_output_qspec,
)

from torch.export import export_for_training
from torch.fx import Graph, Node
from torch.fx.passes.utils.matcher_with_name_node_map_utils import (
SubgraphMatcherWithNameNodeMap,
)

from torchao.quantization.pt2e.quantizer import (
annotate_input_qspec_map,
annotate_output_qspec as _annotate_output_qspec,
QuantizationAnnotation,
)

from .qconfig import QuantizationConfig


Expand Down Expand Up @@ -108,7 +108,7 @@ def _annotate_fused_activation_pattern(
torch.ops.aten.linear.default,
]:
weight_node = producer_node.args[1]
_annotate_input_qspec_map(
annotate_input_qspec_map(
producer_node,
weight_node,
quant_config.weight,
Expand Down Expand Up @@ -201,7 +201,7 @@ def annotate_affine_ops(node: Node, quant_config: QuantizationConfig) -> None:
return

weight_node = node.args[1]
_annotate_input_qspec_map(
annotate_input_qspec_map(
node,
weight_node,
quant_config.weight,
Expand Down Expand Up @@ -260,5 +260,5 @@ def annotate_embedding_op(node: Node, quant_config: QuantizationConfig) -> None:
return

wgt_node = node.args[0]
_annotate_input_qspec_map(node, wgt_node, quant_config.activation)
annotate_input_qspec_map(node, wgt_node, quant_config.activation)
_mark_as_annotated([node])
Loading
Loading