Skip to content

Commit 931c24c

Browse files
committed
Pull request pytorch#64: Feature/EIEX-90 quantization and conversion of aten native batch norm legit no training
Merge in AITEC/executorch from feature/EIEX-90-quantization-and-conversion-of-aten-_native_batch_norm_legit_no_training to main-nxp * commit '246b61f9e3b8fa50210460cf9f419c7eb670fa8b': Add pre-processing pass to fuse BatchNorm into preceding Linear nodes. Add pre-processing pass to fuse BatchNorm into preceding Conv nodes. Add infrastructure for pre-processing passes of aten programs.
2 parents fc0e620 + 246b61f commit 931c24c

File tree

7 files changed

+520
-4
lines changed

7 files changed

+520
-4
lines changed
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from torch.export.unflatten import _AttrKind, _assign_attr
8+
from torch.fx import Node
9+
from torch.nn.utils import fuse_conv_bn_weights
10+
11+
from executorch.backends.nxp.pytorch_passes.nxp_pytorch_pass import NXPPyTorchPass
12+
13+
14+
class FuseBatchNormWithConvPass(NXPPyTorchPass):
15+
""" The executorch batch normalization carries out the following computation [1].
16+
17+
(x - mean) / (var + eps) * W + B
18+
19+
Which can be expressed as
20+
21+
x * (W / sqrt(var + eps)) + (B - mean * (W / sqrt(var + eps)))
22+
23+
So the batch norm can be done as 1 multiplication and 1 addition, provided that the parameters are static,
24+
and the terms can be precomputed. If there is a `Conv` operator before the batch normalization, this scale and
25+
bias can be statically integrated into the weights and bias of the `Conv`, which allows the batch norm to be
26+
completely removed.
27+
28+
29+
30+
┌─────────────▼─────────────┐
31+
│ aten.conv1d | aten.conv2d │
32+
└─────────────┬─────────────┘
33+
│ │
34+
┌─────────────────────▼─────────────────────┐ replace with ┌─────────────▼─────────────┐
35+
│ aten._native_batch_norm_legit_no_training │ ──────────────► │ aten.conv1d | aten.conv2d │
36+
└─────────────────────┬─────────────────────┘ └─────────────┬─────────────┘
37+
│ ▼
38+
┌─────▼──────┐
39+
│ getitem(0) │
40+
└─────┬──────┘
41+
42+
43+
[1] https://github.com/pytorch/executorch/blob/v0.5.0-rc2/kernels/portable/cpu/op_native_batch_norm.cpp#L118-L128
44+
"""
45+
46+
def run(self) -> bool:
47+
def _is_batch_norm(node_: Node) -> bool:
48+
return node_.op == "call_function" and node_.target == torch.ops.aten._native_batch_norm_legit_no_training.default
49+
50+
def _is_conv(node_: Node):
51+
return node_.op == "call_function" and node_.target in (
52+
torch.ops.aten.conv1d.default,
53+
torch.ops.aten.conv2d.default
54+
)
55+
56+
def _is_getitem(node_: Node) -> bool:
57+
return node_.op == "call_function" and node_.target.__name__ == "getitem"
58+
59+
made_changes = False
60+
61+
if not any(map(_is_batch_norm, self.module.graph.nodes)):
62+
return made_changes # No batch norm nodes in the model.
63+
64+
for node in self.module.graph.nodes:
65+
if not _is_batch_norm(node):
66+
continue # Not BatchNorm.
67+
68+
bn_node = node
69+
if not all(_is_getitem(user) and user.args[1] == 0 for user in bn_node.users):
70+
# Nodes other than just `getitem(0)` follow after the BatchNorm. Probably `getitem` nodes accessing
71+
# other outputs of the BN. After the fusion with a Conv op, only the first output can be accessed.
72+
continue
73+
74+
if not _is_conv(bn_node.args[0]):
75+
continue # Something other than a Conv node comes before the BatchNorm.
76+
77+
conv_node = bn_node.args[0]
78+
conv_weight_node = conv_node.args[1]
79+
conv_bias_node = conv_node.args[2] if len(conv_node.args) > 2 else None
80+
81+
# conv args: input, weight, bias, stride, padding, dilation, ...
82+
conv_w = self.get_tensor_constant_from_node(conv_weight_node)
83+
conv_b = self.get_tensor_constant_from_node(conv_bias_node)
84+
85+
# batch norm legit no training args: input, weight, bias, running mean, running var, momentum, eps
86+
bn_w = self.get_tensor_constant_from_node(bn_node.args[1])
87+
bn_b = self.get_tensor_constant_from_node(bn_node.args[2])
88+
bn_rm = self.get_tensor_constant_from_node(bn_node.args[3])
89+
bn_rv = self.get_tensor_constant_from_node(bn_node.args[4])
90+
bn_eps = bn_node.args[6]
91+
92+
if any(t is None for t in (conv_w, bn_rm, bn_rv)): # The other inputs can be None.
93+
continue # The data is not static. Leave this BatchNorm as is (probably a rare case).
94+
fused_weight, fused_bias = fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b)
95+
96+
# Update the weight and bias for Conv.
97+
conv_args = list(conv_node.args)
98+
if len(conv_args) == 2:
99+
# Fill in the default bias argument.
100+
conv_args.append(None)
101+
102+
weight_attr_name = conv_weight_node.target
103+
_assign_attr(fused_weight, self.module, weight_attr_name, _AttrKind.PARAMETER)
104+
105+
if conv_bias_node is not None:
106+
bias_attr_name = conv_bias_node.target
107+
_assign_attr(fused_bias, self.module, str(bias_attr_name), _AttrKind.PARAMETER)
108+
else:
109+
# The Conv doesn't have a bias. Create a new one.
110+
bias_attr_name = weight_attr_name + "_bias"
111+
_assign_attr(fused_bias, self.module, bias_attr_name, _AttrKind.PARAMETER)
112+
with self.module.graph.inserting_before(conv_node):
113+
get_bias_node = self.module.graph.get_attr(bias_attr_name)
114+
115+
conv_args[2] = get_bias_node
116+
117+
conv_node.args = tuple(conv_args)
118+
119+
# Replace the uses of the BatchNorm with the Conv.
120+
for user in bn_node.users:
121+
user.replace_all_uses_with(conv_node)
122+
123+
made_changes = True
124+
125+
return made_changes
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from torch.export.unflatten import _AttrKind, _assign_attr
8+
from torch.fx import Node
9+
from torch.nn.utils import fuse_linear_bn_weights
10+
11+
from executorch.backends.nxp.pytorch_passes.nxp_pytorch_pass import NXPPyTorchPass
12+
13+
14+
class FuseBatchNormWithLinearPass(NXPPyTorchPass):
15+
""" The executorch batch normalization carries out the following computation [1].
16+
17+
(x - mean) / (var + eps) * W + B
18+
19+
Which can be expressed as
20+
21+
x * (W / sqrt(var + eps)) + (B - mean * (W / sqrt(var + eps)))
22+
23+
So the batch norm can be done as 1 multiplication and 1 addition, provided that the parameters are static,
24+
and the terms can be precomputed. If there is a `Linear` operator before the batch normalization, this scale
25+
and bias can be statically integrated into the weights and bias of the `Linear`, which allows the batch norm
26+
to be completely removed.
27+
28+
29+
30+
┌──────▼──────┐
31+
│ aten.linear │
32+
└──────┬──────┘
33+
│ │
34+
┌─────────────────────▼─────────────────────┐ replace with ┌──────▼──────┐
35+
│ aten._native_batch_norm_legit_no_training │ ──────────────► │ aten.linear │
36+
└─────────────────────┬─────────────────────┘ └──────┬──────┘
37+
│ ▼
38+
┌─────▼──────┐
39+
│ getitem(0) │
40+
└─────┬──────┘
41+
42+
43+
[1] https://github.com/pytorch/executorch/blob/v0.5.0-rc2/kernels/portable/cpu/op_native_batch_norm.cpp#L118-L128
44+
"""
45+
46+
def run(self) -> bool:
47+
def _is_batch_norm(node_: Node) -> bool:
48+
return node_.op == "call_function" and node_.target == torch.ops.aten._native_batch_norm_legit_no_training.default
49+
50+
def _is_linear(node_: Node):
51+
return node_.op == "call_function" and node_.target == torch.ops.aten.linear.default
52+
53+
def _is_getitem(node_: Node) -> bool:
54+
return node_.op == "call_function" and node_.target.__name__ == "getitem"
55+
56+
made_changes = False
57+
58+
if not any(map(_is_batch_norm, self.module.graph.nodes)):
59+
return made_changes # No batch norm nodes in the model.
60+
61+
for node in self.module.graph.nodes:
62+
if not _is_batch_norm(node):
63+
continue # Not BatchNorm.
64+
65+
bn_node = node
66+
if not all(_is_getitem(user) and user.args[1] == 0 for user in bn_node.users):
67+
# Nodes other than just `getitem(0)` follow after the BatchNorm. Probably `getitem` nodes accessing
68+
# other outputs of the BN. After the fusion with a Linear op, only the first output can be accessed.
69+
continue
70+
71+
if not _is_linear(bn_node.args[0]):
72+
continue # Something other than a Linear node comes before the BatchNorm.
73+
74+
linear_node = bn_node.args[0]
75+
linear_weight_node = linear_node.args[1]
76+
linear_bias_node = linear_node.args[2] if len(linear_node.args) > 2 else None
77+
78+
linear_w = self.get_tensor_constant_from_node(linear_weight_node)
79+
linear_b = self.get_tensor_constant_from_node(linear_bias_node)
80+
81+
# batch norm legit no training args: input, weight, bias, running mean, running var, momentum, eps
82+
bn_w = self.get_tensor_constant_from_node(bn_node.args[1])
83+
bn_b = self.get_tensor_constant_from_node(bn_node.args[2])
84+
bn_rm = self.get_tensor_constant_from_node(bn_node.args[3])
85+
bn_rv = self.get_tensor_constant_from_node(bn_node.args[4])
86+
bn_eps = bn_node.args[6]
87+
88+
if any(t is None for t in (linear_w, bn_w, bn_b, bn_rm, bn_rv)): # The Linear bias can be None.
89+
continue # The data is not static. Leave this BatchNorm as is (probably a rare case).
90+
fused_weight, fused_bias = fuse_linear_bn_weights(linear_w, linear_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b)
91+
92+
# Update the weight and bias for Linear.
93+
linear_args = list(linear_node.args)
94+
if len(linear_args) == 2:
95+
# Fill in the default bias argument.
96+
linear_args.append(None)
97+
98+
weight_attr_name = linear_weight_node.target
99+
_assign_attr(fused_weight, self.module, weight_attr_name, _AttrKind.PARAMETER)
100+
101+
if linear_bias_node is not None:
102+
bias_attr_name = linear_bias_node.target
103+
_assign_attr(fused_bias, self.module, str(bias_attr_name), _AttrKind.PARAMETER)
104+
else:
105+
# The Linear doesn't have a bias. Create a new one.
106+
bias_attr_name = weight_attr_name + "_bias"
107+
_assign_attr(fused_bias, self.module, bias_attr_name, _AttrKind.PARAMETER)
108+
with self.module.graph.inserting_before(linear_node):
109+
get_bias_node = self.module.graph.get_attr(bias_attr_name)
110+
111+
linear_args[2] = get_bias_node
112+
113+
linear_node.args = tuple(linear_args)
114+
115+
# Replace the uses of the BatchNorm with the Linear.
116+
for user in bn_node.users:
117+
user.replace_all_uses_with(linear_node)
118+
119+
made_changes = True
120+
121+
return made_changes
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from abc import abstractmethod, ABC
7+
8+
from torch.fx import GraphModule
9+
from torch.nn.parameter import Parameter
10+
11+
12+
class NXPPyTorchPass(ABC):
13+
""" Abstract parent class for pre-processing passes on the aten dialect level. """
14+
15+
def __init__(self, module: GraphModule) -> None:
16+
super().__init__()
17+
self.module = module
18+
19+
@abstractmethod
20+
def run(self) -> bool:
21+
""" Execute the pass and return a bool indicating if any changes have been made. """
22+
pass
23+
24+
def get_tensor_constant_from_node(self, node) -> Parameter | None:
25+
""" Get the static data from a given node. If it doesn't have any data, return `None`. """
26+
if node is None or node.op != 'get_attr':
27+
return None
28+
29+
target_atoms = node.target.split('.')
30+
attr_itr = self.module
31+
for i, atom in enumerate(target_atoms):
32+
if not hasattr(attr_itr, atom):
33+
return None
34+
attr_itr = getattr(attr_itr, atom)
35+
return attr_itr
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import logging
7+
from typing import Iterable
8+
9+
import itertools
10+
from torch.fx import GraphModule
11+
12+
from executorch.backends.nxp.pytorch_passes.fuse_batch_norm_with_conv_pass import FuseBatchNormWithConvPass
13+
from executorch.backends.nxp.pytorch_passes.fuse_batch_norm_with_linear_pass import FuseBatchNormWithLinearPass
14+
from executorch.backends.nxp.pytorch_passes.nxp_pytorch_pass import NXPPyTorchPass
15+
16+
17+
class NXPPyTorchPassManager:
18+
""" Class iteratively calls provided passes which inherit from the `NXPPyTorchPass` class. """
19+
20+
def __init__(self, module: GraphModule, passes: Iterable[type[NXPPyTorchPass]] | None = None):
21+
self.module = module
22+
self.passes = passes or [ # New passes should be added here.
23+
FuseBatchNormWithConvPass,
24+
FuseBatchNormWithLinearPass
25+
]
26+
27+
def _clean_up_graph_module(self):
28+
self.module.graph.eliminate_dead_code()
29+
self.module.recompile()
30+
31+
def run(self) -> GraphModule:
32+
""" Iteratively apply all available passes for as long as they are changing the graph. """
33+
graph_module = self.module
34+
num_passes = len(self.passes)
35+
hard_limit = 10 * num_passes # Empirical value.
36+
num_passes_since_last_change = 0
37+
38+
self._clean_up_graph_module()
39+
40+
# Cycle through all passes as long as they are making changes.
41+
for i, pass_class in enumerate(itertools.cycle(self.passes)):
42+
try:
43+
pass_ = pass_class(graph_module)
44+
made_changes = pass_.run()
45+
self._clean_up_graph_module()
46+
47+
if made_changes:
48+
num_passes_since_last_change = 0
49+
else:
50+
num_passes_since_last_change += 1
51+
52+
if num_passes_since_last_change >= num_passes or i >= hard_limit:
53+
break
54+
55+
except Exception as e:
56+
logging.warning(f'An exception occurred during the pre-processing pass `{pass_class}`. '
57+
'Please report this issue.\n' + str(e))
58+
59+
return graph_module

backends/nxp/tests/executorch_pipeline.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
1+
# Copyright 2024-2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
16
import torch
27
from torch import nn
38
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
49

510
from executorch import exir
611
from executorch.backends.nxp.neutron_partitioner import NeutronPartitioner
712
from executorch.backends.nxp.nxp_backend import generate_neutron_compile_spec
13+
from executorch.backends.nxp.pytorch_passes.nxp_pytorch_pass_manager import NXPPyTorchPassManager
814
from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer
915
from executorch.extension.export_util.utils import export_to_edge
1016
from executorch.exir import EdgeProgramManager, ExecutorchBackendConfig, ExecutorchProgramManager
@@ -26,6 +32,11 @@ def to_quantized_edge_program(model: torch.nn.Module, input_shape: tuple, target
2632
example_input = (torch.ones(*input_shape),)
2733

2834
exir_program_aten = torch._export.capture_pre_autograd_graph(model, example_input)
35+
36+
# Run pre-processing passes of the float32 aten dialect program.
37+
pass_manager = NXPPyTorchPassManager(exir_program_aten)
38+
pass_manager.run() # All passes by default.
39+
2940
exir_program_aten_quant = _quantize_model(exir_program_aten, calibration_inputs)
3041
edge_program_manager = export_to_edge(exir_program_aten_quant, example_input)
3142

0 commit comments

Comments
 (0)