|
18 | 18 | ReplaceLogicalNotBooleanWhereWithWherePass,
|
19 | 19 | ReplacePT2DequantWithCadenceDequantPass,
|
20 | 20 | ReplacePT2QuantWithCadenceQuantPass,
|
| 21 | + ReplaceSafeSoftmaxWithSoftmax, |
21 | 22 | ReplaceScalarTensorWithFullPass,
|
22 | 23 | ReplaceSqueezeAndUnsqueezeWithViewPass,
|
23 | 24 | )
|
24 | 25 | from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion
|
25 | 26 | from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer
|
26 |
| -from executorch.backends.cadence.aot.utils import model_is_quantized |
| 27 | +from executorch.backends.cadence.aot.utils import model_gm_has_SDPA, model_is_quantized |
27 | 28 | from executorch.backends.transforms.decompose_sdpa import (
|
28 | 29 | DecomposeScaledDotProductAttention,
|
29 | 30 | )
|
@@ -57,13 +58,20 @@ def convert_pt2(
|
57 | 58 | """
|
58 | 59 |
|
59 | 60 | # Export with dynamo
|
60 |
| - model_exp = capture_pre_autograd_graph(model, inputs) |
| 61 | + model_gm = capture_pre_autograd_graph(model, inputs) |
61 | 62 |
|
62 |
| - # Decompose SDPA |
63 |
| - DecomposeScaledDotProductAttention(False)(model_exp) |
| 63 | + if model_gm_has_SDPA(model_gm): |
| 64 | + # Decompose SDPA |
| 65 | + DecomposeScaledDotProductAttention(False)(model_gm) |
| 66 | + |
| 67 | + # Swap _safe_softmax with _softmax (see https://github.com/pytorch/pytorch/pull/133882 |
| 68 | + # for details). |
| 69 | + result = ReplaceSafeSoftmaxWithSoftmax()(model_gm) |
| 70 | + assert result is not None |
| 71 | + model_gm = result.graph_module |
64 | 72 |
|
65 | 73 | # Prepare
|
66 |
| - prepared_model = prepare_pt2e(model_exp, quantizer) |
| 74 | + prepared_model = prepare_pt2e(model_gm, quantizer) |
67 | 75 |
|
68 | 76 | # Calibrate
|
69 | 77 | prepared_model(*inputs)
|
|
0 commit comments