Skip to content

Commit

Permalink
Move the quantization API to OSS (#3997)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3997

The quantization API (`quantize_pt2`) belongs in OSS with the quantizer. Move it there.

Reviewed By: dulinriley

Differential Revision: D58680953

fbshipit-source-id: 3dca4264c2098a41f96a8bb1d6526cb7022391e4
  • Loading branch information
mcremon-meta authored and facebook-github-bot committed Jun 20, 2024
1 parent 802a0b2 commit d11e8e1
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 27 deletions.
2 changes: 2 additions & 0 deletions backends/cadence/aot/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ python_library(
":passes",
":utils",
"//caffe2:torch",
"//executorch/backends/cadence/aot/quantizer:fusion_pass",
"//executorch/backends/cadence/aot/quantizer:quantizer",
"//executorch/exir:lib",
],
)
Expand Down
44 changes: 40 additions & 4 deletions backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import logging
from typing import Any, Tuple

import torch

Expand All @@ -16,18 +17,53 @@
ReplaceScalarTensorWithFullPass,
ReplaceSqueezeAndUnsqueezeWithViewPass,
)
from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer
from executorch.backends.cadence.aot.utils import model_is_quantized
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.pt2e.export_utils import model_is_exported
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e

from torch.export import export
from torch.export.exported_program import ExportedProgram


def quantize_pt2(
model: torch.nn.Module,
inputs: tuple[object, ...],
) -> torch.fx.GraphModule:
"""
Instantiate the CadenceQuantizer (PTQ), prepare, convert and fuse the model.
Returns a GraphModule with the quantized model.
"""
# Quantizer
quantizer = CadenceQuantizer()

# Export with dynamo
model_exp = capture_pre_autograd_graph(model, inputs)

# Prepare
prepared_model = prepare_pt2e(model_exp, quantizer)

# Calibrate
prepared_model(*inputs)

# Convert
converted_model = convert_pt2e(prepared_model)

# Get patterns and apply fusion of dq -> op -> q to qop
# pyre-fixme[16]: Pyre doesn't get that CadenceQuantizer has a patterns attribute
patterns = [q.pattern for q in quantizer.quantizers]
QuantFusion(patterns)(converted_model)

return converted_model


# Export the model and lower it to an ExportedProgram (in aten IR)
def export_program(
model: torch.nn.Module,
inputs: Tuple[Any, ...],
inputs: tuple[object, ...],
) -> ExportedProgram:
assert isinstance(model, torch.nn.Module), "model should be an nn.Module"

Expand Down Expand Up @@ -57,7 +93,7 @@ def export_program(
# Export the model and lower it to an EdgeProgramManager (in edge IR).
def export_to_edge(
model: torch.nn.Module,
inputs: Tuple[Any, ...],
inputs: tuple[object, ...],
dump_graphs: bool = False,
) -> EdgeProgramManager:
assert isinstance(model, torch.nn.Module), "model should be an nn.Module"
Expand Down Expand Up @@ -89,7 +125,7 @@ def export_to_edge(
# apply passes specific to Cadence DSP execution.
def export_to_cadence(
model: torch.nn.Module,
inputs: Tuple[Any, ...],
inputs: tuple[object, ...],
dump_graphs: bool = False,
) -> EdgeProgramManager:
edge_program_manager = export_to_edge(model, inputs)
Expand Down
32 changes: 9 additions & 23 deletions backends/cadence/aot/export_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
import os
from typing import Any, Tuple

from executorch.backends.cadence.aot.compiler import export_to_cadence, export_to_edge
from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer
from executorch.backends.cadence.aot.compiler import (
export_to_cadence,
export_to_edge,
quantize_pt2,
)
from executorch.exir import ExecutorchProgramManager
from torch import nn
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e

from .utils import print_ops_info

Expand Down Expand Up @@ -49,28 +49,14 @@ def export_model(
example_inputs: Tuple[Any, ...],
file_name: str = "CadenceDemoModel",
):
# Quantizer
quantizer = CadenceQuantizer()

# Export
model_exp = capture_pre_autograd_graph(model, example_inputs)

# Prepare
prepared_model = prepare_pt2e(model_exp, quantizer)
prepared_model(*example_inputs)

# Convert
converted_model = convert_pt2e(prepared_model)

# pyre-fixme[16]: Pyre doesn't get that CadenceQuantizer has a patterns attribute
patterns = [q.pattern for q in quantizer.quantizers]
QuantFusion(patterns)(converted_model)
# Quantize the model
quantized_model = quantize_pt2(model, example_inputs)

# Get edge program
edge_prog_manager = export_to_edge(converted_model, example_inputs)
edge_prog_manager = export_to_edge(quantized_model, example_inputs)

# Get edge program after Cadence specific passes
cadence_prog_manager = export_to_cadence(converted_model, example_inputs)
cadence_prog_manager = export_to_cadence(quantized_model, example_inputs)

exec_prog = cadence_prog_manager.to_executorch()

Expand Down

0 comments on commit d11e8e1

Please sign in to comment.