Skip to content

Arm backend: Add support for 5D tensors #11143

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 58 additions & 25 deletions backends/arm/_passes/annotate_channels_last_dim_order_pass.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -36,7 +35,7 @@
def _transpose_impl(*args, **kwargs):
# Validate length of dim_order array
dim = args[1]
assert len(dim) <= 4
assert len(dim) in (4, 5)
# Pass-through in edge-IR
return args[0]

Expand All @@ -45,13 +44,15 @@ class AnnotateChannelsLastDimOrder(ExportPass):
"""
Annotates each node with a tosa_dim_order. tosa_dim_order can be seen as a channels-last dim-order
that in most cases will be (0, 2, 3, 1) for nodes with 4D-shapes. The pass also inserts passthrough_to_tosa._transpose
when a transition between 3D and 4D tensors happen.
when a transition between 3D and 4D/5D tensors happen.
The annotated tosa_dim_order is used to permute the node's shape such that it gives a TOSA-compliant shape.
"""

NHWC_order = (0, 2, 3, 1)
NHWC_inverse_order = (0, 3, 1, 2)
HWCM_order = (2, 3, 0, 1)
NNHWC_order = (0, 1, 3, 4, 2)
NNHWC_inverse_order = (0, 1, 4, 2, 3)

def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):
"""
Expand Down Expand Up @@ -81,8 +82,12 @@ def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):

@staticmethod
def memory_format_differs(shape):
"""Returns true if the shape will have a different memory layout in NCHW and NHWC format"""
if len(shape) >= 4:
"""Returns true if the shape will have a different memory layout in (N)NCHW and (N)NHWC format"""
if len(shape) >= 5:
C = shape[2]
H = shape[3]
W = shape[4]
elif len(shape) == 4:
C = shape[1]
H = shape[2]
W = shape[3]
Expand All @@ -98,14 +103,24 @@ def memory_format_differs(shape):
@staticmethod
def is_channel_reshape(input_shape, output_shape):
"""Returns true if the reshape changes the channel dimension"""
if not len(input_shape) == len(output_shape) == 4:
if not (
(len(input_shape) == len(output_shape) and (len(output_shape) in (4, 5)))
or (len(input_shape) == 4 and len(output_shape) == 5)
or (len(input_shape) == 5 and len(output_shape) == 4)
):
return False

C_old = input_shape[1]
C_new = output_shape[1]
C_old = input_shape[-3]
C_new = output_shape[-3]

N_new = output_shape[0]
N_old = input_shape[0]
N_new = (
output_shape[0]
if len(output_shape) == 4
else output_shape[0] * output_shape[1]
)
N_old = (
input_shape[0] if len(input_shape) == 4 else input_shape[0] * input_shape[1]
)

return (N_old != N_new) or (C_old != C_new)

Expand All @@ -119,7 +134,11 @@ def insert_input_transpose(node, input_node, graph_module):
torch.ops.passthrough_to_tosa._transpose.default,
args=(
input_node,
list(AnnotateChannelsLastDimOrder.NHWC_inverse_order),
list(
AnnotateChannelsLastDimOrder.NNHWC_inverse_order
if len(get_first_fake_tensor(input_node).size()) == 5
else AnnotateChannelsLastDimOrder.NHWC_inverse_order
),
),
quantize=quantize,
q_params=q_params,
Expand All @@ -137,15 +156,28 @@ def insert_output_transpose(node, graph_module):
permute_node = create_node(
graph_module.graph,
torch.ops.passthrough_to_tosa._transpose.default,
args=(node, list(AnnotateChannelsLastDimOrder.NHWC_order)),
args=(
node,
list(
AnnotateChannelsLastDimOrder.NNHWC_order
if len(get_first_fake_tensor(node).size()) == 5
else AnnotateChannelsLastDimOrder.NHWC_order
),
),
)
permute_node.meta["tosa_dim_order"] = (
AnnotateChannelsLastDimOrder.NHWC_order
AnnotateChannelsLastDimOrder.NNHWC_order
if len(get_first_fake_tensor(node).size()) == 5
else AnnotateChannelsLastDimOrder.NHWC_order
)
permute_node.meta["val"] = get_first_fake_tensor(node).permute(
AnnotateChannelsLastDimOrder.NNHWC_order
if len(get_first_fake_tensor(node).size()) == 5
else AnnotateChannelsLastDimOrder.NHWC_order
)
permute_node.meta["val"] = node.meta["val"].permute(
AnnotateChannelsLastDimOrder.NHWC_order
node.meta["tosa_dim_order"] = tuple(
range(len(get_first_fake_tensor(node).size()))
)
node.meta["tosa_dim_order"] = (0, 1, 2, 3)
users = [user for user in node.users if user != permute_node]
for user in users:
user.replace_input_with(node, permute_node)
Expand All @@ -159,8 +191,8 @@ def insert_output_transpose(node, graph_module):
def _insert_view_transpose(
input_shape, output_shape, node, input_node, graph_module
):
nchw_to_nhwc = len(input_shape) < 4 and len(output_shape) == 4
nhwc_to_nchw = len(input_shape) == 4 and len(output_shape) < 4
nchw_to_nhwc = len(input_shape) < 4 and len(output_shape) >= 4
nhwc_to_nchw = len(input_shape) >= 4 and len(output_shape) < 4
channel_reshape = AnnotateChannelsLastDimOrder.is_channel_reshape(
output_shape, input_shape
)
Expand All @@ -178,11 +210,11 @@ def _insert_view_transpose(

def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
"""
Transposes are needed for operators transforming the input to a different rank, as 4D-tensors are assumed to be in NHWC-format, whereas all other are in NCHW format.
Transposes are needed for operators transforming the input to a different rank, as 4D and 5D-tensors are assumed to be in (N)NHWC-format, whereas all other are in (N)NCHW format.
This is relevant for the following cases:
- view: <4D -> 4D
- view: 4D -> <4D
Additionally, a 4D->4D view operation acting on the channel dimension currently needs to be performed in NCHW format, leadning to one extra input and output transpose for this case.
- view: <4D -> >=4D
- view: >=4D -> <4D
Additionally, a 4D/5D->4D/5D view operation acting on the channel dimension currently needs to be performed in (N)NCHW format, leadning to one extra input and output transpose for this case.

Transposes can be avoided for shapes where there is no difference in actual memory, e.g for
- H == W == 1
Expand Down Expand Up @@ -212,12 +244,13 @@ def call(self, graph_module: torch.fx.GraphModule):
# The weights of TOSA DEPTHWISE_CONV2D have shape (H, W, C, M) which corresponds to
# dim_order = (2, 3, 0, 1) (https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d).
dim_order = self.HWCM_order
elif node_data.dim() == 5:
dim_order = self.NNHWC_order # type: ignore[assignment]
else:
dim_order = tuple(range(node_data.dim())) # type: ignore[assignment]
node.meta["tosa_dim_order"] = dim_order
# Take care of cases when:
# 4D (NHWC) -> >4D (NCH)
# 3D (NCH) -> 4D (NHWC)
# Insert TOSA transposes to convert between (N)NCHW and (N)NHWC format.
# See insert_tosa_transposes for insertion conditions.
self.insert_tosa_transposes(graph_module)
graph_module.recompile()
graph_module = super().call(graph_module).graph_module
Expand Down
51 changes: 51 additions & 0 deletions backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from executorch.exir import ExportedProgram
from executorch.exir.backend.utils import WhyNoPartitionReporter
from executorch.exir.dialects._ops import ops as exir_ops

from torch._subclasses.fake_tensor import FakeTensor
from torch.export.graph_signature import InputKind
from torch.fx.passes.operator_support import any_chain, chain, OperatorSupportBase
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
Expand Down Expand Up @@ -116,6 +118,7 @@ def tosa_support_factory(
negative_checks: list[OperatorSupportBase] = [
CheckInt64Inputs(exported_program, reporter),
CheckFloat64Inputs(exported_program, reporter),
RankCheck(reporter, max_rank=5),
*[
reporter.wrap_check(check, f"Rejected by {check.__class__.__name__}")
for check in (additional_checks if additional_checks else [])
Expand Down Expand Up @@ -474,3 +477,51 @@ def is_node_supported(
)
return False
return True


class RankCheck(OperatorSupportBase):
"""Makes sure that nodes with input or output tensors with rank > max_rank are not partitioned"""

def __init__(self, reporter: WhyNoPartitionReporter, max_rank: int):
self.reporter = reporter
self.max_rank = max_rank
super().__init__()

def is_node_supported(
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
) -> bool:
input_nodes = node.all_input_nodes
# check if any input node has an unsupported rank
for input_node in input_nodes:
input_node_shape = get_first_fake_tensor(input_node).shape
if len(input_node_shape) > self.max_rank:
self.reporter.report_reject(
node,
f"{node.name} has input_node {input_node.name} with shape {input_node_shape}, "
f"rank {len(input_node_shape)} which is unsupported. "
f"Max supported rank is {self.max_rank}.",
)
return False

meta_val = node.meta["val"]
if isinstance(
meta_val, (Sequence, torch.fx.immutable_collections.immutable_list)
):
for val in meta_val:
if isinstance(val, FakeTensor):
if len(val.shape) > self.max_rank:
self.reporter.report_reject(
node,
f"{node.name} has a shape {val.shape}, rank {len(val.shape)} which is unsupported."
f"Max supported rank is {self.max_rank}.",
)
return False
elif isinstance(meta_val, FakeTensor):
if len(meta_val.shape) > self.max_rank:
self.reporter.report_reject(
node,
f"{node.name} has shape {meta_val.shape}, rank={len(meta_val.shape)} which is unsupported."
f"Max supported rank is {self.max_rank}.",
)
return False
return True
10 changes: 1 addition & 9 deletions backends/arm/test/models/test_conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,6 @@ def test_conformer_tosa_MI():
exir_op=[],
use_to_edge_transform_and_lower=True,
)
pipeline.change_args(
"run_method_and_compare_outputs",
get_test_inputs(
TestConformer.dim, TestConformer.lengths, TestConformer.num_examples
),
rtol=1.0,
atol=5.0,
)
pipeline.run()


Expand All @@ -83,7 +75,7 @@ def test_conformer_tosa_BI():
TestConformer.dim, TestConformer.lengths, TestConformer.num_examples
),
rtol=1.0,
atol=5.0,
atol=3.0,
)
pipeline.run()

Expand Down
4 changes: 1 addition & 3 deletions backends/arm/test/models/test_deit_tiny_arm.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,6 @@ def test_deit_tiny_tosa_MI():
aten_op=[],
exir_op=[],
use_to_edge_transform_and_lower=True,
atol=6.5, # This needs to go down: MLETORCH-940
qtol=1,
)
pipeline.run()

Expand All @@ -54,7 +52,7 @@ def test_deit_tiny_tosa_BI():
aten_op=[],
exir_op=[],
use_to_edge_transform_and_lower=True,
atol=3.0, # This needs to go down: MLETORCH-956
atol=2.5, # This needs to go down: MLETORCH-956
qtol=1,
)
pipeline.run()
9 changes: 2 additions & 7 deletions backends/arm/test/models/test_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,17 +109,11 @@ def test_llama_tosa_MI():
exir_op=[],
use_to_edge_transform_and_lower=True,
)
pipeline.change_args(
"run_method_and_compare_outputs",
atol=4.3,
rtol=1.1, # TODO: MLETORCH-825 decrease tolerance
)
pipeline.run()


@pytest.mark.xfail(reason="KeyError: scalar_tensor_1 (MLETORCH-907)")
def test_llama_tosa_BI():
llama_model, llama_inputs, llama_meta = TestLlama.prepare_model()
llama_model, llama_inputs, llama_meta = TestLlama().prepare_model()

if llama_model is None or llama_inputs is None:
pytest.skip("Missing model and/or input files")
Expand All @@ -136,5 +130,6 @@ def test_llama_tosa_BI():
"run_method_and_compare_outputs",
atol=9.9,
rtol=1.5, # TODO: Tolerance needs to be updated after MLETORCH-907
inputs=llama_inputs,
)
pipeline.run()
3 changes: 1 addition & 2 deletions backends/arm/test/models/test_nn_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@
@parametrize(
"test_data",
test_parameters,
xfails={"Transformer": "Output 0 does not match reference output."},
)
def test_nn_Modules_MI(test_data):
module, inputs = test_data
Expand All @@ -81,7 +80,7 @@ def test_nn_Modules_MI(test_data):
xfails={
"GRU": "RuntimeError: Node aten_linear_default with op <EdgeOpOverload: aten.linear[...]> was not decomposed or delegated.",
"PReLU": "RuntimeError: mul(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.",
"Transformer": "RuntimeError: Expected out tensor to have dtype signed char, but got float",
"Transformer": "AssertionError: Output 0 does not match reference output.",
},
)
def test_nn_Modules_BI(test_data):
Expand Down
Loading