1
1
# Copyright 2024-2025 Arm Limited and/or its affiliates.
2
- # All rights reserved.
3
2
#
4
3
# This source code is licensed under the BSD-style license found in the
5
4
# LICENSE file in the root directory of this source tree.
36
35
def _transpose_impl (* args , ** kwargs ):
37
36
# Validate length of dim_order array
38
37
dim = args [1 ]
39
- assert len (dim ) <= 4
38
+ assert len (dim ) in ( 4 , 5 )
40
39
# Pass-through in edge-IR
41
40
return args [0 ]
42
41
@@ -45,13 +44,15 @@ class AnnotateChannelsLastDimOrder(ExportPass):
45
44
"""
46
45
Annotates each node with a tosa_dim_order. tosa_dim_order can be seen as a channels-last dim-order
47
46
that in most cases will be (0, 2, 3, 1) for nodes with 4D-shapes. The pass also inserts passthrough_to_tosa._transpose
48
- when a transition between 3D and 4D tensors happen.
47
+ when a transition between 3D and 4D/5D tensors happen.
49
48
The annotated tosa_dim_order is used to permute the node's shape such that it gives a TOSA-compliant shape.
50
49
"""
51
50
52
51
NHWC_order = (0 , 2 , 3 , 1 )
53
52
NHWC_inverse_order = (0 , 3 , 1 , 2 )
54
53
HWCM_order = (2 , 3 , 0 , 1 )
54
+ NNHWC_order = (0 , 1 , 3 , 4 , 2 )
55
+ NNHWC_inverse_order = (0 , 1 , 4 , 2 , 3 )
55
56
56
57
def is_weight_node_for_depthwise_conv2d (self , node : torch .fx .Node ):
57
58
"""
@@ -81,8 +82,12 @@ def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):
81
82
82
83
@staticmethod
83
84
def memory_format_differs (shape ):
84
- """Returns true if the shape will have a different memory layout in NCHW and NHWC format"""
85
- if len (shape ) >= 4 :
85
+ """Returns true if the shape will have a different memory layout in (N)NCHW and (N)NHWC format"""
86
+ if len (shape ) >= 5 :
87
+ C = shape [2 ]
88
+ H = shape [3 ]
89
+ W = shape [4 ]
90
+ elif len (shape ) == 4 :
86
91
C = shape [1 ]
87
92
H = shape [2 ]
88
93
W = shape [3 ]
@@ -98,14 +103,24 @@ def memory_format_differs(shape):
98
103
@staticmethod
99
104
def is_channel_reshape (input_shape , output_shape ):
100
105
"""Returns true if the reshape changes the channel dimension"""
101
- if not len (input_shape ) == len (output_shape ) == 4 :
106
+ if not (
107
+ (len (input_shape ) == len (output_shape ) and (len (output_shape ) in (4 , 5 )))
108
+ or (len (input_shape ) == 4 and len (output_shape ) == 5 )
109
+ or (len (input_shape ) == 5 and len (output_shape ) == 4 )
110
+ ):
102
111
return False
103
112
104
- C_old = input_shape [1 ]
105
- C_new = output_shape [1 ]
113
+ C_old = input_shape [- 3 ]
114
+ C_new = output_shape [- 3 ]
106
115
107
- N_new = output_shape [0 ]
108
- N_old = input_shape [0 ]
116
+ N_new = (
117
+ output_shape [0 ]
118
+ if len (output_shape ) == 4
119
+ else output_shape [0 ] * output_shape [1 ]
120
+ )
121
+ N_old = (
122
+ input_shape [0 ] if len (input_shape ) == 4 else input_shape [0 ] * input_shape [1 ]
123
+ )
109
124
110
125
return (N_old != N_new ) or (C_old != C_new )
111
126
@@ -119,7 +134,11 @@ def insert_input_transpose(node, input_node, graph_module):
119
134
torch .ops .passthrough_to_tosa ._transpose .default ,
120
135
args = (
121
136
input_node ,
122
- list (AnnotateChannelsLastDimOrder .NHWC_inverse_order ),
137
+ list (
138
+ AnnotateChannelsLastDimOrder .NNHWC_inverse_order
139
+ if len (get_first_fake_tensor (input_node ).size ()) == 5
140
+ else AnnotateChannelsLastDimOrder .NHWC_inverse_order
141
+ ),
123
142
),
124
143
quantize = quantize ,
125
144
q_params = q_params ,
@@ -137,15 +156,28 @@ def insert_output_transpose(node, graph_module):
137
156
permute_node = create_node (
138
157
graph_module .graph ,
139
158
torch .ops .passthrough_to_tosa ._transpose .default ,
140
- args = (node , list (AnnotateChannelsLastDimOrder .NHWC_order )),
159
+ args = (
160
+ node ,
161
+ list (
162
+ AnnotateChannelsLastDimOrder .NNHWC_order
163
+ if len (get_first_fake_tensor (node ).size ()) == 5
164
+ else AnnotateChannelsLastDimOrder .NHWC_order
165
+ ),
166
+ ),
141
167
)
142
168
permute_node .meta ["tosa_dim_order" ] = (
143
- AnnotateChannelsLastDimOrder .NHWC_order
169
+ AnnotateChannelsLastDimOrder .NNHWC_order
170
+ if len (get_first_fake_tensor (node ).size ()) == 5
171
+ else AnnotateChannelsLastDimOrder .NHWC_order
172
+ )
173
+ permute_node .meta ["val" ] = get_first_fake_tensor (node ).permute (
174
+ AnnotateChannelsLastDimOrder .NNHWC_order
175
+ if len (get_first_fake_tensor (node ).size ()) == 5
176
+ else AnnotateChannelsLastDimOrder .NHWC_order
144
177
)
145
- permute_node .meta ["val " ] = node . meta [ "val" ]. permute (
146
- AnnotateChannelsLastDimOrder . NHWC_order
178
+ node .meta ["tosa_dim_order " ] = tuple (
179
+ range ( len ( get_first_fake_tensor ( node ). size ()))
147
180
)
148
- node .meta ["tosa_dim_order" ] = (0 , 1 , 2 , 3 )
149
181
users = [user for user in node .users if user != permute_node ]
150
182
for user in users :
151
183
user .replace_input_with (node , permute_node )
@@ -159,8 +191,8 @@ def insert_output_transpose(node, graph_module):
159
191
def _insert_view_transpose (
160
192
input_shape , output_shape , node , input_node , graph_module
161
193
):
162
- nchw_to_nhwc = len (input_shape ) < 4 and len (output_shape ) = = 4
163
- nhwc_to_nchw = len (input_shape ) = = 4 and len (output_shape ) < 4
194
+ nchw_to_nhwc = len (input_shape ) < 4 and len (output_shape ) > = 4
195
+ nhwc_to_nchw = len (input_shape ) > = 4 and len (output_shape ) < 4
164
196
channel_reshape = AnnotateChannelsLastDimOrder .is_channel_reshape (
165
197
output_shape , input_shape
166
198
)
@@ -178,11 +210,11 @@ def _insert_view_transpose(
178
210
179
211
def insert_tosa_transposes (self , graph_module : torch .fx .GraphModule ):
180
212
"""
181
- Transposes are needed for operators transforming the input to a different rank, as 4D-tensors are assumed to be in NHWC-format, whereas all other are in NCHW format.
213
+ Transposes are needed for operators transforming the input to a different rank, as 4D and 5D -tensors are assumed to be in (N) NHWC-format, whereas all other are in (N) NCHW format.
182
214
This is relevant for the following cases:
183
- - view: <4D -> 4D
184
- - view: 4D -> <4D
185
- Additionally, a 4D->4D view operation acting on the channel dimension currently needs to be performed in NCHW format, leadning to one extra input and output transpose for this case.
215
+ - view: <4D -> >= 4D
216
+ - view: >= 4D -> <4D
217
+ Additionally, a 4D/5D ->4D/5D view operation acting on the channel dimension currently needs to be performed in (N) NCHW format, leadning to one extra input and output transpose for this case.
186
218
187
219
Transposes can be avoided for shapes where there is no difference in actual memory, e.g for
188
220
- H == W == 1
@@ -212,12 +244,13 @@ def call(self, graph_module: torch.fx.GraphModule):
212
244
# The weights of TOSA DEPTHWISE_CONV2D have shape (H, W, C, M) which corresponds to
213
245
# dim_order = (2, 3, 0, 1) (https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d).
214
246
dim_order = self .HWCM_order
247
+ elif node_data .dim () == 5 :
248
+ dim_order = self .NNHWC_order # type: ignore[assignment]
215
249
else :
216
250
dim_order = tuple (range (node_data .dim ())) # type: ignore[assignment]
217
251
node .meta ["tosa_dim_order" ] = dim_order
218
- # Take care of cases when:
219
- # 4D (NHWC) -> >4D (NCH)
220
- # 3D (NCH) -> 4D (NHWC)
252
+ # Insert TOSA transposes to convert between (N)NCHW and (N)NHWC format.
253
+ # See insert_tosa_transposes for insertion conditions.
221
254
self .insert_tosa_transposes (graph_module )
222
255
graph_module .recompile ()
223
256
graph_module = super ().call (graph_module ).graph_module
0 commit comments