Skip to content

Commit

Permalink
Move SDPA decomp pass from Qualcomm's directory to be shareable and c…
Browse files Browse the repository at this point in the history
…all it for Cadence backends (#4258)

Summary:
Pull Request resolved: #4258

Moving the pass to backends/transforms so that other backends can call it, and call it from the Cadence side so that we can quantize the bmm ops in SDPA.

Reviewed By: cccclai

Differential Revision: D59600486
  • Loading branch information
mcremon-meta authored and facebook-github-bot committed Jul 15, 2024
1 parent 4b45264 commit d08ef5c
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 7 deletions.
1 change: 1 addition & 0 deletions backends/cadence/aot/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ python_library(
"//caffe2:torch",
"//executorch/backends/cadence/aot/quantizer:fusion_pass",
"//executorch/backends/cadence/aot/quantizer:quantizer",
"//executorch/backends/transforms:decompose_sdpa",
"//executorch/exir:lib",
],
)
Expand Down
6 changes: 6 additions & 0 deletions backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
CadenceQuantizer,
)
from executorch.backends.cadence.aot.utils import model_is_quantized
from executorch.backends.transforms.decompose_sdpa import (
DecomposeScaledDotProductAttention,
)
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge
from pyre_extensions import assert_is_instance
from torch._export import capture_pre_autograd_graph
Expand All @@ -47,6 +50,9 @@ def quantize_pt2(
# Export with dynamo
model_exp = capture_pre_autograd_graph(model, inputs)

# Decompose SDPA
DecomposeScaledDotProductAttention(False)(model_exp)

# Prepare
prepared_model = prepare_pt2e(model_exp, quantizer)

Expand Down
6 changes: 3 additions & 3 deletions backends/qualcomm/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@
from typing import Callable, Dict, Optional, Sequence, Set

import torch
from executorch.backends.qualcomm.passes.decompose_scaled_dot_product_attention import (
DecomposeScaledDotProductAttention,
)
from executorch.backends.qualcomm.passes.decompose_silu import DecomposeSilu
from executorch.backends.qualcomm.passes.recompose_pixel_unshuffle import (
RecomposePixelUnshuffle,
)
from executorch.backends.qualcomm.passes.reduce_dynamic_range import ReduceDynamicRange
from executorch.backends.qualcomm.passes.remove_redundancy import RemoveRedundancy
from executorch.backends.qualcomm.passes.replace_inf_buffer import ReplaceInfBuffer
from executorch.backends.transforms.decompose_sdpa import (
DecomposeScaledDotProductAttention,
)

from torch._ops import OpOverload
from torch.ao.quantization.quantizer import Quantizer
Expand Down
13 changes: 13 additions & 0 deletions backends/transforms/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,19 @@ runtime.python_library(
],
)

runtime.python_library(
name = "decompose_sdpa",
srcs = ["decompose_sdpa.py"],
visibility = [
"//executorch/backends/...",
"@EXECUTORCH_CLIENTS",
],
deps = [
"//caffe2:torch",
"//executorch/exir:pass_base",
],
)

runtime.python_library(
name = "fuse_batch_norm_with_conv",
srcs = ["fuse_batch_norm_with_conv.py"],
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import torch
from executorch.exir.pass_base import ExportPass, PassResult
from torch._decomp import get_decompositions
Expand All @@ -14,7 +17,15 @@ class DecomposeScaledDotProductAttention(ExportPass):
Decompose from scaled_dot_product_attention to multiple nodes.
"""

def call(self, graph_module: torch.fx.GraphModule):
def __init__(self, allow_non_fake_inputs: bool = True) -> None:
super().__init__()
# With allow_non_fake_inputs=False, we don't get _unsafe_view ops
# in the graph, we allow disabling it here.
self._allow_non_fake_inputs = allow_non_fake_inputs

def call(
self, graph_module: torch.fx.GraphModule, allow_non_fake_inputs: bool = True
) -> PassResult:
graph = graph_module.graph
for node in graph.nodes:
if node.target == torch.ops.aten.scaled_dot_product_attention.default:
Expand All @@ -29,7 +40,7 @@ def call(self, graph_module: torch.fx.GraphModule):
]
),
tracing_mode="fake",
_allow_non_fake_inputs=True,
_allow_non_fake_inputs=allow_non_fake_inputs,
)(*input_tensors)
with graph.inserting_before(node):
name_to_input_tensor_map = {}
Expand Down

0 comments on commit d08ef5c

Please sign in to comment.