Skip to content

Commit

Permalink
Split quantize_pt2 to allow calling the same APIs in testing and regu…
Browse files Browse the repository at this point in the history
…lar flows (#4505)

Summary:
Pull Request resolved: #4505

Splitting `quantize_pt2` into two steps: `convert_pt2` and `fuse_pt2`. Convert will return the converted model after `convert_pt2e`, which allows getting reference outputs for testing. Fuse will return the final fused graph. Those calls should be always be using the same quantizer. Note that we will probably split the convert step again to allow calibration in a follow up diff.

`quantize_pt2` is still the one-liner API, for anything that doesn't require converted reference outputs (so mostly for e2e testing).

Main benefit is that we can use the same API everywhere now, and things like decomposing SDPA and any other ATen IR passes that need to run before quantization can be done in one location (in `convert_pt2`).

Reviewed By: dulinriley

Differential Revision: D60544102

fbshipit-source-id: 7866d26c6ed05cb8a8bf02eb7920a7adbac5f03a
  • Loading branch information
mcremon-meta authored and facebook-github-bot committed Aug 1, 2024
1 parent 301a017 commit aa56e8c
Showing 1 changed file with 58 additions and 7 deletions.
65 changes: 58 additions & 7 deletions backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# pyre-strict

import logging
from typing import Optional

import torch

Expand Down Expand Up @@ -36,16 +37,24 @@
from torch.export.exported_program import ExportedProgram


def quantize_pt2(
# Note: this is not meant as a primary API since it can create inconsistencies
# if the quantizer here is different from the quantizer used to convert. It is
# however useful for unit tests to separate the converted model from the fused
# model, to be able to get reference numerics.
# If this does not apply, please use quantize_and_fuse_pt2 instead.
def convert_pt2(
model: torch.nn.Module,
inputs: tuple[object, ...],
quantizer: CadenceQuantizer,
) -> torch.fx.GraphModule:
"""
Instantiate the CadenceQuantizer (PTQ), prepare, convert and fuse the model.
Returns a GraphModule with the quantized model.
Prepare and convert a model using the given quantizer.
The quantizer must be supplied and be the same as the one used to
fuse the model later, if applicable. If you do not expect that behavior,
please use quantize_and_fuse_pt2 instead, which will instantiate a
default quantizer for you if needed.
Returns a GraphModule with the converted model.
"""
# Quantizer
quantizer = CadenceQuantizer()

# Export with dynamo
model_exp = capture_pre_autograd_graph(model, inputs)
Expand All @@ -62,12 +71,54 @@ def quantize_pt2(
# Convert
converted_model = convert_pt2e(prepared_model)

return converted_model


# Note: this is not meant as a primary API since it can create inconsistencies
# if the quantizer here is different from the quantizer used to convert. It is
# however useful for unit tests to separate the converted model from the fused
# model, to be able to get reference numerics.
# If this does not apply, please use quantize_and_fuse_pt2 instead.
def fuse_pt2(
converted_graph_module: torch.fx.GraphModule,
quantizer: CadenceQuantizer,
) -> torch.fx.GraphModule:
"""
Fuse a converted graph module using the given quantizer.
The quantizer must be the same as the one used to convert the model.
If you do not expect that behavior, please use quantize_and_fuse_pt2 instead,
which will instantiate a default quantizer for you if needed.
Returns a GraphModule with the fused model.
"""
# Get patterns and apply fusion of dq -> op -> q to qop
# pyre-ignore[16]: no attribute
patterns = [q.pattern for q in quantizer.quantizers]
QuantFusion(patterns)(converted_model)
QuantFusion(patterns)(converted_graph_module)

return converted_model
return converted_graph_module


# Note: this is the one-liner API to quantize and fuse a model.
def quantize_pt2(
model: torch.nn.Module,
inputs: tuple[object, ...],
quantizer: Optional[CadenceQuantizer] = None,
) -> torch.fx.GraphModule:
"""
Prepare, convert and fuse the model using the given quantizer.
Returns a GraphModule with the quantized model.
"""
# Quantizer
if not quantizer:
quantizer = CadenceQuantizer()

# Get converted graph module
converted_gm = convert_pt2(model, inputs, quantizer)

# Get fused model
fused_gm = fuse_pt2(converted_gm, quantizer)

return fused_gm


# Export the model and lower it to an ExportedProgram (in aten IR)
Expand Down

0 comments on commit aa56e8c

Please sign in to comment.