Skip to content

Commit

Permalink
Fix SDPA decomp problem
Browse files Browse the repository at this point in the history
Differential Revision: D61639074

Pull Request resolved: #4851
  • Loading branch information
mcremon-meta committed Aug 22, 2024
1 parent bf64819 commit d7c069f
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 5 deletions.
18 changes: 13 additions & 5 deletions backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@
ReplaceLogicalNotBooleanWhereWithWherePass,
ReplacePT2DequantWithCadenceDequantPass,
ReplacePT2QuantWithCadenceQuantPass,
ReplaceSafeSoftmaxWithSoftmax,
ReplaceScalarTensorWithFullPass,
ReplaceSqueezeAndUnsqueezeWithViewPass,
)
from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer
from executorch.backends.cadence.aot.utils import model_is_quantized
from executorch.backends.cadence.aot.utils import model_gm_has_SDPA, model_is_quantized
from executorch.backends.transforms.decompose_sdpa import (
DecomposeScaledDotProductAttention,
)
Expand Down Expand Up @@ -57,13 +58,20 @@ def convert_pt2(
"""

# Export with dynamo
model_exp = capture_pre_autograd_graph(model, inputs)
model_gm = capture_pre_autograd_graph(model, inputs)

# Decompose SDPA
DecomposeScaledDotProductAttention(False)(model_exp)
if model_gm_has_SDPA(model_gm):
# Decompose SDPA
DecomposeScaledDotProductAttention(False)(model_gm)

# Swap _safe_softmax with _softmax (see https://github.com/pytorch/pytorch/pull/133882
# for details).
result = ReplaceSafeSoftmaxWithSoftmax()(model_gm)
assert result is not None
model_gm = result.graph_module

# Prepare
prepared_model = prepare_pt2e(model_exp, quantizer)
prepared_model = prepare_pt2e(model_gm, quantizer)

# Calibrate
prepared_model(*inputs)
Expand Down
26 changes: 26 additions & 0 deletions backends/cadence/aot/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,3 +266,29 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
result = SpecPropPass()(graph_module)
assert result is not None
return result


class ReplaceSafeSoftmaxWithSoftmax(ExportPass):
"""
Replace _safe_softmax with _softmax
"""

def call_operator(
self,
op, # pyre-ignore
args: tuple[Argument, ...],
kwargs: dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
if op != torch.ops.aten._safe_softmax.default:
return super().call_operator(op, args, kwargs, meta)

# Add False for the half_to_float argument of softmax
softmax_args = list(args) + [False]

return super().call_operator(
torch.ops.aten._softmax.default,
tuple(softmax_args),
kwargs,
meta,
)
8 changes: 8 additions & 0 deletions backends/cadence/aot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,11 @@ def print_ops_info(
tablefmt="outline",
)
)


def model_gm_has_SDPA(model_gm: torch.fx.GraphModule) -> bool:
for node in model_gm.graph.nodes:
if node.op == "call_function":
if node.target == torch.ops.aten.scaled_dot_product_attention.default:
return True
return False

0 comments on commit d7c069f

Please sign in to comment.