Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/nxp/backend/edge_program_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
exir_ops.edge.aten.addmm.default: AddMMConverter, # noqa F405
exir_ops.edge.aten.add.Tensor: AddTensorConverter, # noqa F405
exir_ops.edge.aten.avg_pool2d.default: AvgPool2dConverter, # noqa F405
exir_ops.edge.aten.bmm.default: BMMConverter, # noqa F405
exir_ops.edge.aten.cat.default: CatConverter, # noqa F405
exir_ops.edge.aten.clamp.default: ClampConverter, # noqa F405
exir_ops.edge.aten.clone.default: CloneConverter, # noqa F405
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.avg_pool_2d_converter import (
AvgPool2dConverter,
)
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.bmm_converter import (
BMMConverter,
)
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.cat_converter import (
CatConverter,
)
Expand Down Expand Up @@ -96,6 +99,7 @@
"AddMMConverter",
"AddTensorConverter",
"AvgPool2dConverter",
"BMMConverter",
"CatConverter",
"ClampConverter",
"CloneConverter",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright 2026 NXP
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from executorch.backends.nxp.backend.data_format import NXP_NODE_FORMAT
from executorch.backends.nxp.backend.edge_helper import input_rank
from executorch.backends.nxp.backend.ir.converter.conversion import translator
from executorch.backends.nxp.backend.ir.converter.conversion.common import OpsList
from executorch.backends.nxp.backend.ir.converter.node_converter import (
CustomDelegationOptions,
NodeConverter,
)
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import (
batch_mat_mul_options,
)
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
from torch.fx import Node
from torch.nn import Parameter


class BMMConverter(NodeConverter):
@staticmethod
def _is_supported_in_IR(
node: Node,
parameters_mapping: dict[str, Parameter],
custom_delegation_options: CustomDelegationOptions,
) -> bool:
if len(node.all_input_nodes) != 2:
return False

if input_rank(node, 0) != 3 or input_rank(node, 1) != 3:
return False

return True

@staticmethod
def _get_channels_last_shape(node: Node) -> list[int]:
input_shape = node.meta["val"].shape

if node.meta[NXP_NODE_FORMAT].is_channels_first():
input_shape = translator.apply_permutation_to(
input_shape,
translator.create_channels_first_to_channels_last_permutation(
len(input_shape)
),
)

return input_shape

@staticmethod
def _is_supported_on_target(
node: Node,
neutron_target_spec: NeutronTargetSpec,
parameters_mapping: dict[str, Parameter],
custom_delegation_options: CustomDelegationOptions,
) -> bool:
_, w1, c1 = BMMConverter._get_channels_last_shape(node.args[0])
_, w2, c2 = BMMConverter._get_channels_last_shape(node.args[1])

num_macs = neutron_target_spec.get_num_macs()

# The Neutron converter requires that every dimension participating in a
# multiplication is divisible by NUM_MACS. If any of the relevant dimensions
# (w1, c1, w2, c2) violates this constraint, the pattern is not supported.
if not all(m % num_macs == 0 for m in [w1, c1, w2, c2]):
return False

return True

def convert(self, node: Node):
"""Convert the `aten.bmm` operator to TFLite `BatchMatMul`."""
self.assert_convertible(node)

t_op = self._create_tflite_op_with_io_tensors(node)

# We set adj_x = adj_y = False because neither the left-hand side (lhs) nor
# the right-hand side (rhs) needs to be transposed for correct delegation.
#
# We also set asymmetric_quantize_inputs = False. This is faster, but it
# requires that both input tensors are quantized symmetrically.
t_op.builtin_options = batch_mat_mul_options.BatchMatMul(False, False, False)

x1 = t_op.tmp_inputs[0]
x2 = t_op.tmp_inputs[1]
y = t_op.tmp_outputs[0]

# Assign the operator its TFLite inputs and outputs
t_op.tmp_inputs = [x1, x2]
t_op.tmp_outputs = [y]

ops = OpsList(middle_op=t_op)

self.builder.append_operators(ops.flatten())
1 change: 1 addition & 0 deletions backends/nxp/neutron_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def tag_qdq_clusters(self, nodes: list[torch.fx.Node]):
exir_ops.edge.aten.addmm.default: AddMMConverter, # noqa F405
exir_ops.edge.aten.add.Tensor: AddTensorConverter, # noqa F405
exir_ops.edge.aten.avg_pool2d.default: AvgPool2dConverter, # noqa F405
exir_ops.edge.aten.bmm.default: BMMConverter, # noqa F405
exir_ops.edge.aten.cat.default: CatConverter, # noqa F405
exir_ops.edge.aten.clamp.default: ClampConverter, # noqa F405
exir_ops.edge.aten.clone.default: CloneConverter, # noqa F405
Expand Down
2 changes: 2 additions & 0 deletions backends/nxp/quantizer/neutron_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
AddTensorPattern,
AvgPoolPattern,
BatchNormPattern,
BMMPattern,
CatPattern,
ClampPattern,
Conv1dPattern,
Expand Down Expand Up @@ -259,6 +260,7 @@ def __init__(self, neutron_target_spec: NeutronTargetSpec, is_qat: bool = False)
OpQuantizer(AddmmPattern(self, is_qat=is_qat), static_fc_qconfig),
OpQuantizer(AvgPoolPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(BatchNormPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(BMMPattern(self, is_qat=is_qat), static_qconfig),
OpQuantizer(CatPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(ClampPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(Conv1dPattern(is_qat=is_qat), static_qconfig),
Expand Down
47 changes: 47 additions & 0 deletions backends/nxp/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from torch.fx import Node
from torchao.quantization.pt2e import (
FakeQuantize,
MinMaxObserver,
MovingAverageMinMaxObserver,
MovingAveragePerChannelMinMaxObserver,
PerChannelMinMaxObserver,
)
Expand Down Expand Up @@ -298,6 +300,51 @@ def get_anchors(
)


class BMMPattern(QuantizationPattern):
"""
Quantizer for BatchMatMul operator.
"""

def __init__(self, neutron_quantizer, is_qat: bool = False):
super().__init__(is_qat=is_qat)

self.neutron_quantizer = neutron_quantizer
self.neutron_target_info = (
self.neutron_quantizer.neutron_target_spec.neutron_target_info
)

def partition_types(self) -> list[torch.nn.Module]:
return [torch.ops.aten.bmm.default]

def _make_qspec(self):
observer = (
FakeQuantize.with_args(observer=MovingAverageMinMaxObserver)
if self.is_qat
else MinMaxObserver
)
return QuantizationSpec(
dtype=torch.int8,
quant_min=-128,
quant_max=127,
qscheme=torch.per_tensor_symmetric,
observer_or_fake_quant_ctr=observer,
)

def get_anchors(
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
) -> PartitionAnchors | None:
bmm_node = fused_partition[0].nodes[-1]

return PartitionAnchors(
inputs=[
(bmm_node, NodeArgsIdx(0), self._make_qspec()),
(bmm_node, NodeArgsIdx(1), self._make_qspec()),
],
biases=[],
output=[(bmm_node,)],
)


class SubTensorPattern(QuantizationPattern):
"""
Quantization pattern for Sub Tensor quantization. Accepts 1 or 2 input nodes.
Expand Down
144 changes: 144 additions & 0 deletions backends/nxp/tests/ir/converter/node_converter/test_bmm_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Copyright 2026 NXP
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np
import pytest
import torch
from executorch.backends.nxp.backend.edge_program_converter import (
EdgeProgramToIRConverter,
)
from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program
from executorch.backends.nxp.tests.executors import (
convert_run_compare,
graph_contains_any_of_ops,
)
from executorch.backends.nxp.tests.models import BatchMatMulConvModel, BatchMatMulModel
from executorch.backends.nxp.tests.use_qat import * # noqa F403
from executorch.exir.dialects._ops import ops as exir_ops


@pytest.fixture(autouse=True)
def reseed_model_per_test_run():
torch.manual_seed(23)
np.random.seed(23)


# noinspection PyProtectedMember
ExecutorchDelegateCall = torch.ops.higher_order.executorch_call_delegate
Bmm = exir_ops.edge.aten.bmm.default


@pytest.mark.parametrize(
"input_shape_x1, input_shape_x2",
[
pytest.param((1, 8, 16), (1, 16, 24), id="3D, one batch."),
pytest.param((4, 8, 16), (4, 16, 24), id="3D, more batches."),
],
)
def test_convert_bmm__supported(mocker, input_shape_x1, input_shape_x2):
model = BatchMatMulModel()

converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")
delegated_ep = to_quantized_edge_program(
model, [input_shape_x1, input_shape_x2], use_qat=use_qat,
).exported_program()

# Make sure the `bmm` was delegated.
assert graph_contains_any_of_ops(delegated_ep.graph, [ExecutorchDelegateCall])
assert not graph_contains_any_of_ops(delegated_ep.graph, [Bmm])

# Verify correct behavior of the converted NeutronIR model.
intermediate_ep = converter_spy.call_args.args[1]
neutron_ir_model, _ = converter_spy.spy_return

input_data_1 = (
np.random.random(input_shape_x1).astype(np.float32) * 256.0 - 128.0
).astype(np.int8)
input_data_2 = (
np.random.random(input_shape_x2).astype(np.float32) * 256.0 - 128.0
).astype(np.int8)

# Make sure the tested program contains the `bmm`.
assert graph_contains_any_of_ops(intermediate_ep.graph, [Bmm])

# Verify that the delegated `bmm` node produces correct results
# The delegated `bmm` runs with a numerical tolerance of atol = 1
convert_run_compare(
intermediate_ep,
tfl_model=neutron_ir_model,
input_data={
0: input_data_1,
1: input_data_2,
},
atol=1,
)


@pytest.mark.parametrize(
"input_shape_x1, input_shape_x2",
[
pytest.param((1, 7, 16), (1, 16, 24), id="3D, x1_C not divisible by NUM_MACS."),
pytest.param(
(1, 8, 7), (1, 7, 24), id="3D, x1_W (and x2_C) not divisible by NUM_MACS."
),
pytest.param((1, 8, 16), (1, 16, 7), id="3D, x2_W not divisible by NUM_MACS."),
],
)
def test_convert_bmm__unsupported(input_shape_x1, input_shape_x2):
model = BatchMatMulModel()

delegated_ep = to_quantized_edge_program(
model, [input_shape_x1, input_shape_x2], use_qat=use_qat,
).exported_program()

# Make sure the `bmm` was NOT delegated.
assert graph_contains_any_of_ops(delegated_ep.graph, [Bmm])


@pytest.mark.parametrize(
"conv_input_shape, bmm_input_shape",
[
pytest.param((4, 8, 16), (4, 16, 16), id="3D with conv. quant"),
],
)
def test_convert_bmm__conv_quant(mocker, conv_input_shape, bmm_input_shape):
conv_channels = conv_input_shape[1]
bmm_channels = bmm_input_shape[1]
model = BatchMatMulConvModel(in_channels=conv_channels, out_channels=bmm_channels)

converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")
delegated_ep = to_quantized_edge_program(
model, [conv_input_shape, bmm_input_shape], use_qat=use_qat,
).exported_program()

# Make sure the `bmm` was delegated.
assert graph_contains_any_of_ops(delegated_ep.graph, [ExecutorchDelegateCall])
assert not graph_contains_any_of_ops(delegated_ep.graph, [Bmm])

# Verify correct behavior of the converted NeutronIR model.
bmm_intermediate_ep = converter_spy.call_args.args[1]
bmm_neutron_ir_model, _ = converter_spy.spy_return

bmm_input_data_1 = (
np.random.random(bmm_input_shape).astype(np.float32) * 256.0 - 128.0
).astype(np.int8)
bmm_input_data_2 = (
np.random.random(bmm_input_shape).astype(np.float32) * 256.0 - 128.0
).astype(np.int8)

# Make sure the tested program contains the `bmm`.
assert graph_contains_any_of_ops(bmm_intermediate_ep.graph, [Bmm])

# Verify that the delegated `bmm` node produces correct results
# The delegated `bmm` runs with a numerical tolerance of atol = 1
convert_run_compare(
bmm_intermediate_ep,
tfl_model=bmm_neutron_ir_model,
input_data={
0: bmm_input_data_1,
1: bmm_input_data_2,
},
atol=1,
)
24 changes: 24 additions & 0 deletions backends/nxp/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,3 +800,27 @@ def forward(self, x, y):
return torch.squeeze(x + y)
else:
return torch.squeeze(x + y, self.dim)


class BatchMatMulModel(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
return torch.bmm(x, y)


class BatchMatMulConvModel(torch.nn.Module):
def __init__(self, in_channels=16, out_channels=8):
super().__init__()
self.conv = Conv1dModule(
in_channels=in_channels,
out_channels=out_channels,
stride=1,
padding=1,
kernel_size=3,
)

def forward(self, x, y):
x = self.conv(x)
return torch.bmm(x, y)
1 change: 1 addition & 0 deletions docs/source/backends/nxp/op-support.csv
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ aten._adaptive_avg_pool2d.default,int8,static int8,"ceil_mode=False, count_inclu
aten.addmm.default,int8,static int8,2D tensor only
aten.add.Tensor,int8,static int8,"alpha = 1, input tensor of name rank"
aten.avg_pool2d.default,int8,static int8,"ceil_mode=False, count_include_pad=False, divisor_override=False"
aten.bmm.default,int8,static int8,"width and channels dim of both args %8 = 0"
aten.cat.default,int8,static int8,"input_channels % 8 = 0, output_channels %8 = 0"
aten.clamp.default,int8,static int8,"Bounds = (-1, 1) or (0, 1) or (0, 6) or (0, None)"
aten.clone.default,int8,static int8,
Expand Down
Loading