diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index 8e674acc4f..bd4ec660a6 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -34,6 +34,7 @@ python_library( "//caffe2:torch", "//executorch/backends/cadence/aot/quantizer:fusion_pass", "//executorch/backends/cadence/aot/quantizer:quantizer", + "//executorch/backends/transforms:decompose_sdpa", "//executorch/exir:lib", ], ) diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index ff893f4e45..76b900add9 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -23,6 +23,9 @@ CadenceQuantizer, ) from executorch.backends.cadence.aot.utils import model_is_quantized +from executorch.backends.transforms.decompose_sdpa import ( + DecomposeScaledDotProductAttention, +) from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge from pyre_extensions import assert_is_instance from torch._export import capture_pre_autograd_graph @@ -47,6 +50,9 @@ def quantize_pt2( # Export with dynamo model_exp = capture_pre_autograd_graph(model, inputs) + # Decompose SDPA + DecomposeScaledDotProductAttention(False)(model_exp) + # Prepare prepared_model = prepare_pt2e(model_exp, quantizer) diff --git a/backends/qualcomm/quantizer/quantizer.py b/backends/qualcomm/quantizer/quantizer.py index cc2ce008a7..91e31b62e4 100644 --- a/backends/qualcomm/quantizer/quantizer.py +++ b/backends/qualcomm/quantizer/quantizer.py @@ -7,9 +7,6 @@ from typing import Callable, Dict, Optional, Sequence, Set import torch -from executorch.backends.qualcomm.passes.decompose_scaled_dot_product_attention import ( - DecomposeScaledDotProductAttention, -) from executorch.backends.qualcomm.passes.decompose_silu import DecomposeSilu from executorch.backends.qualcomm.passes.recompose_pixel_unshuffle import ( RecomposePixelUnshuffle, @@ -17,6 +14,9 @@ from executorch.backends.qualcomm.passes.reduce_dynamic_range import ReduceDynamicRange from executorch.backends.qualcomm.passes.remove_redundancy import RemoveRedundancy from executorch.backends.qualcomm.passes.replace_inf_buffer import ReplaceInfBuffer +from executorch.backends.transforms.decompose_sdpa import ( + DecomposeScaledDotProductAttention, +) from torch._ops import OpOverload from torch.ao.quantization.quantizer import Quantizer diff --git a/backends/transforms/TARGETS b/backends/transforms/TARGETS index 50fbb76e77..88de8a84a6 100644 --- a/backends/transforms/TARGETS +++ b/backends/transforms/TARGETS @@ -29,6 +29,19 @@ runtime.python_library( ], ) +runtime.python_library( + name = "decompose_sdpa", + srcs = ["decompose_sdpa.py"], + visibility = [ + "//executorch/backends/...", + "@EXECUTORCH_CLIENTS", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:pass_base", + ], +) + runtime.python_library( name = "fuse_batch_norm_with_conv", srcs = ["fuse_batch_norm_with_conv.py"], diff --git a/backends/qualcomm/passes/decompose_scaled_dot_product_attention.py b/backends/transforms/decompose_sdpa.py similarity index 86% rename from backends/qualcomm/passes/decompose_scaled_dot_product_attention.py rename to backends/transforms/decompose_sdpa.py index 9fd3c8fb46..6dbbf564f5 100644 --- a/backends/qualcomm/passes/decompose_scaled_dot_product_attention.py +++ b/backends/transforms/decompose_sdpa.py @@ -1,8 +1,11 @@ -# Copyright (c) Qualcomm Innovation Center, Inc. -# All rights reserved +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + +# pyre-strict + import torch from executorch.exir.pass_base import ExportPass, PassResult from torch._decomp import get_decompositions @@ -14,7 +17,15 @@ class DecomposeScaledDotProductAttention(ExportPass): Decompose from scaled_dot_product_attention to multiple nodes. """ - def call(self, graph_module: torch.fx.GraphModule): + def __init__(self, allow_non_fake_inputs: bool = True) -> None: + super().__init__() + # With allow_non_fake_inputs=False, we don't get _unsafe_view ops + # in the graph, we allow disabling it here. + self._allow_non_fake_inputs = allow_non_fake_inputs + + def call( + self, graph_module: torch.fx.GraphModule, allow_non_fake_inputs: bool = True + ) -> PassResult: graph = graph_module.graph for node in graph.nodes: if node.target == torch.ops.aten.scaled_dot_product_attention.default: @@ -29,7 +40,7 @@ def call(self, graph_module: torch.fx.GraphModule): ] ), tracing_mode="fake", - _allow_non_fake_inputs=True, + _allow_non_fake_inputs=allow_non_fake_inputs, )(*input_tensors) with graph.inserting_before(node): name_to_input_tensor_map = {}