|
| 1 | +# Copyright 2025 NXP |
| 2 | +# |
| 3 | +# This source code is licensed under the BSD-style license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | + |
| 6 | +import torch |
| 7 | +from torch.export.unflatten import _AttrKind, _assign_attr |
| 8 | +from torch.fx import Node |
| 9 | +from torch.nn.utils import fuse_conv_bn_weights |
| 10 | + |
| 11 | +from executorch.backends.nxp.pytorch_passes.nxp_pytorch_pass import NXPPyTorchPass |
| 12 | + |
| 13 | + |
| 14 | +class FuseBatchNormWithConvPass(NXPPyTorchPass): |
| 15 | + """ The executorch batch normalization carries out the following computation [1]. |
| 16 | +
|
| 17 | + (x - mean) / (var + eps) * W + B |
| 18 | +
|
| 19 | + Which can be expressed as |
| 20 | +
|
| 21 | + x * (W / sqrt(var + eps)) + (B - mean * (W / sqrt(var + eps))) |
| 22 | +
|
| 23 | + So the batch norm can be done as 1 multiplication and 1 addition, provided that the parameters are static, |
| 24 | + and the terms can be precomputed. If there is a `Conv` operator before the batch normalization, this scale and |
| 25 | + bias can be statically integrated into the weights and bias of the `Conv`, which allows the batch norm to be |
| 26 | + completely removed. |
| 27 | + |
| 28 | + |
| 29 | + │ |
| 30 | + ┌─────────────▼─────────────┐ |
| 31 | + │ aten.conv1d | aten.conv2d │ |
| 32 | + └─────────────┬─────────────┘ |
| 33 | + │ │ |
| 34 | + ┌─────────────────────▼─────────────────────┐ replace with ┌─────────────▼─────────────┐ |
| 35 | + │ aten._native_batch_norm_legit_no_training │ ──────────────► │ aten.conv1d | aten.conv2d │ |
| 36 | + └─────────────────────┬─────────────────────┘ └─────────────┬─────────────┘ |
| 37 | + │ ▼ |
| 38 | + ┌─────▼──────┐ |
| 39 | + │ getitem(0) │ |
| 40 | + └─────┬──────┘ |
| 41 | + ▼ |
| 42 | +
|
| 43 | + [1] https://github.com/pytorch/executorch/blob/v0.5.0-rc2/kernels/portable/cpu/op_native_batch_norm.cpp#L118-L128 |
| 44 | + """ |
| 45 | + |
| 46 | + def run(self) -> bool: |
| 47 | + def _is_batch_norm(node_: Node) -> bool: |
| 48 | + return node_.op == "call_function" and node_.target == torch.ops.aten._native_batch_norm_legit_no_training.default |
| 49 | + |
| 50 | + def _is_conv(node_: Node): |
| 51 | + return node_.op == "call_function" and node_.target in ( |
| 52 | + torch.ops.aten.conv1d.default, |
| 53 | + torch.ops.aten.conv2d.default |
| 54 | + ) |
| 55 | + |
| 56 | + def _is_getitem(node_: Node) -> bool: |
| 57 | + return node_.op == "call_function" and node_.target.__name__ == "getitem" |
| 58 | + |
| 59 | + made_changes = False |
| 60 | + |
| 61 | + if not any(map(_is_batch_norm, self.module.graph.nodes)): |
| 62 | + return made_changes # No batch norm nodes in the model. |
| 63 | + |
| 64 | + for node in self.module.graph.nodes: |
| 65 | + if not _is_batch_norm(node): |
| 66 | + continue # Not BatchNorm. |
| 67 | + |
| 68 | + bn_node = node |
| 69 | + if not all(_is_getitem(user) and user.args[1] == 0 for user in bn_node.users): |
| 70 | + # Nodes other than just `getitem(0)` follow after the BatchNorm. Probably `getitem` nodes accessing |
| 71 | + # other outputs of the BN. After the fusion with a Conv op, only the first output can be accessed. |
| 72 | + continue |
| 73 | + |
| 74 | + if not _is_conv(bn_node.args[0]): |
| 75 | + continue # Something other than a Conv node comes before the BatchNorm. |
| 76 | + |
| 77 | + conv_node = bn_node.args[0] |
| 78 | + conv_weight_node = conv_node.args[1] |
| 79 | + conv_bias_node = conv_node.args[2] if len(conv_node.args) > 2 else None |
| 80 | + |
| 81 | + # conv args: input, weight, bias, stride, padding, dilation, ... |
| 82 | + conv_w = self.get_tensor_constant_from_node(conv_weight_node) |
| 83 | + conv_b = self.get_tensor_constant_from_node(conv_bias_node) |
| 84 | + |
| 85 | + # batch norm legit no training args: input, weight, bias, running mean, running var, momentum, eps |
| 86 | + bn_w = self.get_tensor_constant_from_node(bn_node.args[1]) |
| 87 | + bn_b = self.get_tensor_constant_from_node(bn_node.args[2]) |
| 88 | + bn_rm = self.get_tensor_constant_from_node(bn_node.args[3]) |
| 89 | + bn_rv = self.get_tensor_constant_from_node(bn_node.args[4]) |
| 90 | + bn_eps = bn_node.args[6] |
| 91 | + |
| 92 | + if any(t is None for t in (conv_w, bn_rm, bn_rv)): # The other inputs can be None. |
| 93 | + continue # The data is not static. Leave this BatchNorm as is (probably a rare case). |
| 94 | + fused_weight, fused_bias = fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b) |
| 95 | + |
| 96 | + # Update the weight and bias for Conv. |
| 97 | + conv_args = list(conv_node.args) |
| 98 | + if len(conv_args) == 2: |
| 99 | + # Fill in the default bias argument. |
| 100 | + conv_args.append(None) |
| 101 | + |
| 102 | + weight_attr_name = conv_weight_node.target |
| 103 | + _assign_attr(fused_weight, self.module, weight_attr_name, _AttrKind.PARAMETER) |
| 104 | + |
| 105 | + if conv_bias_node is not None: |
| 106 | + bias_attr_name = conv_bias_node.target |
| 107 | + _assign_attr(fused_bias, self.module, str(bias_attr_name), _AttrKind.PARAMETER) |
| 108 | + else: |
| 109 | + # The Conv doesn't have a bias. Create a new one. |
| 110 | + bias_attr_name = weight_attr_name + "_bias" |
| 111 | + _assign_attr(fused_bias, self.module, bias_attr_name, _AttrKind.PARAMETER) |
| 112 | + with self.module.graph.inserting_before(conv_node): |
| 113 | + get_bias_node = self.module.graph.get_attr(bias_attr_name) |
| 114 | + |
| 115 | + conv_args[2] = get_bias_node |
| 116 | + |
| 117 | + conv_node.args = tuple(conv_args) |
| 118 | + |
| 119 | + # Replace the uses of the BatchNorm with the Conv. |
| 120 | + for user in bn_node.users: |
| 121 | + user.replace_all_uses_with(conv_node) |
| 122 | + |
| 123 | + made_changes = True |
| 124 | + |
| 125 | + return made_changes |
0 commit comments