From 238850b4dd7da6729846061928d4c6a316dddab3 Mon Sep 17 00:00:00 2001 From: Matthias Cremon Date: Wed, 10 Jul 2024 10:43:36 -0700 Subject: [PATCH] Add support for quantized bmm (#4047) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/4047 The current quantizer only captures "fake" bmm from matmuls with specific shapes. Add support for `torch.bmm` as well. Use a decomposition for SDPA to make sure LLaMa bmms get quantized. Differential Revision: D58959269 Reviewed By: zonglinpengmeta, hsharma35 --- backends/cadence/aot/TARGETS | 1 + backends/cadence/aot/compiler.py | 12 ++++++++--- backends/cadence/aot/quantizer/fusion_pass.py | 5 ++++- backends/cadence/aot/quantizer/patterns.py | 20 +++++++++++++++++++ backends/cadence/aot/quantizer/quantizer.py | 2 ++ 5 files changed, 36 insertions(+), 4 deletions(-) diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index 8f5235b3d8a..8e674acc4fb 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -28,6 +28,7 @@ python_library( "compiler.py", ], deps = [ + "fbsource//third-party/pypi/pyre-extensions:pyre-extensions", ":passes", ":utils", "//caffe2:torch", diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index c51cad98be2..ff893f4e45c 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -18,9 +18,13 @@ ReplaceSqueezeAndUnsqueezeWithViewPass, ) from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion -from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer +from executorch.backends.cadence.aot.quantizer.quantizer import ( + CadenceGenericQuantizer, + CadenceQuantizer, +) from executorch.backends.cadence.aot.utils import model_is_quantized from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge +from pyre_extensions import assert_is_instance from torch._export import capture_pre_autograd_graph from torch.ao.quantization.pt2e.export_utils import model_is_exported from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e @@ -53,8 +57,10 @@ def quantize_pt2( converted_model = convert_pt2e(prepared_model) # Get patterns and apply fusion of dq -> op -> q to qop - # pyre-fixme[16]: Pyre doesn't get that CadenceQuantizer has a patterns attribute - patterns = [q.pattern for q in quantizer.quantizers] + patterns = [ + assert_is_instance(q, CadenceGenericQuantizer).pattern + for q in quantizer.quantizers + ] QuantFusion(patterns)(converted_model) return converted_model diff --git a/backends/cadence/aot/quantizer/fusion_pass.py b/backends/cadence/aot/quantizer/fusion_pass.py index 0a1927e7252..803379b3bdc 100644 --- a/backends/cadence/aot/quantizer/fusion_pass.py +++ b/backends/cadence/aot/quantizer/fusion_pass.py @@ -11,6 +11,7 @@ import torch from executorch.backends.cadence.aot.quantizer.patterns import ( AddmmPattern, + BmmPattern, Conv1dPattern, Conv2dPattern, LayerNormFunctionalPattern, @@ -396,7 +397,9 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 other_inputs, quant_node, ) - elif isinstance(pattern, MatmulPattern): + elif isinstance(pattern, BmmPattern) or isinstance( + pattern, MatmulPattern + ): args, kwargs = get_args_and_kwargs_matmul( inputs_inputs, dequants_inputs, diff --git a/backends/cadence/aot/quantizer/patterns.py b/backends/cadence/aot/quantizer/patterns.py index e403ac9d2ac..381bbbf6f13 100644 --- a/backends/cadence/aot/quantizer/patterns.py +++ b/backends/cadence/aot/quantizer/patterns.py @@ -95,6 +95,26 @@ def replacement_op(self): return torch.ops.cadence.quantized_linear +class BmmPattern(QuantizationPattern): + def partition_types(self) -> List[Callable[..., torch.Tensor]]: + return [torch.bmm] + + def get_anchors( + self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + ) -> PartitionAnchors: + bmm_node = fused_partition[0].nodes[-1] + + return PartitionAnchors( + inputs=[(bmm_node, 0), (bmm_node, 1)], + weights=[], + biases=[], + output=[(bmm_node,)], + ) + + def replacement_op(self): + return torch.ops.cadence.quantized_matmul.default + + class Conv1dPattern(QuantizationPattern): def partition_types(self) -> List[Type[torch.nn.Module]]: return [torch.nn.Conv1d] diff --git a/backends/cadence/aot/quantizer/quantizer.py b/backends/cadence/aot/quantizer/quantizer.py index 79e6fb28149..f5275629625 100644 --- a/backends/cadence/aot/quantizer/quantizer.py +++ b/backends/cadence/aot/quantizer/quantizer.py @@ -9,6 +9,7 @@ import torch from executorch.backends.cadence.aot.quantizer.patterns import ( AddmmPattern, + BmmPattern, Conv1dPattern, Conv2dPattern, LayerNormFunctionalPattern, @@ -133,6 +134,7 @@ def __init__(self): super().__init__( [ CadenceGenericQuantizer(AddmmPattern(), static_qconfig), + CadenceGenericQuantizer(BmmPattern(), static_qconfig), CadenceGenericQuantizer(Conv1dPattern(), static_qconfig), CadenceGenericQuantizer(Conv2dPattern(), static_qconfig), CadenceGenericQuantizer(LayerNormPattern(), static_qconfig),