Skip to content

Commit 087a27c

Browse files
Arm backend: Add support for 5D tensors (#11143)
Updates memory-format handling to consider 5D tensors. Adds a guard to not partition nodes with ranks > 5. This is because the current memory format pass does not handle such ranks in a good way. Signed-off-by: Oscar Andersson <oscar.andersson@arm.com>
1 parent aaf0ecc commit 087a27c

File tree

6 files changed

+114
-46
lines changed

6 files changed

+114
-46
lines changed

backends/arm/_passes/annotate_channels_last_dim_order_pass.py

Lines changed: 58 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# Copyright 2024-2025 Arm Limited and/or its affiliates.
2-
# All rights reserved.
32
#
43
# This source code is licensed under the BSD-style license found in the
54
# LICENSE file in the root directory of this source tree.
@@ -36,7 +35,7 @@
3635
def _transpose_impl(*args, **kwargs):
3736
# Validate length of dim_order array
3837
dim = args[1]
39-
assert len(dim) <= 4
38+
assert len(dim) in (4, 5)
4039
# Pass-through in edge-IR
4140
return args[0]
4241

@@ -45,13 +44,15 @@ class AnnotateChannelsLastDimOrder(ExportPass):
4544
"""
4645
Annotates each node with a tosa_dim_order. tosa_dim_order can be seen as a channels-last dim-order
4746
that in most cases will be (0, 2, 3, 1) for nodes with 4D-shapes. The pass also inserts passthrough_to_tosa._transpose
48-
when a transition between 3D and 4D tensors happen.
47+
when a transition between 3D and 4D/5D tensors happen.
4948
The annotated tosa_dim_order is used to permute the node's shape such that it gives a TOSA-compliant shape.
5049
"""
5150

5251
NHWC_order = (0, 2, 3, 1)
5352
NHWC_inverse_order = (0, 3, 1, 2)
5453
HWCM_order = (2, 3, 0, 1)
54+
NNHWC_order = (0, 1, 3, 4, 2)
55+
NNHWC_inverse_order = (0, 1, 4, 2, 3)
5556

5657
def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):
5758
"""
@@ -81,8 +82,12 @@ def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):
8182

8283
@staticmethod
8384
def memory_format_differs(shape):
84-
"""Returns true if the shape will have a different memory layout in NCHW and NHWC format"""
85-
if len(shape) >= 4:
85+
"""Returns true if the shape will have a different memory layout in (N)NCHW and (N)NHWC format"""
86+
if len(shape) >= 5:
87+
C = shape[2]
88+
H = shape[3]
89+
W = shape[4]
90+
elif len(shape) == 4:
8691
C = shape[1]
8792
H = shape[2]
8893
W = shape[3]
@@ -98,14 +103,24 @@ def memory_format_differs(shape):
98103
@staticmethod
99104
def is_channel_reshape(input_shape, output_shape):
100105
"""Returns true if the reshape changes the channel dimension"""
101-
if not len(input_shape) == len(output_shape) == 4:
106+
if not (
107+
(len(input_shape) == len(output_shape) and (len(output_shape) in (4, 5)))
108+
or (len(input_shape) == 4 and len(output_shape) == 5)
109+
or (len(input_shape) == 5 and len(output_shape) == 4)
110+
):
102111
return False
103112

104-
C_old = input_shape[1]
105-
C_new = output_shape[1]
113+
C_old = input_shape[-3]
114+
C_new = output_shape[-3]
106115

107-
N_new = output_shape[0]
108-
N_old = input_shape[0]
116+
N_new = (
117+
output_shape[0]
118+
if len(output_shape) == 4
119+
else output_shape[0] * output_shape[1]
120+
)
121+
N_old = (
122+
input_shape[0] if len(input_shape) == 4 else input_shape[0] * input_shape[1]
123+
)
109124

110125
return (N_old != N_new) or (C_old != C_new)
111126

@@ -119,7 +134,11 @@ def insert_input_transpose(node, input_node, graph_module):
119134
torch.ops.passthrough_to_tosa._transpose.default,
120135
args=(
121136
input_node,
122-
list(AnnotateChannelsLastDimOrder.NHWC_inverse_order),
137+
list(
138+
AnnotateChannelsLastDimOrder.NNHWC_inverse_order
139+
if len(get_first_fake_tensor(input_node).size()) == 5
140+
else AnnotateChannelsLastDimOrder.NHWC_inverse_order
141+
),
123142
),
124143
quantize=quantize,
125144
q_params=q_params,
@@ -137,15 +156,28 @@ def insert_output_transpose(node, graph_module):
137156
permute_node = create_node(
138157
graph_module.graph,
139158
torch.ops.passthrough_to_tosa._transpose.default,
140-
args=(node, list(AnnotateChannelsLastDimOrder.NHWC_order)),
159+
args=(
160+
node,
161+
list(
162+
AnnotateChannelsLastDimOrder.NNHWC_order
163+
if len(get_first_fake_tensor(node).size()) == 5
164+
else AnnotateChannelsLastDimOrder.NHWC_order
165+
),
166+
),
141167
)
142168
permute_node.meta["tosa_dim_order"] = (
143-
AnnotateChannelsLastDimOrder.NHWC_order
169+
AnnotateChannelsLastDimOrder.NNHWC_order
170+
if len(get_first_fake_tensor(node).size()) == 5
171+
else AnnotateChannelsLastDimOrder.NHWC_order
172+
)
173+
permute_node.meta["val"] = get_first_fake_tensor(node).permute(
174+
AnnotateChannelsLastDimOrder.NNHWC_order
175+
if len(get_first_fake_tensor(node).size()) == 5
176+
else AnnotateChannelsLastDimOrder.NHWC_order
144177
)
145-
permute_node.meta["val"] = node.meta["val"].permute(
146-
AnnotateChannelsLastDimOrder.NHWC_order
178+
node.meta["tosa_dim_order"] = tuple(
179+
range(len(get_first_fake_tensor(node).size()))
147180
)
148-
node.meta["tosa_dim_order"] = (0, 1, 2, 3)
149181
users = [user for user in node.users if user != permute_node]
150182
for user in users:
151183
user.replace_input_with(node, permute_node)
@@ -159,8 +191,8 @@ def insert_output_transpose(node, graph_module):
159191
def _insert_view_transpose(
160192
input_shape, output_shape, node, input_node, graph_module
161193
):
162-
nchw_to_nhwc = len(input_shape) < 4 and len(output_shape) == 4
163-
nhwc_to_nchw = len(input_shape) == 4 and len(output_shape) < 4
194+
nchw_to_nhwc = len(input_shape) < 4 and len(output_shape) >= 4
195+
nhwc_to_nchw = len(input_shape) >= 4 and len(output_shape) < 4
164196
channel_reshape = AnnotateChannelsLastDimOrder.is_channel_reshape(
165197
output_shape, input_shape
166198
)
@@ -178,11 +210,11 @@ def _insert_view_transpose(
178210

179211
def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
180212
"""
181-
Transposes are needed for operators transforming the input to a different rank, as 4D-tensors are assumed to be in NHWC-format, whereas all other are in NCHW format.
213+
Transposes are needed for operators transforming the input to a different rank, as 4D and 5D-tensors are assumed to be in (N)NHWC-format, whereas all other are in (N)NCHW format.
182214
This is relevant for the following cases:
183-
- view: <4D -> 4D
184-
- view: 4D -> <4D
185-
Additionally, a 4D->4D view operation acting on the channel dimension currently needs to be performed in NCHW format, leadning to one extra input and output transpose for this case.
215+
- view: <4D -> >=4D
216+
- view: >=4D -> <4D
217+
Additionally, a 4D/5D->4D/5D view operation acting on the channel dimension currently needs to be performed in (N)NCHW format, leadning to one extra input and output transpose for this case.
186218
187219
Transposes can be avoided for shapes where there is no difference in actual memory, e.g for
188220
- H == W == 1
@@ -212,12 +244,13 @@ def call(self, graph_module: torch.fx.GraphModule):
212244
# The weights of TOSA DEPTHWISE_CONV2D have shape (H, W, C, M) which corresponds to
213245
# dim_order = (2, 3, 0, 1) (https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d).
214246
dim_order = self.HWCM_order
247+
elif node_data.dim() == 5:
248+
dim_order = self.NNHWC_order # type: ignore[assignment]
215249
else:
216250
dim_order = tuple(range(node_data.dim())) # type: ignore[assignment]
217251
node.meta["tosa_dim_order"] = dim_order
218-
# Take care of cases when:
219-
# 4D (NHWC) -> >4D (NCH)
220-
# 3D (NCH) -> 4D (NHWC)
252+
# Insert TOSA transposes to convert between (N)NCHW and (N)NHWC format.
253+
# See insert_tosa_transposes for insertion conditions.
221254
self.insert_tosa_transposes(graph_module)
222255
graph_module.recompile()
223256
graph_module = super().call(graph_module).graph_module

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
from executorch.exir import ExportedProgram
2929
from executorch.exir.backend.utils import WhyNoPartitionReporter
3030
from executorch.exir.dialects._ops import ops as exir_ops
31+
32+
from torch._subclasses.fake_tensor import FakeTensor
3133
from torch.export.graph_signature import InputKind
3234
from torch.fx.passes.operator_support import any_chain, chain, OperatorSupportBase
3335
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
@@ -116,6 +118,7 @@ def tosa_support_factory(
116118
negative_checks: list[OperatorSupportBase] = [
117119
CheckInt64Inputs(exported_program, reporter),
118120
CheckFloat64Inputs(exported_program, reporter),
121+
RankCheck(reporter, max_rank=5),
119122
*[
120123
reporter.wrap_check(check, f"Rejected by {check.__class__.__name__}")
121124
for check in (additional_checks if additional_checks else [])
@@ -474,3 +477,51 @@ def is_node_supported(
474477
)
475478
return False
476479
return True
480+
481+
482+
class RankCheck(OperatorSupportBase):
483+
"""Makes sure that nodes with input or output tensors with rank > max_rank are not partitioned"""
484+
485+
def __init__(self, reporter: WhyNoPartitionReporter, max_rank: int):
486+
self.reporter = reporter
487+
self.max_rank = max_rank
488+
super().__init__()
489+
490+
def is_node_supported(
491+
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
492+
) -> bool:
493+
input_nodes = node.all_input_nodes
494+
# check if any input node has an unsupported rank
495+
for input_node in input_nodes:
496+
input_node_shape = get_first_fake_tensor(input_node).shape
497+
if len(input_node_shape) > self.max_rank:
498+
self.reporter.report_reject(
499+
node,
500+
f"{node.name} has input_node {input_node.name} with shape {input_node_shape}, "
501+
f"rank {len(input_node_shape)} which is unsupported. "
502+
f"Max supported rank is {self.max_rank}.",
503+
)
504+
return False
505+
506+
meta_val = node.meta["val"]
507+
if isinstance(
508+
meta_val, (Sequence, torch.fx.immutable_collections.immutable_list)
509+
):
510+
for val in meta_val:
511+
if isinstance(val, FakeTensor):
512+
if len(val.shape) > self.max_rank:
513+
self.reporter.report_reject(
514+
node,
515+
f"{node.name} has a shape {val.shape}, rank {len(val.shape)} which is unsupported."
516+
f"Max supported rank is {self.max_rank}.",
517+
)
518+
return False
519+
elif isinstance(meta_val, FakeTensor):
520+
if len(meta_val.shape) > self.max_rank:
521+
self.reporter.report_reject(
522+
node,
523+
f"{node.name} has shape {meta_val.shape}, rank={len(meta_val.shape)} which is unsupported."
524+
f"Max supported rank is {self.max_rank}.",
525+
)
526+
return False
527+
return True

backends/arm/test/models/test_conformer.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,6 @@ def test_conformer_tosa_MI():
5757
exir_op=[],
5858
use_to_edge_transform_and_lower=True,
5959
)
60-
pipeline.change_args(
61-
"run_method_and_compare_outputs",
62-
get_test_inputs(
63-
TestConformer.dim, TestConformer.lengths, TestConformer.num_examples
64-
),
65-
rtol=1.0,
66-
atol=5.0,
67-
)
6860
pipeline.run()
6961

7062

@@ -83,7 +75,7 @@ def test_conformer_tosa_BI():
8375
TestConformer.dim, TestConformer.lengths, TestConformer.num_examples
8476
),
8577
rtol=1.0,
86-
atol=5.0,
78+
atol=3.0,
8779
)
8880
pipeline.run()
8981

backends/arm/test/models/test_deit_tiny_arm.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,6 @@ def test_deit_tiny_tosa_MI():
4141
aten_op=[],
4242
exir_op=[],
4343
use_to_edge_transform_and_lower=True,
44-
atol=6.5, # This needs to go down: MLETORCH-940
45-
qtol=1,
4644
)
4745
pipeline.run()
4846

@@ -54,7 +52,7 @@ def test_deit_tiny_tosa_BI():
5452
aten_op=[],
5553
exir_op=[],
5654
use_to_edge_transform_and_lower=True,
57-
atol=3.0, # This needs to go down: MLETORCH-956
55+
atol=2.5, # This needs to go down: MLETORCH-956
5856
qtol=1,
5957
)
6058
pipeline.run()

backends/arm/test/models/test_llama.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -109,17 +109,11 @@ def test_llama_tosa_MI():
109109
exir_op=[],
110110
use_to_edge_transform_and_lower=True,
111111
)
112-
pipeline.change_args(
113-
"run_method_and_compare_outputs",
114-
atol=4.3,
115-
rtol=1.1, # TODO: MLETORCH-825 decrease tolerance
116-
)
117112
pipeline.run()
118113

119114

120-
@pytest.mark.xfail(reason="KeyError: scalar_tensor_1 (MLETORCH-907)")
121115
def test_llama_tosa_BI():
122-
llama_model, llama_inputs, llama_meta = TestLlama.prepare_model()
116+
llama_model, llama_inputs, llama_meta = TestLlama().prepare_model()
123117

124118
if llama_model is None or llama_inputs is None:
125119
pytest.skip("Missing model and/or input files")
@@ -136,5 +130,6 @@ def test_llama_tosa_BI():
136130
"run_method_and_compare_outputs",
137131
atol=9.9,
138132
rtol=1.5, # TODO: Tolerance needs to be updated after MLETORCH-907
133+
inputs=llama_inputs,
139134
)
140135
pipeline.run()

backends/arm/test/models/test_nn_modules.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@
5656
@parametrize(
5757
"test_data",
5858
test_parameters,
59-
xfails={"Transformer": "Output 0 does not match reference output."},
6059
)
6160
def test_nn_Modules_MI(test_data):
6261
module, inputs = test_data
@@ -81,7 +80,7 @@ def test_nn_Modules_MI(test_data):
8180
xfails={
8281
"GRU": "RuntimeError: Node aten_linear_default with op <EdgeOpOverload: aten.linear[...]> was not decomposed or delegated.",
8382
"PReLU": "RuntimeError: mul(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.",
84-
"Transformer": "RuntimeError: Expected out tensor to have dtype signed char, but got float",
83+
"Transformer": "AssertionError: Output 0 does not match reference output.",
8584
},
8685
)
8786
def test_nn_Modules_BI(test_data):

0 commit comments

Comments
 (0)