-
Notifications
You must be signed in to change notification settings - Fork 603
NXP backend: Create NeutronAtenPassManager with initial BatchNorm fusing passes #10579
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
# Copyright 2025 NXP | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
from typing import Optional | ||
|
||
import torch | ||
from torch.export.unflatten import _assign_attr, _AttrKind | ||
from torch.fx import GraphModule, Node | ||
from torch.fx.passes.infra.pass_base import PassBase, PassResult | ||
from torch.nn.parameter import Parameter | ||
from torch.nn.utils import fuse_conv_bn_weights | ||
|
||
|
||
class FuseBatchNormWithConvPass(PassBase): | ||
"""The executorch batch normalization carries out the following computation [1]. | ||
|
||
(x - mean) / sqrt(var + eps) * W + B | ||
|
||
Which can be expressed as | ||
|
||
x * (W / sqrt(var + eps)) + (B - mean * (W / sqrt(var + eps))) | ||
|
||
So the batch norm can be done as 1 multiplication and 1 addition, provided that the parameters are static, | ||
and the terms can be precomputed. If there is a `Conv` operator before the batch normalization, this scale and | ||
bias can be statically integrated into the weights and bias of the `Conv`, which allows the batch norm to be | ||
completely removed. | ||
|
||
|
||
│ | ||
┌─────────────▼─────────────┐ | ||
│ aten.conv1d | aten.conv2d │ | ||
└─────────────┬─────────────┘ | ||
│ │ | ||
┌─────────────────────▼─────────────────────┐ replace with ┌─────────────▼─────────────┐ | ||
│ aten.batch_norm │ ──────────────► │ aten.conv1d | aten.conv2d │ | ||
└─────────────────────┬─────────────────────┘ └─────────────┬─────────────┘ | ||
│ ▼ | ||
▼ | ||
|
||
[1] https://github.com/pytorch/executorch/blob/v0.5.0-rc2/kernels/portable/cpu/op_native_batch_norm.cpp#L118-L128 | ||
""" | ||
|
||
def _get_tensor_constant_from_node(self, graph_module, node) -> Parameter | None: | ||
"""Get the static data from a given node. If it doesn't have any data, return `None`.""" | ||
if node is None or node.op != "get_attr": | ||
return None | ||
|
||
target_atoms = node.target.split(".") | ||
attr_itr = graph_module | ||
for atom in target_atoms: | ||
if not hasattr(attr_itr, atom): | ||
return None | ||
attr_itr = getattr(attr_itr, atom) | ||
return attr_itr | ||
|
||
def call(self, graph_module: GraphModule) -> Optional[PassResult]: | ||
def _is_batch_norm(node_: Node) -> bool: | ||
return ( | ||
node_.op == "call_function" | ||
and node_.target == torch.ops.aten.batch_norm.default | ||
) | ||
|
||
def _is_conv(node_: Node): | ||
is_conv = node_.op == "call_function" and node_.target in ( | ||
torch.ops.aten.conv1d.default, | ||
torch.ops.aten.conv2d.default, | ||
) | ||
has_single_user = len(node.users) == 1 | ||
|
||
return is_conv and has_single_user | ||
|
||
made_changes = False | ||
|
||
if not any(map(_is_batch_norm, graph_module.graph.nodes)): | ||
return PassResult( | ||
graph_module, made_changes | ||
) # No batch norm nodes in the model. | ||
|
||
for node in graph_module.graph.nodes: | ||
if not _is_batch_norm(node): | ||
continue # Not BatchNorm. | ||
|
||
bn_node = node | ||
|
||
if not _is_conv(bn_node.args[0]): | ||
continue # Something other than a Conv node comes before the BatchNorm. | ||
|
||
conv_node = bn_node.args[0] | ||
conv_weight_node = conv_node.args[1] | ||
conv_bias_node = conv_node.args[2] if len(conv_node.args) > 2 else None | ||
|
||
# conv args: input, weight, bias, stride, padding, dilation, ... | ||
conv_w = self._get_tensor_constant_from_node(graph_module, conv_weight_node) | ||
conv_b = self._get_tensor_constant_from_node(graph_module, conv_bias_node) | ||
|
||
# batch norm args: input, weight, bias, running mean, training, running var, momentum, eps | ||
bn_w = self._get_tensor_constant_from_node(graph_module, bn_node.args[1]) | ||
bn_b = self._get_tensor_constant_from_node(graph_module, bn_node.args[2]) | ||
bn_rm = self._get_tensor_constant_from_node(graph_module, bn_node.args[3]) | ||
bn_rv = self._get_tensor_constant_from_node(graph_module, bn_node.args[4]) | ||
bn_eps = bn_node.args[7] | ||
|
||
if any( | ||
t is None for t in (conv_w, bn_rm, bn_rv) | ||
): # The other inputs can be None. | ||
continue # The data is not static. Leave this BatchNorm as is (probably a rare case). | ||
fused_weight, fused_bias = fuse_conv_bn_weights( | ||
conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b | ||
) | ||
|
||
# Update the weight and bias for Conv. | ||
conv_args = list(conv_node.args) | ||
if len(conv_args) == 2: | ||
# Fill in the default bias argument. | ||
conv_args.append(None) | ||
|
||
weight_attr_name = conv_weight_node.target | ||
_assign_attr( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you can do this without using internal methods, and insert new params and delete old params https://github.com/pytorch/executorch/blob/main/backends/transforms/utils.py#L65 (assuming you are doing this after export) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That would require to propagate exported program (via constructor?) to the pass. I don't find that particularly clean because passes are built upon GraphModules as you've mentioned in some of the previous PR's. We will end up with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I guess this is for when you do run this pass in |
||
fused_weight, graph_module, weight_attr_name, _AttrKind.PARAMETER | ||
) | ||
|
||
if conv_bias_node is not None: | ||
bias_attr_name = conv_bias_node.target | ||
_assign_attr( | ||
fused_bias, graph_module, str(bias_attr_name), _AttrKind.PARAMETER | ||
) | ||
else: | ||
# The Conv doesn't have a bias. Create a new one. | ||
bias_attr_name = weight_attr_name + "_bias" | ||
_assign_attr( | ||
fused_bias, graph_module, bias_attr_name, _AttrKind.PARAMETER | ||
) | ||
with graph_module.graph.inserting_before(conv_node): | ||
get_bias_node = graph_module.graph.get_attr(bias_attr_name) | ||
|
||
conv_args[2] = get_bias_node | ||
|
||
conv_node.args = tuple(conv_args) | ||
|
||
# Replace the uses of the BatchNorm with the Conv. | ||
bn_node.replace_all_uses_with(conv_node) | ||
|
||
made_changes = True | ||
|
||
return PassResult(graph_module, made_changes) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
# Copyright 2025 NXP | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
from typing import Optional | ||
|
||
import torch | ||
from torch.export.unflatten import _assign_attr, _AttrKind | ||
from torch.fx import GraphModule, Node | ||
from torch.fx.passes.infra.pass_base import PassBase, PassResult | ||
from torch.nn.parameter import Parameter | ||
from torch.nn.utils import fuse_linear_bn_weights | ||
|
||
|
||
class FuseBatchNormWithLinearPass(PassBase): | ||
"""The executorch batch normalization carries out the following computation [1]. | ||
|
||
(x - mean) / sqrt(var + eps) * W + B | ||
|
||
Which can be expressed as | ||
|
||
x * (W / sqrt(var + eps)) + (B - mean * (W / sqrt(var + eps))) | ||
|
||
So the batch norm can be done as 1 multiplication and 1 addition, provided that the parameters are static, | ||
and the terms can be precomputed. If there is a `Linear` operator before the batch normalization, this scale | ||
and bias can be statically integrated into the weights and bias of the `Linear`, which allows the batch norm | ||
to be completely removed. | ||
|
||
|
||
│ | ||
┌──────▼──────┐ | ||
│ aten.linear │ | ||
└──────┬──────┘ | ||
│ │ | ||
┌─────────────────────▼─────────────────────┐ replace with ┌──────▼──────┐ | ||
│ aten.batch_norm │ ──────────────► │ aten.linear │ | ||
└─────────────────────┬─────────────────────┘ └──────┬──────┘ | ||
▼ | ||
|
||
[1] https://github.com/pytorch/executorch/blob/v0.5.0-rc2/kernels/portable/cpu/op_native_batch_norm.cpp#L118-L128 | ||
""" | ||
|
||
def _get_tensor_constant_from_node(self, graph_module, node) -> Parameter | None: | ||
"""Get the static data from a given node. If it doesn't have any data, return `None`.""" | ||
if node is None or node.op != "get_attr": | ||
return None | ||
|
||
target_atoms = node.target.split(".") | ||
attr_itr = graph_module | ||
for atom in target_atoms: | ||
if not hasattr(attr_itr, atom): | ||
return None | ||
attr_itr = getattr(attr_itr, atom) | ||
return attr_itr | ||
|
||
def call(self, graph_module: GraphModule) -> Optional[PassResult]: | ||
def _is_batch_norm(node_: Node) -> bool: | ||
return ( | ||
node_.op == "call_function" | ||
and node_.target == torch.ops.aten.batch_norm.default | ||
) | ||
|
||
def _is_linear(node_: Node): | ||
is_linear = ( | ||
node_.op == "call_function" | ||
and node_.target == torch.ops.aten.linear.default | ||
) | ||
has_single_user = len(node.users) == 1 | ||
|
||
return is_linear and has_single_user | ||
|
||
made_changes = False | ||
|
||
if not any(map(_is_batch_norm, graph_module.graph.nodes)): | ||
return PassResult( | ||
graph_module, made_changes | ||
) # No batch norm nodes in the model. | ||
|
||
for node in graph_module.graph.nodes: | ||
if not _is_batch_norm(node): | ||
continue # Not BatchNorm. | ||
|
||
bn_node = node | ||
|
||
if not _is_linear(bn_node.args[0]): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. here and above, you can be defensive and also make sure args len == 1. there is nothing stopping someone from adding a skip connection after conv/linear and before batchnorm, or some other op. Can't say i've seen it in practice though There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added check for number of users of Conv/Linear op. |
||
continue # Something other than a Linear node comes before the BatchNorm. | ||
|
||
linear_node = bn_node.args[0] | ||
linear_weight_node = linear_node.args[1] | ||
linear_bias_node = ( | ||
linear_node.args[2] if len(linear_node.args) > 2 else None | ||
) | ||
|
||
linear_w = self._get_tensor_constant_from_node( | ||
graph_module, linear_weight_node | ||
) | ||
linear_b = self._get_tensor_constant_from_node( | ||
graph_module, linear_bias_node | ||
) | ||
|
||
# batch norm args: input, weight, bias, running mean, training, running var, momentum, eps | ||
bn_w = self._get_tensor_constant_from_node(graph_module, bn_node.args[1]) | ||
bn_b = self._get_tensor_constant_from_node(graph_module, bn_node.args[2]) | ||
bn_rm = self._get_tensor_constant_from_node(graph_module, bn_node.args[3]) | ||
bn_rv = self._get_tensor_constant_from_node(graph_module, bn_node.args[4]) | ||
bn_eps = bn_node.args[7] | ||
|
||
if any( | ||
t is None for t in (linear_w, bn_w, bn_b, bn_rm, bn_rv) | ||
): # The Linear bias can be None. | ||
continue # The data is not static. Leave this BatchNorm as is (probably a rare case). | ||
fused_weight, fused_bias = fuse_linear_bn_weights( | ||
linear_w, linear_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b | ||
) | ||
|
||
# Update the weight and bias for Linear. | ||
linear_args = list(linear_node.args) | ||
if len(linear_args) == 2: | ||
# Fill in the default bias argument. | ||
linear_args.append(None) | ||
|
||
weight_attr_name = linear_weight_node.target | ||
_assign_attr( | ||
fused_weight, graph_module, weight_attr_name, _AttrKind.PARAMETER | ||
) | ||
|
||
if linear_bias_node is not None: | ||
bias_attr_name = linear_bias_node.target | ||
_assign_attr( | ||
fused_bias, graph_module, str(bias_attr_name), _AttrKind.PARAMETER | ||
) | ||
else: | ||
# The Linear doesn't have a bias. Create a new one. | ||
bias_attr_name = weight_attr_name + "_bias" | ||
_assign_attr( | ||
fused_bias, graph_module, bias_attr_name, _AttrKind.PARAMETER | ||
) | ||
with graph_module.graph.inserting_before(linear_node): | ||
get_bias_node = graph_module.graph.get_attr(bias_attr_name) | ||
|
||
linear_args[2] = get_bias_node | ||
|
||
linear_node.args = tuple(linear_args) | ||
|
||
# Replace the uses of the BatchNorm with the Linear. | ||
bn_node.replace_all_uses_with(linear_node) | ||
|
||
made_changes = True | ||
|
||
return PassResult(graph_module, made_changes) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# Copyright 2025 NXP | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import Callable | ||
|
||
import torch | ||
|
||
from executorch.backends.nxp.aten_passes.fuse_batch_norm_with_conv_pass import ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: can become unwieldly since there are a few other permutations of "batch norm" fusion, maybe just one file for fuse_batch_norm would be good. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree with you. Maybe we can create subdirectories in |
||
FuseBatchNormWithConvPass, | ||
) | ||
from executorch.backends.nxp.aten_passes.fuse_batch_norm_with_linear_pass import ( | ||
FuseBatchNormWithLinearPass, | ||
) | ||
from executorch.exir.pass_manager import PassManager | ||
from torch import nn | ||
from torch.fx.passes.infra.pass_base import PassResult | ||
|
||
PassType = list[type[Callable[[torch.fx.GraphModule], PassResult]]] | ||
|
||
|
||
class NeutronAtenPassManager(PassManager): | ||
|
||
def __init__(self, passes: list[PassType] = None): | ||
passes: list[PassType] = passes or [ | ||
FuseBatchNormWithConvPass(), | ||
FuseBatchNormWithLinearPass(), | ||
] | ||
|
||
super().__init__(passes) | ||
|
||
def __call__(self, module: nn.Module) -> PassResult: | ||
pass_result: PassResult = super().__call__(module) | ||
|
||
graph_module = pass_result.graph_module | ||
graph_module.graph.eliminate_dead_code() | ||
graph_module.recompile() | ||
|
||
return pass_result |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,9 @@ | |
from typing import List, Optional, Tuple, Union | ||
|
||
import torch | ||
from executorch.backends.nxp.aten_passes.neutron_aten_pass_manager import ( | ||
NeutronAtenPassManager, | ||
) | ||
|
||
from executorch.backends.nxp.quantizer.patterns import ( | ||
AddmmPattern, | ||
|
@@ -202,4 +205,5 @@ def __init__(self): | |
def transform_for_annotation( | ||
self, model: torch.fx.GraphModule | ||
) -> torch.fx.GraphModule: | ||
return model | ||
pass_runner = NeutronAtenPassManager() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! |
||
return pass_runner(model).graph_module |
Uh oh!
There was an error while loading. Please reload this page.