Skip to content

Commit 360c9bb

Browse files
authored
NXP backend: Create NeutronAtenPassManager with initial BatchNorm fusing passes (#10579)
1 parent 33d4790 commit 360c9bb

File tree

6 files changed

+558
-12
lines changed

6 files changed

+558
-12
lines changed
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from typing import Optional
6+
7+
import torch
8+
from torch.export.unflatten import _assign_attr, _AttrKind
9+
from torch.fx import GraphModule, Node
10+
from torch.fx.passes.infra.pass_base import PassBase, PassResult
11+
from torch.nn.parameter import Parameter
12+
from torch.nn.utils import fuse_conv_bn_weights
13+
14+
15+
class FuseBatchNormWithConvPass(PassBase):
16+
"""The executorch batch normalization carries out the following computation [1].
17+
18+
(x - mean) / sqrt(var + eps) * W + B
19+
20+
Which can be expressed as
21+
22+
x * (W / sqrt(var + eps)) + (B - mean * (W / sqrt(var + eps)))
23+
24+
So the batch norm can be done as 1 multiplication and 1 addition, provided that the parameters are static,
25+
and the terms can be precomputed. If there is a `Conv` operator before the batch normalization, this scale and
26+
bias can be statically integrated into the weights and bias of the `Conv`, which allows the batch norm to be
27+
completely removed.
28+
29+
30+
31+
┌─────────────▼─────────────┐
32+
│ aten.conv1d | aten.conv2d │
33+
└─────────────┬─────────────┘
34+
│ │
35+
┌─────────────────────▼─────────────────────┐ replace with ┌─────────────▼─────────────┐
36+
│ aten.batch_norm │ ──────────────► │ aten.conv1d | aten.conv2d │
37+
└─────────────────────┬─────────────────────┘ └─────────────┬─────────────┘
38+
│ ▼
39+
40+
41+
[1] https://github.com/pytorch/executorch/blob/v0.5.0-rc2/kernels/portable/cpu/op_native_batch_norm.cpp#L118-L128
42+
"""
43+
44+
def _get_tensor_constant_from_node(self, graph_module, node) -> Parameter | None:
45+
"""Get the static data from a given node. If it doesn't have any data, return `None`."""
46+
if node is None or node.op != "get_attr":
47+
return None
48+
49+
target_atoms = node.target.split(".")
50+
attr_itr = graph_module
51+
for atom in target_atoms:
52+
if not hasattr(attr_itr, atom):
53+
return None
54+
attr_itr = getattr(attr_itr, atom)
55+
return attr_itr
56+
57+
def call(self, graph_module: GraphModule) -> Optional[PassResult]:
58+
def _is_batch_norm(node_: Node) -> bool:
59+
return (
60+
node_.op == "call_function"
61+
and node_.target == torch.ops.aten.batch_norm.default
62+
)
63+
64+
def _is_conv(node_: Node):
65+
is_conv = node_.op == "call_function" and node_.target in (
66+
torch.ops.aten.conv1d.default,
67+
torch.ops.aten.conv2d.default,
68+
)
69+
has_single_user = len(node.users) == 1
70+
71+
return is_conv and has_single_user
72+
73+
made_changes = False
74+
75+
if not any(map(_is_batch_norm, graph_module.graph.nodes)):
76+
return PassResult(
77+
graph_module, made_changes
78+
) # No batch norm nodes in the model.
79+
80+
for node in graph_module.graph.nodes:
81+
if not _is_batch_norm(node):
82+
continue # Not BatchNorm.
83+
84+
bn_node = node
85+
86+
if not _is_conv(bn_node.args[0]):
87+
continue # Something other than a Conv node comes before the BatchNorm.
88+
89+
conv_node = bn_node.args[0]
90+
conv_weight_node = conv_node.args[1]
91+
conv_bias_node = conv_node.args[2] if len(conv_node.args) > 2 else None
92+
93+
# conv args: input, weight, bias, stride, padding, dilation, ...
94+
conv_w = self._get_tensor_constant_from_node(graph_module, conv_weight_node)
95+
conv_b = self._get_tensor_constant_from_node(graph_module, conv_bias_node)
96+
97+
# batch norm args: input, weight, bias, running mean, training, running var, momentum, eps
98+
bn_w = self._get_tensor_constant_from_node(graph_module, bn_node.args[1])
99+
bn_b = self._get_tensor_constant_from_node(graph_module, bn_node.args[2])
100+
bn_rm = self._get_tensor_constant_from_node(graph_module, bn_node.args[3])
101+
bn_rv = self._get_tensor_constant_from_node(graph_module, bn_node.args[4])
102+
bn_eps = bn_node.args[7]
103+
104+
if any(
105+
t is None for t in (conv_w, bn_rm, bn_rv)
106+
): # The other inputs can be None.
107+
continue # The data is not static. Leave this BatchNorm as is (probably a rare case).
108+
fused_weight, fused_bias = fuse_conv_bn_weights(
109+
conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b
110+
)
111+
112+
# Update the weight and bias for Conv.
113+
conv_args = list(conv_node.args)
114+
if len(conv_args) == 2:
115+
# Fill in the default bias argument.
116+
conv_args.append(None)
117+
118+
weight_attr_name = conv_weight_node.target
119+
_assign_attr(
120+
fused_weight, graph_module, weight_attr_name, _AttrKind.PARAMETER
121+
)
122+
123+
if conv_bias_node is not None:
124+
bias_attr_name = conv_bias_node.target
125+
_assign_attr(
126+
fused_bias, graph_module, str(bias_attr_name), _AttrKind.PARAMETER
127+
)
128+
else:
129+
# The Conv doesn't have a bias. Create a new one.
130+
bias_attr_name = weight_attr_name + "_bias"
131+
_assign_attr(
132+
fused_bias, graph_module, bias_attr_name, _AttrKind.PARAMETER
133+
)
134+
with graph_module.graph.inserting_before(conv_node):
135+
get_bias_node = graph_module.graph.get_attr(bias_attr_name)
136+
137+
conv_args[2] = get_bias_node
138+
139+
conv_node.args = tuple(conv_args)
140+
141+
# Replace the uses of the BatchNorm with the Conv.
142+
bn_node.replace_all_uses_with(conv_node)
143+
144+
made_changes = True
145+
146+
return PassResult(graph_module, made_changes)
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from typing import Optional
6+
7+
import torch
8+
from torch.export.unflatten import _assign_attr, _AttrKind
9+
from torch.fx import GraphModule, Node
10+
from torch.fx.passes.infra.pass_base import PassBase, PassResult
11+
from torch.nn.parameter import Parameter
12+
from torch.nn.utils import fuse_linear_bn_weights
13+
14+
15+
class FuseBatchNormWithLinearPass(PassBase):
16+
"""The executorch batch normalization carries out the following computation [1].
17+
18+
(x - mean) / sqrt(var + eps) * W + B
19+
20+
Which can be expressed as
21+
22+
x * (W / sqrt(var + eps)) + (B - mean * (W / sqrt(var + eps)))
23+
24+
So the batch norm can be done as 1 multiplication and 1 addition, provided that the parameters are static,
25+
and the terms can be precomputed. If there is a `Linear` operator before the batch normalization, this scale
26+
and bias can be statically integrated into the weights and bias of the `Linear`, which allows the batch norm
27+
to be completely removed.
28+
29+
30+
31+
┌──────▼──────┐
32+
│ aten.linear │
33+
└──────┬──────┘
34+
│ │
35+
┌─────────────────────▼─────────────────────┐ replace with ┌──────▼──────┐
36+
│ aten.batch_norm │ ──────────────► │ aten.linear │
37+
└─────────────────────┬─────────────────────┘ └──────┬──────┘
38+
39+
40+
[1] https://github.com/pytorch/executorch/blob/v0.5.0-rc2/kernels/portable/cpu/op_native_batch_norm.cpp#L118-L128
41+
"""
42+
43+
def _get_tensor_constant_from_node(self, graph_module, node) -> Parameter | None:
44+
"""Get the static data from a given node. If it doesn't have any data, return `None`."""
45+
if node is None or node.op != "get_attr":
46+
return None
47+
48+
target_atoms = node.target.split(".")
49+
attr_itr = graph_module
50+
for atom in target_atoms:
51+
if not hasattr(attr_itr, atom):
52+
return None
53+
attr_itr = getattr(attr_itr, atom)
54+
return attr_itr
55+
56+
def call(self, graph_module: GraphModule) -> Optional[PassResult]:
57+
def _is_batch_norm(node_: Node) -> bool:
58+
return (
59+
node_.op == "call_function"
60+
and node_.target == torch.ops.aten.batch_norm.default
61+
)
62+
63+
def _is_linear(node_: Node):
64+
is_linear = (
65+
node_.op == "call_function"
66+
and node_.target == torch.ops.aten.linear.default
67+
)
68+
has_single_user = len(node.users) == 1
69+
70+
return is_linear and has_single_user
71+
72+
made_changes = False
73+
74+
if not any(map(_is_batch_norm, graph_module.graph.nodes)):
75+
return PassResult(
76+
graph_module, made_changes
77+
) # No batch norm nodes in the model.
78+
79+
for node in graph_module.graph.nodes:
80+
if not _is_batch_norm(node):
81+
continue # Not BatchNorm.
82+
83+
bn_node = node
84+
85+
if not _is_linear(bn_node.args[0]):
86+
continue # Something other than a Linear node comes before the BatchNorm.
87+
88+
linear_node = bn_node.args[0]
89+
linear_weight_node = linear_node.args[1]
90+
linear_bias_node = (
91+
linear_node.args[2] if len(linear_node.args) > 2 else None
92+
)
93+
94+
linear_w = self._get_tensor_constant_from_node(
95+
graph_module, linear_weight_node
96+
)
97+
linear_b = self._get_tensor_constant_from_node(
98+
graph_module, linear_bias_node
99+
)
100+
101+
# batch norm args: input, weight, bias, running mean, training, running var, momentum, eps
102+
bn_w = self._get_tensor_constant_from_node(graph_module, bn_node.args[1])
103+
bn_b = self._get_tensor_constant_from_node(graph_module, bn_node.args[2])
104+
bn_rm = self._get_tensor_constant_from_node(graph_module, bn_node.args[3])
105+
bn_rv = self._get_tensor_constant_from_node(graph_module, bn_node.args[4])
106+
bn_eps = bn_node.args[7]
107+
108+
if any(
109+
t is None for t in (linear_w, bn_w, bn_b, bn_rm, bn_rv)
110+
): # The Linear bias can be None.
111+
continue # The data is not static. Leave this BatchNorm as is (probably a rare case).
112+
fused_weight, fused_bias = fuse_linear_bn_weights(
113+
linear_w, linear_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b
114+
)
115+
116+
# Update the weight and bias for Linear.
117+
linear_args = list(linear_node.args)
118+
if len(linear_args) == 2:
119+
# Fill in the default bias argument.
120+
linear_args.append(None)
121+
122+
weight_attr_name = linear_weight_node.target
123+
_assign_attr(
124+
fused_weight, graph_module, weight_attr_name, _AttrKind.PARAMETER
125+
)
126+
127+
if linear_bias_node is not None:
128+
bias_attr_name = linear_bias_node.target
129+
_assign_attr(
130+
fused_bias, graph_module, str(bias_attr_name), _AttrKind.PARAMETER
131+
)
132+
else:
133+
# The Linear doesn't have a bias. Create a new one.
134+
bias_attr_name = weight_attr_name + "_bias"
135+
_assign_attr(
136+
fused_bias, graph_module, bias_attr_name, _AttrKind.PARAMETER
137+
)
138+
with graph_module.graph.inserting_before(linear_node):
139+
get_bias_node = graph_module.graph.get_attr(bias_attr_name)
140+
141+
linear_args[2] = get_bias_node
142+
143+
linear_node.args = tuple(linear_args)
144+
145+
# Replace the uses of the BatchNorm with the Linear.
146+
bn_node.replace_all_uses_with(linear_node)
147+
148+
made_changes = True
149+
150+
return PassResult(graph_module, made_changes)
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Callable
7+
8+
import torch
9+
10+
from executorch.backends.nxp.aten_passes.fuse_batch_norm_with_conv_pass import (
11+
FuseBatchNormWithConvPass,
12+
)
13+
from executorch.backends.nxp.aten_passes.fuse_batch_norm_with_linear_pass import (
14+
FuseBatchNormWithLinearPass,
15+
)
16+
from executorch.exir.pass_manager import PassManager
17+
from torch import nn
18+
from torch.fx.passes.infra.pass_base import PassResult
19+
20+
PassType = list[type[Callable[[torch.fx.GraphModule], PassResult]]]
21+
22+
23+
class NeutronAtenPassManager(PassManager):
24+
25+
def __init__(self, passes: list[PassType] = None):
26+
passes: list[PassType] = passes or [
27+
FuseBatchNormWithConvPass(),
28+
FuseBatchNormWithLinearPass(),
29+
]
30+
31+
super().__init__(passes)
32+
33+
def __call__(self, module: nn.Module) -> PassResult:
34+
pass_result: PassResult = super().__call__(module)
35+
36+
graph_module = pass_result.graph_module
37+
graph_module.graph.eliminate_dead_code()
38+
graph_module.recompile()
39+
40+
return pass_result

backends/nxp/quantizer/neutron_quantizer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
from typing import List, Optional, Tuple, Union
88

99
import torch
10+
from executorch.backends.nxp.aten_passes.neutron_aten_pass_manager import (
11+
NeutronAtenPassManager,
12+
)
1013

1114
from executorch.backends.nxp.quantizer.patterns import (
1215
AddmmPattern,
@@ -203,4 +206,5 @@ def __init__(self):
203206
def transform_for_annotation(
204207
self, model: torch.fx.GraphModule
205208
) -> torch.fx.GraphModule:
206-
return model
209+
pass_runner = NeutronAtenPassManager()
210+
return pass_runner(model).graph_module

backends/nxp/tests/executorch_pipeline.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,6 @@
88
from executorch import exir
99
from executorch.backends.nxp.neutron_partitioner import NeutronPartitioner
1010
from executorch.backends.nxp.nxp_backend import generate_neutron_compile_spec
11-
12-
# TODO (Robert Kalmar) Uncomment when NXP passes are ported to main
13-
# from executorch.backends.nxp.pytorch_passes.nxp_pytorch_pass_manager import NXPPyTorchPassManager
1411
from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer
1512
from executorch.exir import (
1613
EdgeCompileConfig,
@@ -27,7 +24,7 @@ def _quantize_model(model, calibration_inputs: list[tuple[torch.Tensor]]):
2724
quantizer = NeutronQuantizer()
2825

2926
m = prepare_pt2e(model, quantizer)
30-
for _i, data in enumerate(calibration_inputs):
27+
for data in calibration_inputs:
3128
m(*data)
3229
m = convert_pt2e(m)
3330

@@ -48,14 +45,8 @@ def to_quantized_edge_program(
4845
model, example_input, strict=True
4946
)
5047

51-
# TODO(Robert Kalmar) uncoment when NXP passes are ported to main
52-
# Run pre-processing passes of the float32 aten dialect program.
53-
# pass_manager = NXPPyTorchPassManager(exir_program_aten)
54-
# pass_manager.run() # All passes by default.
55-
56-
exir_program_aten_module = exir_program_aten.module()
5748
exir_program_aten__module_quant = _quantize_model(
58-
exir_program_aten_module, calibration_inputs
49+
exir_program_aten.module(), calibration_inputs
5950
)
6051

6152
compile_spec = generate_neutron_compile_spec(

0 commit comments

Comments
 (0)