Skip to content

Commit 11d6f1a

Browse files
Arm backend: Add numerically stable (log)softmax decomposition
- Only use the old version for Ethos-U55 compile specs since amax isn't supported in that case. - Add support for negative indices in amax/amin - Refactor unittests Change-Id: I7ed43b8d6b95625f59ce9e71d55a21763fc51358 Signed-off-by: Adrian Lundell <adrian.lundell@arm.com> Co-authored-by: Oscar Andersson <oscar.andersson@arm.com>
1 parent ddf0d9e commit 11d6f1a

File tree

7 files changed

+299
-290
lines changed

7 files changed

+299
-290
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,9 @@
4141
from executorch.backends.arm._passes.decompose_select import ( # type: ignore[import-not-found]
4242
DecomposeSelectPass,
4343
)
44-
from executorch.backends.arm._passes.decompose_softmaxes_pass import (
45-
DecomposeSoftmaxesPass,
44+
from executorch.backends.arm._passes.decompose_softmax_pass import DecomposeSoftmaxPass
45+
from executorch.backends.arm._passes.decompose_softmax_unstable_pass import (
46+
DecomposeSoftmaxUnstablePass,
4647
)
4748
from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass
4849
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
@@ -78,7 +79,7 @@
7879
from executorch.backends.arm._passes.unsqueeze_scalar_placeholders_pass import (
7980
UnsqueezeScalarPlaceholdersPass,
8081
)
81-
from executorch.backends.arm.tosa_specification import TosaSpecification
82+
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
8283
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform
8384

8485
from executorch.backends.transforms.replace_scalar_with_tensor import (
@@ -151,7 +152,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
151152
self.add_pass(DecomposeMeanDimPass())
152153
self.add_pass(ConvertMeanDimToAveragePoolPass())
153154
self.add_pass(DecomposeDivPass())
154-
self.add_pass(DecomposeSoftmaxesPass())
155+
self.add_pass(DecomposeSoftmaxPass())
155156
self.add_pass(ConvertFullLikeToFullPass())
156157
self.add_pass(ConvertToClampPass())
157158
self.add_pass(ConvertMinMaxPass())
@@ -199,6 +200,12 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
199200
self.add_pass(DecomposeVarPass())
200201
self.add_pass(DecomposeMeanDimPass())
201202
self.add_pass(DecomposeDivPass())
202-
self.add_pass(DecomposeSoftmaxesPass())
203+
204+
if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset:
205+
# Numerically stable softmax uses amax which is not supported on Ethos-U55
206+
self.add_pass(DecomposeSoftmaxUnstablePass())
207+
else:
208+
self.add_pass(DecomposeSoftmaxPass())
209+
203210
self.add_pass(ConvertMinMaxPass())
204211
return self._transform(graph_module)
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.exir.dialects._ops import ops as exir_ops
8+
from executorch.exir.pass_base import ExportPass
9+
10+
# For BI case
11+
torch_softmax = (torch.ops.aten.softmax.int, torch.ops.aten.log_softmax.int)
12+
# For MI case
13+
edge_softmax = (
14+
exir_ops.edge.aten._softmax.default,
15+
exir_ops.edge.aten._log_softmax.default,
16+
)
17+
log_softmax = (torch.ops.aten.log_softmax.int, exir_ops.edge.aten._log_softmax.default)
18+
19+
20+
def _get_logsoftmax_ops(op) -> tuple:
21+
"""
22+
Returns the (log_op, sub_op, amax_op, expo_op, sum_op, reciprocal_op), where the ops depends on if
23+
the softmax op is an aten or edge op.
24+
"""
25+
if op in edge_softmax:
26+
return (
27+
exir_ops.edge.aten.log.default,
28+
exir_ops.edge.aten.sub.Tensor,
29+
exir_ops.edge.aten.amax.default,
30+
exir_ops.edge.aten.exp.default,
31+
exir_ops.edge.aten.sum.dim_IntList,
32+
exir_ops.edge.aten.reciprocal.default,
33+
exir_ops.edge.aten.mul.Tensor,
34+
)
35+
if op in torch_softmax:
36+
return (
37+
torch.ops.aten.log.default,
38+
torch.ops.aten.sub.Tensor,
39+
torch.ops.aten.amax.default,
40+
torch.ops.aten.exp.default,
41+
torch.ops.aten.sum.dim_IntList,
42+
torch.ops.aten.reciprocal.default,
43+
torch.ops.aten.mul.Tensor,
44+
)
45+
raise RuntimeError(f"Can't get logsoftmax decomposition ops for op {op}")
46+
47+
48+
class DecomposeSoftmaxPass(ExportPass):
49+
"""
50+
This pass decomposes log_softmax or softmax into more primitive ops.
51+
Example:
52+
%op1 = amax(x)
53+
%op2 = sub(x, %op1)
54+
%op3 = exp(%op2)
55+
%op4 = sum(%op3, dim)
56+
%op5 = reciprocal(%op4)
57+
%op6 = mul(%op3, %op5)
58+
(in logsoftmax case: %op7 = log(%op6))
59+
"""
60+
61+
def call_operator(self, op, args, kwargs, meta):
62+
if op not in torch_softmax + edge_softmax:
63+
return super().call_operator(op, args, kwargs, meta)
64+
log_op, sub_op, max_op, exp_op, sum_op, reciprocal_op, mul_op = (
65+
_get_logsoftmax_ops(op)
66+
)
67+
_input = args[0]
68+
dim = [args[1]]
69+
op1 = super().call_operator(max_op, (_input, dim, True), {}, meta)
70+
op2 = super().call_operator(sub_op, (_input, op1), {}, meta)
71+
op3 = super().call_operator(exp_op, (op2,), {}, meta)
72+
op4 = super().call_operator(sum_op, (op3, dim, True), {}, meta)
73+
op5 = super().call_operator(reciprocal_op, (op4,), {}, meta)
74+
op6 = super().call_operator(mul_op, (op3, op5), {}, meta)
75+
if op in log_softmax:
76+
op6 = super().call_operator(log_op, (op6,), {}, meta)
77+
return op6

backends/arm/_passes/decompose_softmaxes_pass.py renamed to backends/arm/_passes/decompose_softmax_unstable_pass.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
2-
# All rights reserved.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
32
#
43
# This source code is licensed under the BSD-style license found in the
54
# LICENSE file in the root directory of this source tree.
@@ -46,7 +45,7 @@ def get_logsoftmax_ops(op) -> tuple:
4645
raise RuntimeError(f"Can't get softmax decomposition ops for op {op}")
4746

4847

49-
class DecomposeSoftmaxesPass(ExportPass):
48+
class DecomposeSoftmaxUnstablePass(ExportPass):
5049
"""
5150
This pass decomposes log softmax or softmax into more primitive ops.
5251

backends/arm/operators/op_amax.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import List
66

77
import serializer.tosa_serializer as ts
8+
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
89
from executorch.backends.arm.operators.node_visitor import (
910
NodeVisitor,
1011
register_node_visitor,
@@ -31,6 +32,12 @@ def define_node(
3132

3233
input = inputs[0]
3334
dim = inputs[1].number
35+
36+
if dim < 0:
37+
tensor = get_first_fake_tensor(node)
38+
rank = len(tensor.size())
39+
dim = rank + dim
40+
3441
keep_dims = inputs[2].number
3542
if not keep_dims:
3643
raise RuntimeError(

backends/arm/operators/op_amin.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import List
66

77
import serializer.tosa_serializer as ts
8+
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
89
from executorch.backends.arm.operators.node_visitor import (
910
NodeVisitor,
1011
register_node_visitor,
@@ -31,6 +32,12 @@ def define_node(
3132

3233
input = inputs[0]
3334
dim = inputs[1].number
35+
36+
if dim < 0:
37+
tensor = get_first_fake_tensor(node)
38+
rank = len(tensor.size())
39+
dim = rank + dim
40+
3441
keep_dims = inputs[2].number
3542
if not keep_dims:
3643
raise RuntimeError(

0 commit comments

Comments
 (0)