66import numpy as np
77import torch
88
9- from executorch .backends .nxp .backend .edge_helper import input_tensor , input_tensor_safe
9+ from executorch .backends .nxp .backend .edge_helper import (
10+ input_tensor ,
11+ input_tensor_safe ,
12+ node_is_effectively_static_tensor ,
13+ )
1014from executorch .backends .nxp .backend .ir .converter .conversion import (
1115 aten_translator ,
1216 common ,
1317)
14- from executorch .backends .nxp .backend .ir .converter .conversion .common import (
15- OpsList ,
16- try_get_input ,
17- )
18+ from executorch .backends .nxp .backend .ir .converter .conversion .common import try_get_input
1819from executorch .backends .nxp .backend .ir .converter .node_converter import (
1920 NodeConverter ,
2021 Target ,
2122)
23+ from executorch .backends .nxp .backend .ir .converter .node_converters .shared import (
24+ conv_utils ,
25+ )
26+ from executorch .backends .nxp .backend .ir .converter .node_converters .shared .conv_utils import (
27+ ConvConversionResult ,
28+ ConvParameters ,
29+ )
2230from executorch .backends .nxp .backend .ir .converter .quantization_utils import (
2331 set_quantization_parameters_to_tensor ,
2432)
33+ from executorch .backends .nxp .backend .ir .converter .tensor_utils import tensor_has_data
2534from executorch .backends .nxp .backend .ir .lib .tflite .TensorType import TensorType
2635from executorch .backends .nxp .backend .ir .tflite_generator import tflite_model
2736from executorch .backends .nxp .backend .ir .tflite_generator .builtin_options import (
2837 conv_2d_options ,
38+ depthwise_conv_2d_options ,
2939)
3040from torch .fx import Node
3141from torch .nn import Parameter
@@ -48,7 +58,29 @@ def _is_supported_in_IR(
4858 if output_padding != [0 , 0 ]:
4959 return False
5060
51- if groups != 1 :
61+ if groups == 1 :
62+ # Regular (pointwise) convolution.
63+ pass
64+
65+ elif conv_utils .group_conv_convertible_as_depthwise (
66+ node , groups
67+ ) and node_is_effectively_static_tensor (node .args [1 ], parameters_mapping ):
68+ # Depthwise convolution.
69+ # Only supported if the weights are static, because TFLite `DepthwiseConv2D` uses permuted weights. In case
70+ # the weights are dynamic, a Transpose operator would have to be added, which is not supported on Neutron.
71+ pass
72+
73+ elif conv_utils .group_conv_convertible_into_multiple_convolutions (node , groups ):
74+ # Group Separable convolution.
75+ # Not supported natively by the eIQ Neutron so Group Separable Convolution.
76+ # In practice it can be computed by splitting the Group Separable Convolution into multiple Pointwise
77+ # Convo it will use the Split and Concat operation. The Concat operation in Neutron Converter
78+ # SDK 25.03 requires the # of channels to be multipy of # of MAC units in the eIQ Neutron.
79+ # For this reason Group Separable Convolution is not delegated by default at this moment.
80+ return False
81+
82+ else :
83+ # All conversion options related to the `group` attribute have been checked and none of them can be used.
5284 return False
5385
5486 if input_tensor_safe (node , 2 ) is None :
@@ -57,71 +89,152 @@ def _is_supported_in_IR(
5789 if weight_tensor .dtype not in [torch .float32 , torch .int8 , torch .uint8 ]:
5890 return False
5991
60- return True
61-
62- def _convert_2d_conv (
63- self , stride , padding , dilation , t_op : tflite_model .Operator
64- ) -> list [tflite_model .Operator ]:
65- ops = OpsList (middle_op = t_op )
66- t_op .builtin_options = conv_2d_options .Conv2D ()
67- common .assign_2d_strides (t_op .builtin_options , stride )
68- common .assign_2d_dilations (t_op .builtin_options , dilation )
69- t_op .builtin_options .padding , explicit_padding = (
70- aten_translator .convert_padding (padding )
71- )
92+ if node .args [0 ].meta ["val" ].shape [0 ] != 1 :
93+ # Only batch size 1 is supported on neutron.
94+ return False
7295
73- if explicit_padding is not None :
74- # Need to prepend a 'Pad' operator, which adds 0s. But these will be included in the computation!
75- ops .add_pre (
76- self .builder .create_pad_operator_before (t_op , 0 , explicit_padding )
77- )
96+ return True
7897
79- input_tensor : tflite_model . Tensor = t_op . tmp_inputs [ 0 ]
80- weight_tensor : tflite_model . Tensor = t_op . tmp_inputs [ 1 ]
81- output_tensor : tflite_model . Tensor = t_op . tmp_outputs [ 0 ]
98+ Stride = Padding = Dilation = OutPadding = list [ int ]
99+ Transposed = bool
100+ Groups = int
82101
83- if (bias_tensor := try_get_input (t_op , 2 )) is None :
102+ @staticmethod
103+ def _get_convolution_arguments (
104+ conv_node : Node ,
105+ ) -> (Stride , Padding , Dilation , Transposed , OutPadding , Groups ):
106+ # The arguments of the conv are:
107+ # [x, w, b, stride, padding, dilation, transposed, output padding, groups]
108+ # https://github.com/pytorch/pytorch/blob/v2.6.0/aten/src/ATen/native/Convolution.cpp#L286-L291
109+ _ , _ , _ , stride , padding , dilation , transposed , out_padding , groups = (
110+ conv_node .args
111+ )
112+ return stride , padding , dilation , transposed , out_padding , groups
113+
114+ # noinspection PyPep8Naming
115+ def _convert_unpadded_2D (
116+ self , t_op : tflite_model .Operator , conv_params : ConvParameters
117+ ) -> conv_utils .ConvConversionResult :
118+ """Convert the `aten.convolution` into TFLite. The `padding` and `builtin_options` must be converter by the
119+ caller.
120+ """
121+ common .assign_2d_strides (t_op .builtin_options , conv_params .stride )
122+ common .assign_2d_dilations (t_op .builtin_options , conv_params .dilation )
123+
124+ x : tflite_model .Tensor = t_op .tmp_inputs [0 ]
125+ w : tflite_model .Tensor = t_op .tmp_inputs [1 ]
126+ y : tflite_model .Tensor = t_op .tmp_outputs [0 ]
127+
128+ if (b := try_get_input (t_op , 2 )) is None :
84129 # Operator has no bias. Convolution aten op can omit it, TFLite can't.
85- output_channels = weight_tensor .shape .vector [0 ]
130+ output_channels = w .shape .vector [0 ]
86131
87- if weight_tensor .type == TensorType .FLOAT32 :
132+ if w .type == TensorType .FLOAT32 :
88133 bias_type = np .dtype (np .float32 )
89- elif weight_tensor .type in [TensorType .INT8 , TensorType .UINT8 ]:
134+ elif w .type in [TensorType .INT8 , TensorType .UINT8 ]:
90135 bias_type = np .dtype (np .int32 )
91136 else :
92137 # Should never happen.
93138 raise NotImplementedError (
94- f"Convolution node with unsupported weight type: { weight_tensor .type } "
139+ f"Convolution node with unsupported weight type: { w .type } "
95140 )
96141
97- bias_tensor = self .builder .create_zeros_tensor (
142+ b = self .builder .create_zeros_tensor (
98143 [output_channels ], "zero_bias" , bias_type , True
99144 )
100145
101146 # Compute scale and zero point for bias tensor
102- input_scale = np .array (input_tensor .quantization .scale .vector )
103- weight_scale = np .array (weight_tensor .quantization .scale .vector )
147+ input_scale = np .array (x .quantization .scale .vector )
148+ weight_scale = np .array (w .quantization .scale .vector )
104149 bias_scale = input_scale * weight_scale
105150 bias_zero_point = np .zeros (weight_scale .shape , dtype = np .int64 )
106151
107152 set_quantization_parameters_to_tensor (
108- bias_tensor , bias_scale , bias_zero_point , quantized_dimension = 0
153+ b , bias_scale , bias_zero_point , quantized_dimension = 0
109154 )
110155
111156 # Assign the operator its TFLite inputs and outputs
112- t_op .tmp_inputs = [input_tensor , weight_tensor , bias_tensor ]
113- t_op .tmp_outputs = [output_tensor ]
157+ t_op .tmp_inputs = [x , w , b ]
158+ t_op .tmp_outputs = [y ]
159+
160+ conversion_result = ConvConversionResult (x , w , b , y )
161+ conversion_result .ops_list .middle_op = t_op
162+
163+ return conversion_result
164+
165+ def _convert_2d_conv (
166+ self , t_op : tflite_model .Operator , conv_params : ConvParameters
167+ ) -> list [tflite_model .Operator ]:
168+ if conv_utils .group_conv_convertible_as_depthwise (
169+ t_op , conv_params .groups
170+ ): # Convert to `DepthwiseConv2D`.
171+ t_op .builtin_options = depthwise_conv_2d_options .DepthwiseConv2D ()
172+
173+ conversion_result = self ._convert_unpadded_2D (t_op , conv_params )
174+ t_op .builtin_options .padding , explicit_padding = (
175+ aten_translator .convert_padding (conv_params .padding )
176+ )
177+ if explicit_padding is not None :
178+ # Need to prepend a 'Pad' operator, which adds 0s.
179+ conversion_result .ops_list .add_pre (
180+ self .builder .create_pad_operator_before (t_op , 0 , explicit_padding )
181+ )
182+
183+ # DepthwiseConv2D expects weights in format [kernel_channels, kernel_height, kernel_width, output_channels]
184+ perm = [3 , 1 , 2 , 0 ]
185+ weight_tensor = conversion_result .conv_weight_tensor
186+ if tensor_has_data (weight_tensor ):
187+ # Transpose cloned tensor statically
188+ t_op .tmp_inputs [1 ] = self .builder .create_transposed_tensor (
189+ weight_tensor , perm
190+ )
191+ else :
192+ raise NotImplementedError ("Dynamic Depthwise Conv weights." )
193+
194+ elif conv_utils .group_conv_convertible_into_multiple_convolutions (
195+ t_op , conv_params .groups
196+ ):
197+ # Note: by default the Group Separable Convolution is rejected by the Neutron Partitioner, see the
198+ # ConvolutionConveter._is_supported_in_IR()
199+ t_op .builtin_options = conv_2d_options .Conv2D ()
200+
201+ return conv_utils .create_separated_convolutions_based_on_group (
202+ t_op ,
203+ conv_params ,
204+ self .builder ,
205+ self ._convert_unpadded_2D ,
206+ conv_utils .conv_op_factory ,
207+ )
208+
209+ else :
210+ # Convert to regular `Conv2D`.
211+ t_op .builtin_options = conv_2d_options .Conv2D ()
212+ conversion_result = self ._convert_unpadded_2D (t_op , conv_params )
213+ t_op .builtin_options .padding , explicit_padding = (
214+ aten_translator .convert_padding (conv_params .padding )
215+ )
216+ if explicit_padding is not None :
217+ # Need to prepend a 'Pad' operator, which adds 0s.
218+ conversion_result .ops_list .add_pre (
219+ self .builder .create_pad_operator_before (t_op , 0 , explicit_padding )
220+ )
114221
115- return ops .flatten ()
222+ return conversion_result . ops_list .flatten ()
116223
117224 def convert (self , node : Node ):
118225 self .assert_convertible (node )
119226
120- stride = node .args [3 ]
121- padding = node .args [4 ]
122- dilation = node .args [5 ]
227+ stride , padding , dilation , _ , _ , groups = self ._get_convolution_arguments (node )
123228
124229 t_op = self ._create_tflite_op_with_io_tensors (node )
125- ops_to_add = self ._convert_2d_conv (stride , padding , dilation , t_op )
230+ conv_params = ConvParameters (stride , padding , dilation , groups )
231+
232+ rank = t_op .tmp_inputs [1 ].shape .len ()
233+ if rank == 4 : # Conv2D
234+ ops_to_add = self ._convert_2d_conv (t_op , conv_params )
235+ else :
236+ raise NotImplementedError (
237+ f"{ rank - 2 } D convolution is not supported."
238+ ) # Should never get here.
126239
127240 self .builder .append_operators (ops_to_add )
0 commit comments