From 48da61aa34b73ea8e2ee815a6a79eea817e361db Mon Sep 17 00:00:00 2001 From: Matthias Cremon Date: Tue, 23 Jul 2024 16:33:48 -0700 Subject: [PATCH] Enable aten.relu_.default in the CadenceQuantizer (#4344) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/4344 As titled. Some model use `torch.ops.aten.relu_.default` instead of `torch.ops.aten.relu.default`. Enable that in the quantizer. Reviewed By: zonglinpengmeta Differential Revision: D60071019 fbshipit-source-id: efad4818f17ca1aef7445d4f8d651bd9f1c46444 --- backends/cadence/aot/quantizer/fusion_pass.py | 8 ++++++-- backends/cadence/aot/quantizer/patterns.py | 18 ++++++++++++++++-- backends/cadence/aot/quantizer/quantizer.py | 6 ++++-- 3 files changed, 26 insertions(+), 6 deletions(-) diff --git a/backends/cadence/aot/quantizer/fusion_pass.py b/backends/cadence/aot/quantizer/fusion_pass.py index 5f0d847e84..4c43172a92 100644 --- a/backends/cadence/aot/quantizer/fusion_pass.py +++ b/backends/cadence/aot/quantizer/fusion_pass.py @@ -17,7 +17,8 @@ LayerNormPattern, LinearPattern, MatmulPattern, - ReluPattern, + ReluPattern0, + ReluPattern1, ) from executorch.backends.cadence.aot.quantizer.utils import ( create_zero_bias_int32, @@ -36,6 +37,9 @@ # pyre-ignore[33]: `_ModelInputsType` cannot alias to `Any`. ArgsType = Any +# Use this part for patterns with multiple aten ops +ReluPatterns = (ReluPattern0, ReluPattern1) + # Helper function to get the args and kwargs for the linear replacement op def get_args_and_kwargs_linear( @@ -411,7 +415,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 bias_inputs, quant_node, ) - elif isinstance(pattern, ReluPattern): + elif isinstance(pattern, ReluPatterns): args, kwargs = get_args_and_kwargs_relu( graph_module, inputs_inputs, diff --git a/backends/cadence/aot/quantizer/patterns.py b/backends/cadence/aot/quantizer/patterns.py index 943b9e473a..7043bae571 100644 --- a/backends/cadence/aot/quantizer/patterns.py +++ b/backends/cadence/aot/quantizer/patterns.py @@ -288,9 +288,11 @@ def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_matmul.default -class ReluPattern(QuantizationPattern): +# This is a base class for ReLU, since it can be used with two different aten ops +class ReluBasePattern(QuantizationPattern): + @abstractmethod def partition_types(self) -> List[OpOverload]: - return [torch.ops.aten.relu.default] + pass def get_anchors( self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] @@ -308,3 +310,15 @@ def get_anchors( def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_relu.default + + +# Regular relu op +class ReluPattern0(ReluBasePattern): + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.relu.default] + + +# Alternate relu op +class ReluPattern1(ReluBasePattern): + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.relu_.default] diff --git a/backends/cadence/aot/quantizer/quantizer.py b/backends/cadence/aot/quantizer/quantizer.py index 5a2c101512..4cd3c6bfb4 100644 --- a/backends/cadence/aot/quantizer/quantizer.py +++ b/backends/cadence/aot/quantizer/quantizer.py @@ -18,7 +18,8 @@ LinearPattern, MatmulPattern, QuantizationPattern, - ReluPattern, + ReluPattern0, + ReluPattern1, ) from executorch.backends.cadence.aot.quantizer.utils import ( find_sequential_partitions_aten, @@ -159,6 +160,7 @@ def __init__(self) -> None: CadenceAtenQuantizer(LayerNormPattern(), static_qconfig), CadenceAtenQuantizer(LinearPattern(), static_qconfig), CadenceAtenQuantizer(MatmulPattern(), static_qconfig), - CadenceAtenQuantizer(ReluPattern(), static_qconfig), + CadenceAtenQuantizer(ReluPattern0(), static_qconfig), + CadenceAtenQuantizer(ReluPattern1(), static_qconfig), ] )