From aa56e8ceb97d3bd0261f724bd668085cd0500af7 Mon Sep 17 00:00:00 2001 From: Matthias Cremon Date: Thu, 1 Aug 2024 14:20:02 -0700 Subject: [PATCH] Split quantize_pt2 to allow calling the same APIs in testing and regular flows (#4505) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/4505 Splitting `quantize_pt2` into two steps: `convert_pt2` and `fuse_pt2`. Convert will return the converted model after `convert_pt2e`, which allows getting reference outputs for testing. Fuse will return the final fused graph. Those calls should be always be using the same quantizer. Note that we will probably split the convert step again to allow calibration in a follow up diff. `quantize_pt2` is still the one-liner API, for anything that doesn't require converted reference outputs (so mostly for e2e testing). Main benefit is that we can use the same API everywhere now, and things like decomposing SDPA and any other ATen IR passes that need to run before quantization can be done in one location (in `convert_pt2`). Reviewed By: dulinriley Differential Revision: D60544102 fbshipit-source-id: 7866d26c6ed05cb8a8bf02eb7920a7adbac5f03a --- backends/cadence/aot/compiler.py | 65 ++++++++++++++++++++++++++++---- 1 file changed, 58 insertions(+), 7 deletions(-) diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index 39511ae917..509e254b55 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -7,6 +7,7 @@ # pyre-strict import logging +from typing import Optional import torch @@ -36,16 +37,24 @@ from torch.export.exported_program import ExportedProgram -def quantize_pt2( +# Note: this is not meant as a primary API since it can create inconsistencies +# if the quantizer here is different from the quantizer used to convert. It is +# however useful for unit tests to separate the converted model from the fused +# model, to be able to get reference numerics. +# If this does not apply, please use quantize_and_fuse_pt2 instead. +def convert_pt2( model: torch.nn.Module, inputs: tuple[object, ...], + quantizer: CadenceQuantizer, ) -> torch.fx.GraphModule: """ - Instantiate the CadenceQuantizer (PTQ), prepare, convert and fuse the model. - Returns a GraphModule with the quantized model. + Prepare and convert a model using the given quantizer. + The quantizer must be supplied and be the same as the one used to + fuse the model later, if applicable. If you do not expect that behavior, + please use quantize_and_fuse_pt2 instead, which will instantiate a + default quantizer for you if needed. + Returns a GraphModule with the converted model. """ - # Quantizer - quantizer = CadenceQuantizer() # Export with dynamo model_exp = capture_pre_autograd_graph(model, inputs) @@ -62,12 +71,54 @@ def quantize_pt2( # Convert converted_model = convert_pt2e(prepared_model) + return converted_model + + +# Note: this is not meant as a primary API since it can create inconsistencies +# if the quantizer here is different from the quantizer used to convert. It is +# however useful for unit tests to separate the converted model from the fused +# model, to be able to get reference numerics. +# If this does not apply, please use quantize_and_fuse_pt2 instead. +def fuse_pt2( + converted_graph_module: torch.fx.GraphModule, + quantizer: CadenceQuantizer, +) -> torch.fx.GraphModule: + """ + Fuse a converted graph module using the given quantizer. + The quantizer must be the same as the one used to convert the model. + If you do not expect that behavior, please use quantize_and_fuse_pt2 instead, + which will instantiate a default quantizer for you if needed. + Returns a GraphModule with the fused model. + """ # Get patterns and apply fusion of dq -> op -> q to qop # pyre-ignore[16]: no attribute patterns = [q.pattern for q in quantizer.quantizers] - QuantFusion(patterns)(converted_model) + QuantFusion(patterns)(converted_graph_module) - return converted_model + return converted_graph_module + + +# Note: this is the one-liner API to quantize and fuse a model. +def quantize_pt2( + model: torch.nn.Module, + inputs: tuple[object, ...], + quantizer: Optional[CadenceQuantizer] = None, +) -> torch.fx.GraphModule: + """ + Prepare, convert and fuse the model using the given quantizer. + Returns a GraphModule with the quantized model. + """ + # Quantizer + if not quantizer: + quantizer = CadenceQuantizer() + + # Get converted graph module + converted_gm = convert_pt2(model, inputs, quantizer) + + # Get fused model + fused_gm = fuse_pt2(converted_gm, quantizer) + + return fused_gm # Export the model and lower it to an ExportedProgram (in aten IR)