Skip to content

NXP backend: Add support for depthwise and separable convolution. #11215

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions backends/nxp/backend/edge_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch
from torch.fx import Node
from torch.nn import Parameter


def input_tensor(node: Node, input_index: int) -> torch.Tensor:
Expand Down Expand Up @@ -38,3 +39,35 @@ def input_tensor_safe(node: Node, input_index: int) -> torch.Tensor | None:
return None

return input_tensor(node, input_index)


def node_is_static_tensor(node: Node, parameters_mapping: dict[str, Parameter]) -> bool:
"""Return `True` if the given `node` has static data in the `parameters_mapping` dict.
:param node: Tensor node to check for data.
:param parameters_mapping: Dict mapping tensor names to their static data. Should be inferred from the
`state_dict` attribute of an edge program.
"""
return node.name in parameters_mapping.keys()


def node_is_effectively_static_tensor(
node: Node, parameters_mapping: dict[str, Parameter]
) -> bool:
"""Return `True` if the given `node` has static data, or follows after a `Dequantize` node with a static input.
In the IR, the `node` will be turned into a static quantized tensor.
:param node: Tensor node to check for data.
:param parameters_mapping: Dict mapping tensor names to their static data. Should be inferred from the
`state_dict` attribute of an edge program.
"""
if node_is_static_tensor(node, parameters_mapping):
return True

def _is_dequantize(node_: Node) -> bool:
return node_.target.__name__ in {
"quantized_decomposed.dequantize_per_tensor.default",
"quantized_decomposed.dequantize_per_channel.default",
}

return _is_dequantize(node) and node_is_static_tensor(
node.args[0], parameters_mapping
)
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,36 @@
import numpy as np
import torch

from executorch.backends.nxp.backend.edge_helper import input_tensor, input_tensor_safe
from executorch.backends.nxp.backend.edge_helper import (
input_tensor,
input_tensor_safe,
node_is_effectively_static_tensor,
)
from executorch.backends.nxp.backend.ir.converter.conversion import (
aten_translator,
common,
)
from executorch.backends.nxp.backend.ir.converter.conversion.common import (
OpsList,
try_get_input,
)
from executorch.backends.nxp.backend.ir.converter.conversion.common import try_get_input
from executorch.backends.nxp.backend.ir.converter.node_converter import (
NodeConverter,
Target,
)
from executorch.backends.nxp.backend.ir.converter.node_converters.shared import (
conv_utils,
)
from executorch.backends.nxp.backend.ir.converter.node_converters.shared.conv_utils import (
ConvConversionResult,
ConvParameters,
)
from executorch.backends.nxp.backend.ir.converter.quantization_utils import (
set_quantization_parameters_to_tensor,
)
from executorch.backends.nxp.backend.ir.converter.tensor_utils import tensor_has_data
from executorch.backends.nxp.backend.ir.lib.tflite.TensorType import TensorType
from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import (
conv_2d_options,
depthwise_conv_2d_options,
)
from torch.fx import Node
from torch.nn import Parameter
Expand All @@ -48,7 +58,24 @@ def _is_supported_in_IR(
if output_padding != [0, 0]:
return False

if groups != 1:
if groups == 1:
# Regular convolution.
pass

elif conv_utils.group_conv_convertible_as_depthwise(
node, groups
) and node_is_effectively_static_tensor(node.args[1], parameters_mapping):
# Depthwise convolution.
# Only supported if the weights are static, because TFLite `DepthwiseConv2D` uses permuted weights. In case
# the weights are dynamic, a Transpose operator would have to be added, which is not supported on Neutron.
pass

elif conv_utils.group_conv_convertible_into_multiple_convolutions(node, groups):
# Separable convolution. Currently not supported.
return False

else:
# All conversion options related to the `group` attribute have been checked and none of them can be used.
return False

if input_tensor_safe(node, 2) is None:
Expand All @@ -57,71 +84,150 @@ def _is_supported_in_IR(
if weight_tensor.dtype not in [torch.float32, torch.int8, torch.uint8]:
return False

return True

def _convert_2d_conv(
self, stride, padding, dilation, t_op: tflite_model.Operator
) -> list[tflite_model.Operator]:
ops = OpsList(middle_op=t_op)
t_op.builtin_options = conv_2d_options.Conv2D()
common.assign_2d_strides(t_op.builtin_options, stride)
common.assign_2d_dilations(t_op.builtin_options, dilation)
t_op.builtin_options.padding, explicit_padding = (
aten_translator.convert_padding(padding)
)
if node.args[0].meta["val"].shape[0] != 1:
# Only batch size 1 is supported on neutron.
return False

if explicit_padding is not None:
# Need to prepend a 'Pad' operator, which adds 0s. But these will be included in the computation!
ops.add_pre(
self.builder.create_pad_operator_before(t_op, 0, explicit_padding)
)
return True

input_tensor: tflite_model.Tensor = t_op.tmp_inputs[0]
weight_tensor: tflite_model.Tensor = t_op.tmp_inputs[1]
output_tensor: tflite_model.Tensor = t_op.tmp_outputs[0]
Stride = Padding = Dilation = OutPadding = list[int]
Transposed = bool
Groups = int

if (bias_tensor := try_get_input(t_op, 2)) is None:
@staticmethod
def _get_convolution_arguments(
conv_node: Node,
) -> (Stride, Padding, Dilation, Transposed, OutPadding, Groups):
# The arguments of the conv are:
# [x, w, b, stride, padding, dilation, transposed, output padding, groups]
# https://github.com/pytorch/pytorch/blob/v2.6.0/aten/src/ATen/native/Convolution.cpp#L286-L291
_, _, _, stride, padding, dilation, transposed, out_padding, groups = (
conv_node.args
)
return stride, padding, dilation, transposed, out_padding, groups

# noinspection PyPep8Naming
def _convert_unpadded_2D(
self, t_op: tflite_model.Operator, conv_params: ConvParameters
) -> conv_utils.ConvConversionResult:
"""Convert the `aten.convolution` into TFLite. The `padding` and `builtin_options` must be converter by the
caller.
"""
common.assign_2d_strides(t_op.builtin_options, conv_params.stride)
common.assign_2d_dilations(t_op.builtin_options, conv_params.dilation)

x: tflite_model.Tensor = t_op.tmp_inputs[0]
w: tflite_model.Tensor = t_op.tmp_inputs[1]
y: tflite_model.Tensor = t_op.tmp_outputs[0]

if (b := try_get_input(t_op, 2)) is None:
# Operator has no bias. Convolution aten op can omit it, TFLite can't.
output_channels = weight_tensor.shape.vector[0]
output_channels = w.shape.vector[0]

if weight_tensor.type == TensorType.FLOAT32:
if w.type == TensorType.FLOAT32:
bias_type = np.dtype(np.float32)
elif weight_tensor.type in [TensorType.INT8, TensorType.UINT8]:
elif w.type in [TensorType.INT8, TensorType.UINT8]:
bias_type = np.dtype(np.int32)
else:
# Should never happen.
raise NotImplementedError(
f"Convolution node with unsupported weight type: {weight_tensor.type}"
f"Convolution node with unsupported weight type: {w.type}"
)

bias_tensor = self.builder.create_zeros_tensor(
b = self.builder.create_zeros_tensor(
[output_channels], "zero_bias", bias_type, True
)

# Compute scale and zero point for bias tensor
input_scale = np.array(input_tensor.quantization.scale.vector)
weight_scale = np.array(weight_tensor.quantization.scale.vector)
input_scale = np.array(x.quantization.scale.vector)
weight_scale = np.array(w.quantization.scale.vector)
bias_scale = input_scale * weight_scale
bias_zero_point = np.zeros(weight_scale.shape, dtype=np.int64)

set_quantization_parameters_to_tensor(
bias_tensor, bias_scale, bias_zero_point, quantized_dimension=0
b, bias_scale, bias_zero_point, quantized_dimension=0
)

# Assign the operator its TFLite inputs and outputs
t_op.tmp_inputs = [input_tensor, weight_tensor, bias_tensor]
t_op.tmp_outputs = [output_tensor]
t_op.tmp_inputs = [x, w, b]
t_op.tmp_outputs = [y]

conversion_result = ConvConversionResult(x, w, b, y)
conversion_result.ops_list.middle_op = t_op

return conversion_result

def _convert_2d_conv(
self, t_op: tflite_model.Operator, conv_params: ConvParameters
) -> list[tflite_model.Operator]:
if conv_utils.group_conv_convertible_as_depthwise(
t_op, conv_params.groups
): # Convert to `DepthwiseConv2D`.
t_op.builtin_options = depthwise_conv_2d_options.DepthwiseConv2D()

conversion_result = self._convert_unpadded_2D(t_op, conv_params)
t_op.builtin_options.padding, explicit_padding = (
aten_translator.convert_padding(conv_params.padding)
)
if explicit_padding is not None:
# Need to prepend a 'Pad' operator, which adds 0s.
conversion_result.ops_list.add_pre(
self.builder.create_pad_operator_before(t_op, 0, explicit_padding)
)

# DepthwiseConv2D expects weights in format [kernel_channels, kernel_height, kernel_width, output_channels]
perm = [3, 1, 2, 0]
weight_tensor = conversion_result.conv_weight_tensor
if tensor_has_data(weight_tensor):
# Transpose cloned tensor statically
t_op.tmp_inputs[1] = self.builder.create_transposed_tensor(
weight_tensor, perm
)
else:
raise NotImplementedError("Dynamic Depthwise Conv weights.")

elif conv_utils.group_conv_convertible_into_multiple_convolutions(
t_op, conv_params.groups
):
t_op.builtin_options = conv_2d_options.Conv2D()

return conv_utils.create_separated_convolutions_based_on_group(
t_op,
conv_params,
self.builder,
self._convert_unpadded_2D,
conv_utils.conv_op_factory,
)

else:
# Convert to regular `Conv2D`.
t_op.builtin_options = conv_2d_options.Conv2D()
conversion_result = self._convert_unpadded_2D(t_op, conv_params)
t_op.builtin_options.padding, explicit_padding = (
aten_translator.convert_padding(conv_params.padding)
)
if explicit_padding is not None:
# Need to prepend a 'Pad' operator, which adds 0s.
conversion_result.ops_list.add_pre(
self.builder.create_pad_operator_before(t_op, 0, explicit_padding)
)

return ops.flatten()
return conversion_result.ops_list.flatten()

def convert(self, node: Node):
self.assert_convertible(node)

stride = node.args[3]
padding = node.args[4]
dilation = node.args[5]
stride, padding, dilation, _, _, groups = self._get_convolution_arguments(node)

t_op = self._create_tflite_op_with_io_tensors(node)
ops_to_add = self._convert_2d_conv(stride, padding, dilation, t_op)
conv_params = ConvParameters(stride, padding, dilation, groups)

rank = t_op.tmp_inputs[1].shape.len()
if rank == 4: # Conv2D
ops_to_add = self._convert_2d_conv(t_op, conv_params)
else:
raise NotImplementedError(
f"{rank - 2}D convolution is not supported."
) # Should never get here.

self.builder.append_operators(ops_to_add)
Loading
Loading