Skip to content

NXP backend: Create NeutronAtenPassManager with initial BatchNorm fusing passes #10579

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions backends/nxp/aten_passes/fuse_batch_norm_with_conv_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright 2025 NXP
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional

import torch
from torch.export.unflatten import _assign_attr, _AttrKind
from torch.fx import GraphModule, Node
from torch.fx.passes.infra.pass_base import PassBase, PassResult
from torch.nn.parameter import Parameter
from torch.nn.utils import fuse_conv_bn_weights


class FuseBatchNormWithConvPass(PassBase):
"""The executorch batch normalization carries out the following computation [1].

(x - mean) / sqrt(var + eps) * W + B

Which can be expressed as

x * (W / sqrt(var + eps)) + (B - mean * (W / sqrt(var + eps)))

So the batch norm can be done as 1 multiplication and 1 addition, provided that the parameters are static,
and the terms can be precomputed. If there is a `Conv` operator before the batch normalization, this scale and
bias can be statically integrated into the weights and bias of the `Conv`, which allows the batch norm to be
completely removed.


┌─────────────▼─────────────┐
│ aten.conv1d | aten.conv2d │
└─────────────┬─────────────┘
│ │
┌─────────────────────▼─────────────────────┐ replace with ┌─────────────▼─────────────┐
│ aten.batch_norm │ ──────────────► │ aten.conv1d | aten.conv2d │
└─────────────────────┬─────────────────────┘ └─────────────┬─────────────┘
│ ▼

[1] https://github.com/pytorch/executorch/blob/v0.5.0-rc2/kernels/portable/cpu/op_native_batch_norm.cpp#L118-L128
"""

def _get_tensor_constant_from_node(self, graph_module, node) -> Parameter | None:
"""Get the static data from a given node. If it doesn't have any data, return `None`."""
if node is None or node.op != "get_attr":
return None

target_atoms = node.target.split(".")
attr_itr = graph_module
for atom in target_atoms:
if not hasattr(attr_itr, atom):
return None
attr_itr = getattr(attr_itr, atom)
return attr_itr

def call(self, graph_module: GraphModule) -> Optional[PassResult]:
def _is_batch_norm(node_: Node) -> bool:
return (
node_.op == "call_function"
and node_.target == torch.ops.aten.batch_norm.default
)

def _is_conv(node_: Node):
is_conv = node_.op == "call_function" and node_.target in (
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
)
has_single_user = len(node.users) == 1

return is_conv and has_single_user

made_changes = False

if not any(map(_is_batch_norm, graph_module.graph.nodes)):
return PassResult(
graph_module, made_changes
) # No batch norm nodes in the model.

for node in graph_module.graph.nodes:
if not _is_batch_norm(node):
continue # Not BatchNorm.

bn_node = node

if not _is_conv(bn_node.args[0]):
continue # Something other than a Conv node comes before the BatchNorm.

conv_node = bn_node.args[0]
conv_weight_node = conv_node.args[1]
conv_bias_node = conv_node.args[2] if len(conv_node.args) > 2 else None

# conv args: input, weight, bias, stride, padding, dilation, ...
conv_w = self._get_tensor_constant_from_node(graph_module, conv_weight_node)
conv_b = self._get_tensor_constant_from_node(graph_module, conv_bias_node)

# batch norm args: input, weight, bias, running mean, training, running var, momentum, eps
bn_w = self._get_tensor_constant_from_node(graph_module, bn_node.args[1])
bn_b = self._get_tensor_constant_from_node(graph_module, bn_node.args[2])
bn_rm = self._get_tensor_constant_from_node(graph_module, bn_node.args[3])
bn_rv = self._get_tensor_constant_from_node(graph_module, bn_node.args[4])
bn_eps = bn_node.args[7]

if any(
t is None for t in (conv_w, bn_rm, bn_rv)
): # The other inputs can be None.
continue # The data is not static. Leave this BatchNorm as is (probably a rare case).
fused_weight, fused_bias = fuse_conv_bn_weights(
conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b
)

# Update the weight and bias for Conv.
conv_args = list(conv_node.args)
if len(conv_args) == 2:
# Fill in the default bias argument.
conv_args.append(None)

weight_attr_name = conv_weight_node.target
_assign_attr(
Copy link
Contributor

@digantdesai digantdesai May 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can do this without using internal methods, and insert new params and delete old params https://github.com/pytorch/executorch/blob/main/backends/transforms/utils.py#L65 (assuming you are doing this after export)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would require to propagate exported program (via constructor?) to the pass. I don't find that particularly clean because passes are built upon GraphModules as you've mentioned in some of the previous PR's. We will end up with ExportedProgram being passed via constructor and GraphModule within call() function what doesn't sound right to me.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I guess this is for when you do run this pass in delegate.preprocess

fused_weight, graph_module, weight_attr_name, _AttrKind.PARAMETER
)

if conv_bias_node is not None:
bias_attr_name = conv_bias_node.target
_assign_attr(
fused_bias, graph_module, str(bias_attr_name), _AttrKind.PARAMETER
)
else:
# The Conv doesn't have a bias. Create a new one.
bias_attr_name = weight_attr_name + "_bias"
_assign_attr(
fused_bias, graph_module, bias_attr_name, _AttrKind.PARAMETER
)
with graph_module.graph.inserting_before(conv_node):
get_bias_node = graph_module.graph.get_attr(bias_attr_name)

conv_args[2] = get_bias_node

conv_node.args = tuple(conv_args)

# Replace the uses of the BatchNorm with the Conv.
bn_node.replace_all_uses_with(conv_node)

made_changes = True

return PassResult(graph_module, made_changes)
150 changes: 150 additions & 0 deletions backends/nxp/aten_passes/fuse_batch_norm_with_linear_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Copyright 2025 NXP
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional

import torch
from torch.export.unflatten import _assign_attr, _AttrKind
from torch.fx import GraphModule, Node
from torch.fx.passes.infra.pass_base import PassBase, PassResult
from torch.nn.parameter import Parameter
from torch.nn.utils import fuse_linear_bn_weights


class FuseBatchNormWithLinearPass(PassBase):
"""The executorch batch normalization carries out the following computation [1].

(x - mean) / sqrt(var + eps) * W + B

Which can be expressed as

x * (W / sqrt(var + eps)) + (B - mean * (W / sqrt(var + eps)))

So the batch norm can be done as 1 multiplication and 1 addition, provided that the parameters are static,
and the terms can be precomputed. If there is a `Linear` operator before the batch normalization, this scale
and bias can be statically integrated into the weights and bias of the `Linear`, which allows the batch norm
to be completely removed.


┌──────▼──────┐
│ aten.linear │
└──────┬──────┘
│ │
┌─────────────────────▼─────────────────────┐ replace with ┌──────▼──────┐
│ aten.batch_norm │ ──────────────► │ aten.linear │
└─────────────────────┬─────────────────────┘ └──────┬──────┘

[1] https://github.com/pytorch/executorch/blob/v0.5.0-rc2/kernels/portable/cpu/op_native_batch_norm.cpp#L118-L128
"""

def _get_tensor_constant_from_node(self, graph_module, node) -> Parameter | None:
"""Get the static data from a given node. If it doesn't have any data, return `None`."""
if node is None or node.op != "get_attr":
return None

target_atoms = node.target.split(".")
attr_itr = graph_module
for atom in target_atoms:
if not hasattr(attr_itr, atom):
return None
attr_itr = getattr(attr_itr, atom)
return attr_itr

def call(self, graph_module: GraphModule) -> Optional[PassResult]:
def _is_batch_norm(node_: Node) -> bool:
return (
node_.op == "call_function"
and node_.target == torch.ops.aten.batch_norm.default
)

def _is_linear(node_: Node):
is_linear = (
node_.op == "call_function"
and node_.target == torch.ops.aten.linear.default
)
has_single_user = len(node.users) == 1

return is_linear and has_single_user

made_changes = False

if not any(map(_is_batch_norm, graph_module.graph.nodes)):
return PassResult(
graph_module, made_changes
) # No batch norm nodes in the model.

for node in graph_module.graph.nodes:
if not _is_batch_norm(node):
continue # Not BatchNorm.

bn_node = node

if not _is_linear(bn_node.args[0]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here and above, you can be defensive and also make sure args len == 1. there is nothing stopping someone from adding a skip connection after conv/linear and before batchnorm, or some other op. Can't say i've seen it in practice though

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added check for number of users of Conv/Linear op.

continue # Something other than a Linear node comes before the BatchNorm.

linear_node = bn_node.args[0]
linear_weight_node = linear_node.args[1]
linear_bias_node = (
linear_node.args[2] if len(linear_node.args) > 2 else None
)

linear_w = self._get_tensor_constant_from_node(
graph_module, linear_weight_node
)
linear_b = self._get_tensor_constant_from_node(
graph_module, linear_bias_node
)

# batch norm args: input, weight, bias, running mean, training, running var, momentum, eps
bn_w = self._get_tensor_constant_from_node(graph_module, bn_node.args[1])
bn_b = self._get_tensor_constant_from_node(graph_module, bn_node.args[2])
bn_rm = self._get_tensor_constant_from_node(graph_module, bn_node.args[3])
bn_rv = self._get_tensor_constant_from_node(graph_module, bn_node.args[4])
bn_eps = bn_node.args[7]

if any(
t is None for t in (linear_w, bn_w, bn_b, bn_rm, bn_rv)
): # The Linear bias can be None.
continue # The data is not static. Leave this BatchNorm as is (probably a rare case).
fused_weight, fused_bias = fuse_linear_bn_weights(
linear_w, linear_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b
)

# Update the weight and bias for Linear.
linear_args = list(linear_node.args)
if len(linear_args) == 2:
# Fill in the default bias argument.
linear_args.append(None)

weight_attr_name = linear_weight_node.target
_assign_attr(
fused_weight, graph_module, weight_attr_name, _AttrKind.PARAMETER
)

if linear_bias_node is not None:
bias_attr_name = linear_bias_node.target
_assign_attr(
fused_bias, graph_module, str(bias_attr_name), _AttrKind.PARAMETER
)
else:
# The Linear doesn't have a bias. Create a new one.
bias_attr_name = weight_attr_name + "_bias"
_assign_attr(
fused_bias, graph_module, bias_attr_name, _AttrKind.PARAMETER
)
with graph_module.graph.inserting_before(linear_node):
get_bias_node = graph_module.graph.get_attr(bias_attr_name)

linear_args[2] = get_bias_node

linear_node.args = tuple(linear_args)

# Replace the uses of the BatchNorm with the Linear.
bn_node.replace_all_uses_with(linear_node)

made_changes = True

return PassResult(graph_module, made_changes)
40 changes: 40 additions & 0 deletions backends/nxp/aten_passes/neutron_aten_pass_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright 2025 NXP
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Callable

import torch

from executorch.backends.nxp.aten_passes.fuse_batch_norm_with_conv_pass import (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can become unwieldly since there are a few other permutations of "batch norm" fusion, maybe just one file for fuse_batch_norm would be good.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with you. Maybe we can create subdirectories in backend/transforms based on what dialect/ir is pass targeting?

FuseBatchNormWithConvPass,
)
from executorch.backends.nxp.aten_passes.fuse_batch_norm_with_linear_pass import (
FuseBatchNormWithLinearPass,
)
from executorch.exir.pass_manager import PassManager
from torch import nn
from torch.fx.passes.infra.pass_base import PassResult

PassType = list[type[Callable[[torch.fx.GraphModule], PassResult]]]


class NeutronAtenPassManager(PassManager):

def __init__(self, passes: list[PassType] = None):
passes: list[PassType] = passes or [
FuseBatchNormWithConvPass(),
FuseBatchNormWithLinearPass(),
]

super().__init__(passes)

def __call__(self, module: nn.Module) -> PassResult:
pass_result: PassResult = super().__call__(module)

graph_module = pass_result.graph_module
graph_module.graph.eliminate_dead_code()
graph_module.recompile()

return pass_result
6 changes: 5 additions & 1 deletion backends/nxp/quantizer/neutron_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from typing import List, Optional, Tuple, Union

import torch
from executorch.backends.nxp.aten_passes.neutron_aten_pass_manager import (
NeutronAtenPassManager,
)

from executorch.backends.nxp.quantizer.patterns import (
AddmmPattern,
Expand Down Expand Up @@ -202,4 +205,5 @@ def __init__(self):
def transform_for_annotation(
self, model: torch.fx.GraphModule
) -> torch.fx.GraphModule:
return model
pass_runner = NeutronAtenPassManager()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

return pass_runner(model).graph_module
13 changes: 2 additions & 11 deletions backends/nxp/tests/executorch_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
from executorch import exir
from executorch.backends.nxp.neutron_partitioner import NeutronPartitioner
from executorch.backends.nxp.nxp_backend import generate_neutron_compile_spec

# TODO (Robert Kalmar) Uncomment when NXP passes are ported to main
# from executorch.backends.nxp.pytorch_passes.nxp_pytorch_pass_manager import NXPPyTorchPassManager
from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer
from executorch.exir import (
EdgeCompileConfig,
Expand All @@ -27,7 +24,7 @@ def _quantize_model(model, calibration_inputs: list[tuple[torch.Tensor]]):
quantizer = NeutronQuantizer()

m = prepare_pt2e(model, quantizer)
for _i, data in enumerate(calibration_inputs):
for data in calibration_inputs:
m(*data)
m = convert_pt2e(m)

Expand All @@ -48,14 +45,8 @@ def to_quantized_edge_program(
model, example_input, strict=True
)

# TODO(Robert Kalmar) uncoment when NXP passes are ported to main
# Run pre-processing passes of the float32 aten dialect program.
# pass_manager = NXPPyTorchPassManager(exir_program_aten)
# pass_manager.run() # All passes by default.

exir_program_aten_module = exir_program_aten.module()
exir_program_aten__module_quant = _quantize_model(
exir_program_aten_module, calibration_inputs
exir_program_aten.module(), calibration_inputs
)

compile_spec = generate_neutron_compile_spec(
Expand Down
Loading
Loading