Skip to content

Commit b252af9

Browse files
metascroyfacebook-github-bot
authored andcommitted
Migrate ExecuTorch's use of pt2e from torch.ao to torchao (#10294)
Summary: Most code related to PT2E quantization is migrating from torch.ao.quantization to torchao.quantization.pt2e. torchao.quantization.pt2e contains an exact copy of PT2E code in torch.ao.quantization. The torchao pin in ExecuTorch has already been bumped pick up these changes. Pull Request resolved: #10294 Reviewed By: SS-JIA Differential Revision: D74694311 Pulled By: metascroy
1 parent 9aaea31 commit b252af9

File tree

81 files changed

+324
-267
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

81 files changed

+324
-267
lines changed

.lintrunner.toml

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,3 +378,31 @@ command = [
378378
'--',
379379
'@{{PATHSFILE}}',
380380
]
381+
382+
[[linter]]
383+
code = "TORCH_AO_IMPORT"
384+
include_patterns = ["**/*.py"]
385+
exclude_patterns = [
386+
"third-party/**",
387+
]
388+
389+
command = [
390+
"python3",
391+
"-m",
392+
"lintrunner_adapters",
393+
"run",
394+
"grep_linter",
395+
"--pattern=\\bfrom torch\\.ao\\.quantization\\.(?:quantizer|observer|quantize_pt2e|pt2e)(?:\\.[A-Za-z0-9_]+)*\\b",
396+
"--linter-name=TorchAOImport",
397+
"--error-name=Prohibited torch.ao.quantization import",
398+
"""--error-description=\
399+
Imports from torch.ao.quantization are not allowed. \
400+
Please import from torchao.quantization.pt2e instead.\n \
401+
* torchao.quantization.pt2e (includes all the utils, including observers, fake quants etc.) \n \
402+
* torchao.quantization.pt2e.quantizer (quantizer related objects and utils) \n \
403+
* torchao.quantization.pt2e.quantize_pt2e (prepare_pt2e, prepare_qat_pt2e, convert_pt2e) \n\n \
404+
If you need something from torch.ao.quantization, you can add your file to an exclude_patterns for TORCH_AO_IMPORT in .lintrunner.toml. \
405+
""",
406+
"--",
407+
"@{{PATHSFILE}}",
408+
]

.mypy.ini

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,6 @@ ignore_missing_imports = True
9797

9898
[mypy-zstd]
9999
ignore_missing_imports = True
100+
101+
[mypy-torchao.*]
102+
follow_untyped_imports = True

backends/apple/coreml/test/test_coreml_quantizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@
1515
)
1616

1717
from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer
18-
from torch.ao.quantization.quantize_pt2e import (
18+
from torch.export import export_for_training
19+
from torchao.quantization.pt2e.quantize_pt2e import (
1920
convert_pt2e,
2021
prepare_pt2e,
2122
prepare_qat_pt2e,
2223
)
23-
from torch.export import export_for_training
2424

2525

2626
class TestCoreMLQuantizer:

backends/arm/quantizer/arm_quantizer.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,25 +30,24 @@
3030
is_vgf,
3131
) # usort: skip
3232
from executorch.exir.backend.compile_spec_schema import CompileSpec
33-
from torch.ao.quantization.fake_quantize import (
33+
from torch.fx import GraphModule, Node
34+
from torchao.quantization.pt2e import (
3435
FakeQuantize,
3536
FusedMovingAvgObsFakeQuantize,
36-
)
37-
from torch.ao.quantization.observer import (
3837
HistogramObserver,
3938
MinMaxObserver,
4039
MovingAverageMinMaxObserver,
4140
MovingAveragePerChannelMinMaxObserver,
41+
ObserverOrFakeQuantizeConstructor,
4242
PerChannelMinMaxObserver,
4343
PlaceholderObserver,
4444
)
45-
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
46-
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
47-
from torch.ao.quantization.quantizer.utils import (
48-
_annotate_input_qspec_map,
49-
_annotate_output_qspec,
45+
from torchao.quantization.pt2e.quantizer import (
46+
annotate_input_qspec_map,
47+
annotate_output_qspec,
48+
QuantizationSpec,
49+
Quantizer,
5050
)
51-
from torch.fx import GraphModule, Node
5251

5352
__all__ = [
5453
"TOSAQuantizer",
@@ -97,7 +96,7 @@ def get_symmetric_quantization_config(
9796
weight_qscheme = (
9897
torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric
9998
)
100-
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
99+
weight_observer_or_fake_quant_ctr: ObserverOrFakeQuantizeConstructor = (
101100
MinMaxObserver
102101
)
103102
if is_qat:
@@ -337,14 +336,14 @@ def _annotate_io(
337336
if is_annotated(node):
338337
continue
339338
if node.op == "placeholder" and len(node.users) > 0:
340-
_annotate_output_qspec(
339+
annotate_output_qspec(
341340
node,
342341
quantization_config.get_output_act_qspec(),
343342
)
344343
mark_node_as_annotated(node)
345344
if node.op == "output":
346345
parent = node.all_input_nodes[0]
347-
_annotate_input_qspec_map(
346+
annotate_input_qspec_map(
348347
node, parent, quantization_config.get_input_act_qspec()
349348
)
350349
mark_node_as_annotated(node)

backends/arm/quantizer/arm_quantizer_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515

1616
import torch
1717
from torch._subclasses import FakeTensor
18-
19-
from torch.ao.quantization.quantizer import QuantizationAnnotation
2018
from torch.fx import GraphModule, Node
2119

20+
from torchao.quantization.pt2e.quantizer import QuantizationAnnotation
21+
2222

2323
def is_annotated(node: Node) -> bool:
2424
"""Given a node return whether the node is annotated."""

backends/arm/quantizer/quantization_annotator.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@
1212
import torch.fx
1313
from executorch.backends.arm.quantizer import QuantizationConfig
1414
from executorch.backends.arm.tosa_utils import get_node_debug_info
15-
from torch.ao.quantization.quantizer import QuantizationSpecBase, SharedQuantizationSpec
16-
from torch.ao.quantization.quantizer.utils import (
17-
_annotate_input_qspec_map,
18-
_annotate_output_qspec,
19-
)
2015
from torch.fx import Node
16+
from torchao.quantization.pt2e.quantizer import (
17+
annotate_input_qspec_map,
18+
annotate_output_qspec,
19+
QuantizationSpecBase,
20+
SharedQuantizationSpec,
21+
)
2122

2223
from .arm_quantizer_utils import (
2324
is_annotated,
@@ -118,7 +119,7 @@ def _annotate_input(node: Node, quant_property: _QuantProperty):
118119
strict=True,
119120
):
120121
assert isinstance(n_arg, Node)
121-
_annotate_input_qspec_map(node, n_arg, qspec)
122+
annotate_input_qspec_map(node, n_arg, qspec)
122123
if quant_property.mark_annotated:
123124
mark_node_as_annotated(n_arg) # type: ignore[attr-defined]
124125

@@ -129,7 +130,7 @@ def _annotate_output(node: Node, quant_property: _QuantProperty):
129130
assert not quant_property.optional
130131
assert quant_property.index == 0, "Only one output annotation supported currently"
131132

132-
_annotate_output_qspec(node, quant_property.qspec)
133+
annotate_output_qspec(node, quant_property.qspec)
133134

134135

135136
def _match_pattern(

backends/arm/quantizer/quantization_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
from dataclasses import dataclass
1010

1111
import torch
12-
from torch.ao.quantization import ObserverOrFakeQuantize
12+
from torchao.quantization.pt2e import ObserverOrFakeQuantize
1313

14-
from torch.ao.quantization.quantizer import (
14+
from torchao.quantization.pt2e.quantizer import (
1515
DerivedQuantizationSpec,
1616
FixedQParamsQuantizationSpec,
1717
QuantizationSpec,

backends/arm/test/ops/test_add.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
)
2020
from executorch.backends.arm.tosa_specification import TosaSpecification
2121
from executorch.backends.xnnpack.test.tester import Quantize
22-
from torch.ao.quantization.observer import HistogramObserver
23-
from torch.ao.quantization.quantizer import QuantizationSpec
22+
from torchao.quantization.pt2e.observer import HistogramObserver
23+
from torchao.quantization.pt2e.quantizer import QuantizationSpec
2424

2525
aten_op = "torch.ops.aten.add.Tensor"
2626
exir_op = "executorch_exir_dialects_edge__ops_aten_add_Tensor"

backends/arm/test/ops/test_sigmoid_16bit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
)
1919
from executorch.backends.arm.tosa_specification import TosaSpecification
2020
from executorch.backends.xnnpack.test.tester import Quantize
21-
from torch.ao.quantization.observer import HistogramObserver
22-
from torch.ao.quantization.quantizer import QuantizationSpec
21+
from torchao.quantization.pt2e.observer import HistogramObserver
22+
from torchao.quantization.pt2e.quantizer import QuantizationSpec
2323

2424

2525
def _get_16_bit_quant_config():

backends/arm/test/ops/test_sigmoid_32bit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
)
1515
from executorch.backends.arm.tosa_specification import TosaSpecification
1616
from executorch.backends.xnnpack.test.tester import Quantize
17-
from torch.ao.quantization.observer import HistogramObserver
18-
from torch.ao.quantization.quantizer import QuantizationSpec
17+
from torchao.quantization.pt2e.observer import HistogramObserver
18+
from torchao.quantization.pt2e.quantizer import QuantizationSpec
1919

2020

2121
def _get_16_bit_quant_config():

backends/cadence/aot/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@
3737
from executorch.exir.passes.sym_shape_eval_pass import HintBasedSymShapeEvalPass
3838
from executorch.exir.program._program import to_edge_with_preserved_ops
3939
from torch._inductor.decomposition import remove_decompositions
40-
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
4140

4241
from torch.export.exported_program import ExportedProgram
42+
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
4343

4444
from .passes import get_cadence_passes
4545

backends/cadence/aot/quantizer/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ python_library(
2121
deps = [
2222
":utils",
2323
"//caffe2:torch",
24+
"//pytorch/ao:torchao",
2425
],
2526
)
2627

backends/cadence/aot/quantizer/patterns.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from torch import fx
1717
from torch._ops import OpOverload
18-
from torch.ao.quantization.quantizer import (
18+
from torchao.quantization.pt2e.quantizer import (
1919
DerivedQuantizationSpec,
2020
SharedQuantizationSpec,
2121
)

backends/cadence/aot/quantizer/quantizer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@
3838

3939
from torch import fx
4040

41-
from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver
42-
from torch.ao.quantization.quantizer import DerivedQuantizationSpec, Quantizer
43-
from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer
41+
from torchao.quantization.pt2e.observer import HistogramObserver, MinMaxObserver
42+
from torchao.quantization.pt2e.quantizer import DerivedQuantizationSpec, Quantizer
43+
from torchao.quantization.pt2e.quantizer.composable_quantizer import ComposableQuantizer
4444

4545

4646
act_qspec_asym8s = QuantizationSpec(

backends/cadence/aot/quantizer/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@
1414
import torch
1515
from torch import fx
1616
from torch._ops import OpOverload
17-
from torch.ao.quantization import ObserverOrFakeQuantize
1817

1918
from torch.fx import GraphModule
2019
from torch.fx.passes.utils.source_matcher_utils import (
2120
check_subgraphs_connected,
2221
SourcePartition,
2322
)
23+
from torchao.quantization.pt2e import ObserverOrFakeQuantize
2424

2525

2626
def quantize_tensor_multiplier(

backends/cortex_m/test/test_replace_quant_nodes.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@
1616
ReplaceQuantNodesPass,
1717
)
1818
from executorch.exir.dialects._ops import ops as exir_ops
19-
from torch.ao.quantization.observer import HistogramObserver
20-
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
21-
from torch.ao.quantization.quantizer.quantizer import (
19+
from torch.export import export, export_for_training
20+
from torch.fx import GraphModule
21+
from torchao.quantization.pt2e.observer import HistogramObserver
22+
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
23+
from torchao.quantization.pt2e.quantizer import (
2224
QuantizationAnnotation,
2325
QuantizationSpec,
2426
Quantizer,
2527
)
26-
from torch.export import export, export_for_training
27-
from torch.fx import GraphModule
2828

2929

3030
@dataclass(eq=True, frozen=True)

backends/example/example_backend_delegate_passes/permute_memory_formats_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from executorch.exir.dialects._ops import ops as exir_ops
1212
from executorch.exir.dim_order_utils import get_dim_order
1313
from executorch.exir.pass_base import ExportPass, PassResult
14-
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
14+
from torchao.quantization.pt2e import find_sequential_partitions
1515

1616

1717
class PermuteMemoryFormatsPass(ExportPass):

backends/example/example_operators/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from torch.ao.quantization.quantizer.quantizer import QuantizationAnnotation
7+
from torchao.quantization.pt2e.quantizer.quantizer import QuantizationAnnotation
88

99

1010
def _nodes_are_annotated(node_list):

backends/example/example_partitioner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
)
2020
from executorch.exir.dialects._ops import ops as exir_ops
2121
from executorch.exir.graph_module import get_control_flow_submodules
22-
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
2322
from torch.export import ExportedProgram
2423
from torch.fx.passes.operator_support import OperatorSupportBase
24+
from torchao.quantization.pt2e import find_sequential_partitions
2525

2626

2727
@final

backends/example/example_quantizer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,12 @@
1111
from executorch.backends.example.example_operators.ops import module_to_annotator
1212
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils import OperatorConfig
1313
from torch import fx
14-
from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver
15-
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
16-
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
14+
from torchao.quantization.pt2e import (
15+
find_sequential_partitions,
16+
HistogramObserver,
17+
MinMaxObserver,
18+
)
19+
from torchao.quantization.pt2e.quantizer import QuantizationSpec, Quantizer
1720

1821

1922
def get_uint8_tensor_spec(observer_or_fake_quant_ctr):

backends/example/test_example_delegate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
DuplicateDequantNodePass,
1818
)
1919
from executorch.exir.delegate import executorch_call_delegate
20-
21-
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
2220
from torch.export import export
2321

22+
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
23+
2424
from torchvision.models.quantization import mobilenet_v2
2525

2626

backends/mediatek/quantizer/annotator.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,18 @@
1010
from torch._ops import OpOverload
1111
from torch._subclasses import FakeTensor
1212

13-
from torch.ao.quantization.quantizer import QuantizationAnnotation
14-
from torch.ao.quantization.quantizer.utils import (
15-
_annotate_input_qspec_map,
16-
_annotate_output_qspec,
17-
)
18-
1913
from torch.export import export_for_training
2014
from torch.fx import Graph, Node
2115
from torch.fx.passes.utils.matcher_with_name_node_map_utils import (
2216
SubgraphMatcherWithNameNodeMap,
2317
)
2418

19+
from torchao.quantization.pt2e.quantizer import (
20+
annotate_input_qspec_map,
21+
annotate_output_qspec as _annotate_output_qspec,
22+
QuantizationAnnotation,
23+
)
24+
2525
from .qconfig import QuantizationConfig
2626

2727

@@ -108,7 +108,7 @@ def _annotate_fused_activation_pattern(
108108
torch.ops.aten.linear.default,
109109
]:
110110
weight_node = producer_node.args[1]
111-
_annotate_input_qspec_map(
111+
annotate_input_qspec_map(
112112
producer_node,
113113
weight_node,
114114
quant_config.weight,
@@ -201,7 +201,7 @@ def annotate_affine_ops(node: Node, quant_config: QuantizationConfig) -> None:
201201
return
202202

203203
weight_node = node.args[1]
204-
_annotate_input_qspec_map(
204+
annotate_input_qspec_map(
205205
node,
206206
weight_node,
207207
quant_config.weight,
@@ -260,5 +260,5 @@ def annotate_embedding_op(node: Node, quant_config: QuantizationConfig) -> None:
260260
return
261261

262262
wgt_node = node.args[0]
263-
_annotate_input_qspec_map(node, wgt_node, quant_config.activation)
263+
annotate_input_qspec_map(node, wgt_node, quant_config.activation)
264264
_mark_as_annotated([node])

0 commit comments

Comments
 (0)