Skip to content

Commit 73880e6

Browse files
Erik-Lundellkirklandsign
authored andcommitted
Arm backend: Clean up shift support (#9573)
- Handle lshift.Tensor and rshift.Tensor - Convert *.Scalar to *.Tensor - Test Scalar and Tensor cases with multiple dtypes - Move cast logic from node visitor to pass Signed-off-by: Erik Lundell <erik.lundell@arm.com>
1 parent b7a557e commit 73880e6

15 files changed

+430
-185
lines changed

backends/arm/_passes/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from . import arm_pass_utils # noqa
88
from .annotate_channels_last_dim_order_pass import AnnotateChannelsLastDimOrder # noqa
99
from .annotate_decomposed_matmul import AnnotateDecomposedMatmulPass # noqa
10-
from .cast_int64_pass import CastInt64ToInt32Pass # noqa
10+
from .cast_int64_pass import CastInt64BuffersToInt32Pass # noqa
11+
from .cast_to_int32_pass import CastToInt32Pass # noqa
1112
from .conv1d_unsqueeze_pass import Conv1dUnsqueezePass # noqa
1213
from .convert_any_default_dim_dims_pass import ConvertAnyDefaultDimDimsPass # noqa
1314
from .convert_expand_copy_to_repeat import ConvertExpandCopyToRepeatPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
from executorch.backends.arm._passes import (
1111
AnnotateChannelsLastDimOrder,
1212
AnnotateDecomposedMatmulPass,
13-
CastInt64ToInt32Pass,
13+
CastInt64BuffersToInt32Pass,
14+
CastToInt32Pass,
1415
ComputeConstantOpsAOT,
1516
Conv1dUnsqueezePass,
1617
ConvertAnyDefaultDimDimsPass,
@@ -80,6 +81,8 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
8081
self.add_pass(ConvertToClampPass())
8182
self.add_pass(ConvertMinMaxPass())
8283
self.add_pass(ConvertAnyDefaultDimDimsPass())
84+
if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset:
85+
self.add_pass(CastToInt32Pass())
8386

8487
self.add_pass(ReplaceScalarWithTensorArgPass())
8588
self.add_pass(AnnotateDecomposedMatmulPass())
@@ -94,7 +97,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
9497
self.add_pass(SizeAdjustConv2DPass())
9598
self.add_pass(ConvertExpandCopyToRepeatPass())
9699
self.add_pass(UnsqueezeBeforeRepeatPass())
97-
self.add_pass(CastInt64ToInt32Pass(exported_program))
100+
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
98101
self.add_pass(KeepDimsFalseToSqueezePass())
99102
self.add_pass(Conv1dUnsqueezePass(exported_program))
100103
self.add_pass(DecomposeSelectPass())
@@ -141,7 +144,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
141144
self.add_pass(SizeAdjustConv2DPass())
142145
self.add_pass(ConvertExpandCopyToRepeatPass())
143146
self.add_pass(UnsqueezeBeforeRepeatPass())
144-
self.add_pass(CastInt64ToInt32Pass(exported_program))
147+
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
145148
self.add_pass(KeepDimsFalseToSqueezePass())
146149
self.add_pass(Conv1dUnsqueezePass(exported_program))
147150
self.add_pass(DecomposeSelectPass())

backends/arm/_passes/cast_int64_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515
logger.setLevel(logging.WARNING)
1616

1717

18-
class CastInt64ToInt32Pass(ExportPass):
18+
class CastInt64BuffersToInt32Pass(ExportPass):
1919
"""
2020
Cast int64 buffers to int32 if the int64 data is in int32 range.
2121
"""
2222

2323
def __init__(self, exported_program: torch.export.ExportedProgram):
24-
super(CastInt64ToInt32Pass, self).__init__()
24+
super(CastInt64BuffersToInt32Pass, self).__init__()
2525
self.exported_program = exported_program
2626

2727
def _assert_within_int32(self, tensor: torch.Tensor, node: torch.fx.Node):
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
from executorch.exir.pass_base import ExportPass
10+
11+
12+
class CastToInt32Pass(ExportPass):
13+
"""Casts the input to int32 if it is not already and casts back the output to the original input dtype."""
14+
15+
targeted_ops = {
16+
exir_ops.edge.aten.bitwise_left_shift.Tensor,
17+
exir_ops.edge.aten.bitwise_right_shift.Tensor,
18+
}
19+
20+
def call_operator(self, op, args, kwargs, meta):
21+
if op not in self.targeted_ops:
22+
return super().call_operator(op, args, kwargs, meta)
23+
24+
new_args: list = []
25+
did_cast = False
26+
for arg in args:
27+
if arg.data.dtype != torch.int32:
28+
new_args.append(
29+
super().call_operator(
30+
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
31+
(arg,),
32+
{"dtype": torch.int32},
33+
meta,
34+
)
35+
)
36+
did_cast = True
37+
else:
38+
new_args.append(arg)
39+
40+
output = super().call_operator(
41+
op,
42+
tuple(new_args),
43+
{},
44+
meta,
45+
)
46+
47+
if did_cast:
48+
output = super().call_operator(
49+
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
50+
(output,),
51+
{"dtype": args[0].data.dtype},
52+
meta,
53+
)
54+
return output

backends/arm/_passes/match_arg_ranks_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ def __init__(self, exported_program):
4545
exir_ops.edge.aten.sub.Tensor,
4646
exir_ops.edge.aten.mul.Tensor,
4747
exir_ops.edge.aten.div.Tensor,
48+
exir_ops.edge.aten.bitwise_right_shift.Tensor,
49+
exir_ops.edge.aten.bitwise_left_shift.Tensor,
4850
]
4951

5052
def _match_op_rank(self, graph_module, node, arg, max_rank):

backends/arm/operator_support/right_shift_support.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222

2323
@register_tosa_support_check
2424
class RightShiftSupported(SupportedTOSAOperatorCheck):
25-
targets = [exir_ops.edge.aten.__rshift__.Scalar]
25+
targets = [
26+
exir_ops.edge.aten.bitwise_right_shift.Tensor,
27+
exir_ops.edge.aten.__rshift__.Scalar,
28+
]
2629

2730
tosa_specs = [
2831
TosaSpecification.create_from_string("TOSA-0.80+BI"),

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,8 @@ def is_node_supported(
205205
exir_ops.edge.aten.amin.default,
206206
exir_ops.edge.aten.eye.default,
207207
exir_ops.edge.aten.linspace.default,
208+
exir_ops.edge.aten.bitwise_left_shift.Tensor,
209+
exir_ops.edge.aten.__lshift__.Scalar,
208210
torch.ops.aten.scalar_tensor.default,
209211
]
210212

backends/arm/operators/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
op_reciprocal,
3636
op_repeat,
3737
op_rescale,
38-
op_rshift,
38+
op_rshift_tensor,
3939
op_rsqrt,
4040
op_sigmoid,
4141
op_slice,

backends/arm/operators/op_rshift.py

Lines changed: 0 additions & 100 deletions
This file was deleted.
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
from typing import List
9+
10+
import serializer.tosa_serializer as ts # type: ignore
11+
import torch
12+
from executorch.backends.arm.operators.node_visitor import (
13+
NodeVisitor,
14+
register_node_visitor,
15+
)
16+
from executorch.backends.arm.tosa_mapping import TosaArg
17+
from executorch.backends.arm.tosa_specification import Tosa_0_80
18+
from serializer.tosa_serializer import TosaOp
19+
20+
21+
@register_node_visitor
22+
class RshiftVisitor(NodeVisitor):
23+
target = "aten.bitwise_right_shift.Tensor"
24+
25+
def define_node(
26+
self,
27+
node: torch.fx.Node,
28+
tosa_graph: ts.TosaSerializer,
29+
inputs: List[TosaArg],
30+
output: TosaArg,
31+
) -> None:
32+
33+
attr = ts.TosaSerializerAttribute()
34+
round = False
35+
if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset:
36+
# U55 only supports INT32 and round == True
37+
# TODO MLETORCH-525 Emulate round == False with different decomposition
38+
round = True
39+
attr.ArithmeticRightShiftAttribute(round=round)
40+
41+
tosa_graph.addOperator(
42+
TosaOp.Op().ARITHMETIC_RIGHT_SHIFT,
43+
[inputs[0].name, inputs[1].name],
44+
[output.name],
45+
attr,
46+
)

backends/arm/operators/ops_binary.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,6 @@ def define_node(
5252
binary_operator_factory("aten.logical_and.default", TosaOp.Op().LOGICAL_AND)
5353
binary_operator_factory("aten.logical_xor.default", TosaOp.Op().LOGICAL_XOR)
5454
binary_operator_factory("aten.logical_or.default", TosaOp.Op().LOGICAL_OR)
55+
binary_operator_factory(
56+
"aten.bitwise_left_shift.Tensor", TosaOp.Op().LOGICAL_LEFT_SHIFT
57+
)

0 commit comments

Comments
 (0)