From d7c069f495e24d3919cc27ae25ceed8b042e2eed Mon Sep 17 00:00:00 2001 From: mcremon-meta <134334895+mcremon-meta@users.noreply.github.com> Date: Thu, 22 Aug 2024 16:29:15 -0700 Subject: [PATCH] Fix SDPA decomp problem Differential Revision: D61639074 Pull Request resolved: https://github.com/pytorch/executorch/pull/4851 --- backends/cadence/aot/compiler.py | 18 +++++++++++++----- backends/cadence/aot/passes.py | 26 ++++++++++++++++++++++++++ backends/cadence/aot/utils.py | 8 ++++++++ 3 files changed, 47 insertions(+), 5 deletions(-) diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index 509e254b55..405f8b5db4 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -18,12 +18,13 @@ ReplaceLogicalNotBooleanWhereWithWherePass, ReplacePT2DequantWithCadenceDequantPass, ReplacePT2QuantWithCadenceQuantPass, + ReplaceSafeSoftmaxWithSoftmax, ReplaceScalarTensorWithFullPass, ReplaceSqueezeAndUnsqueezeWithViewPass, ) from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer -from executorch.backends.cadence.aot.utils import model_is_quantized +from executorch.backends.cadence.aot.utils import model_gm_has_SDPA, model_is_quantized from executorch.backends.transforms.decompose_sdpa import ( DecomposeScaledDotProductAttention, ) @@ -57,13 +58,20 @@ def convert_pt2( """ # Export with dynamo - model_exp = capture_pre_autograd_graph(model, inputs) + model_gm = capture_pre_autograd_graph(model, inputs) - # Decompose SDPA - DecomposeScaledDotProductAttention(False)(model_exp) + if model_gm_has_SDPA(model_gm): + # Decompose SDPA + DecomposeScaledDotProductAttention(False)(model_gm) + + # Swap _safe_softmax with _softmax (see https://github.com/pytorch/pytorch/pull/133882 + # for details). + result = ReplaceSafeSoftmaxWithSoftmax()(model_gm) + assert result is not None + model_gm = result.graph_module # Prepare - prepared_model = prepare_pt2e(model_exp, quantizer) + prepared_model = prepare_pt2e(model_gm, quantizer) # Calibrate prepared_model(*inputs) diff --git a/backends/cadence/aot/passes.py b/backends/cadence/aot/passes.py index db419bfb5e..83ef43d151 100644 --- a/backends/cadence/aot/passes.py +++ b/backends/cadence/aot/passes.py @@ -266,3 +266,29 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: result = SpecPropPass()(graph_module) assert result is not None return result + + +class ReplaceSafeSoftmaxWithSoftmax(ExportPass): + """ + Replace _safe_softmax with _softmax + """ + + def call_operator( + self, + op, # pyre-ignore + args: tuple[Argument, ...], + kwargs: dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + if op != torch.ops.aten._safe_softmax.default: + return super().call_operator(op, args, kwargs, meta) + + # Add False for the half_to_float argument of softmax + softmax_args = list(args) + [False] + + return super().call_operator( + torch.ops.aten._softmax.default, + tuple(softmax_args), + kwargs, + meta, + ) diff --git a/backends/cadence/aot/utils.py b/backends/cadence/aot/utils.py index f0c294260a..b710f7d4e5 100644 --- a/backends/cadence/aot/utils.py +++ b/backends/cadence/aot/utils.py @@ -177,3 +177,11 @@ def print_ops_info( tablefmt="outline", ) ) + + +def model_gm_has_SDPA(model_gm: torch.fx.GraphModule) -> bool: + for node in model_gm.graph.nodes: + if node.op == "call_function": + if node.target == torch.ops.aten.scaled_dot_product_attention.default: + return True + return False