|
| 1 | +# Copyright 2025 NXP |
| 2 | +# |
| 3 | +# This source code is licensed under the BSD-style license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | +from typing import Optional |
| 6 | + |
| 7 | +import torch |
| 8 | +from torch.export.unflatten import _assign_attr, _AttrKind |
| 9 | +from torch.fx import GraphModule, Node |
| 10 | +from torch.fx.passes.infra.pass_base import PassBase, PassResult |
| 11 | +from torch.nn.parameter import Parameter |
| 12 | +from torch.nn.utils import fuse_conv_bn_weights |
| 13 | + |
| 14 | + |
| 15 | +class FuseBatchNormWithConvPass(PassBase): |
| 16 | + """The executorch batch normalization carries out the following computation [1]. |
| 17 | +
|
| 18 | + (x - mean) / (var + eps) * W + B |
| 19 | +
|
| 20 | + Which can be expressed as |
| 21 | +
|
| 22 | + x * (W / sqrt(var + eps)) + (B - mean * (W / sqrt(var + eps))) |
| 23 | +
|
| 24 | + So the batch norm can be done as 1 multiplication and 1 addition, provided that the parameters are static, |
| 25 | + and the terms can be precomputed. If there is a `Conv` operator before the batch normalization, this scale and |
| 26 | + bias can be statically integrated into the weights and bias of the `Conv`, which allows the batch norm to be |
| 27 | + completely removed. |
| 28 | +
|
| 29 | +
|
| 30 | + │ |
| 31 | + ┌─────────────▼─────────────┐ |
| 32 | + │ aten.conv1d | aten.conv2d │ |
| 33 | + └─────────────┬─────────────┘ |
| 34 | + │ │ |
| 35 | + ┌─────────────────────▼─────────────────────┐ replace with ┌─────────────▼─────────────┐ |
| 36 | + │ aten.batch_norm │ ──────────────► │ aten.conv1d | aten.conv2d │ |
| 37 | + └─────────────────────┬─────────────────────┘ └─────────────┬─────────────┘ |
| 38 | + │ ▼ |
| 39 | + ▼ |
| 40 | +
|
| 41 | + [1] https://github.com/pytorch/executorch/blob/v0.5.0-rc2/kernels/portable/cpu/op_native_batch_norm.cpp#L118-L128 |
| 42 | + """ |
| 43 | + |
| 44 | + def _get_tensor_constant_from_node(self, graph_module, node) -> Parameter | None: |
| 45 | + """Get the static data from a given node. If it doesn't have any data, return `None`.""" |
| 46 | + if node is None or node.op != "get_attr": |
| 47 | + return None |
| 48 | + |
| 49 | + target_atoms = node.target.split(".") |
| 50 | + attr_itr = graph_module |
| 51 | + for atom in target_atoms: |
| 52 | + if not hasattr(attr_itr, atom): |
| 53 | + return None |
| 54 | + attr_itr = getattr(attr_itr, atom) |
| 55 | + return attr_itr |
| 56 | + |
| 57 | + def call(self, graph_module: GraphModule) -> Optional[PassResult]: |
| 58 | + def _is_batch_norm(node_: Node) -> bool: |
| 59 | + return ( |
| 60 | + node_.op == "call_function" |
| 61 | + and node_.target == torch.ops.aten.batch_norm.default |
| 62 | + ) |
| 63 | + |
| 64 | + def _is_conv(node_: Node): |
| 65 | + is_conv = node_.op == "call_function" and node_.target in ( |
| 66 | + torch.ops.aten.conv1d.default, |
| 67 | + torch.ops.aten.conv2d.default, |
| 68 | + ) |
| 69 | + has_single_user = len(node.users) == 1 |
| 70 | + |
| 71 | + return is_conv and has_single_user |
| 72 | + |
| 73 | + made_changes = False |
| 74 | + |
| 75 | + if not any(map(_is_batch_norm, graph_module.graph.nodes)): |
| 76 | + return PassResult( |
| 77 | + graph_module, made_changes |
| 78 | + ) # No batch norm nodes in the model. |
| 79 | + |
| 80 | + for node in graph_module.graph.nodes: |
| 81 | + if not _is_batch_norm(node): |
| 82 | + continue # Not BatchNorm. |
| 83 | + |
| 84 | + bn_node = node |
| 85 | + |
| 86 | + if not _is_conv(bn_node.args[0]): |
| 87 | + continue # Something other than a Conv node comes before the BatchNorm. |
| 88 | + |
| 89 | + conv_node = bn_node.args[0] |
| 90 | + conv_weight_node = conv_node.args[1] |
| 91 | + conv_bias_node = conv_node.args[2] if len(conv_node.args) > 2 else None |
| 92 | + |
| 93 | + # conv args: input, weight, bias, stride, padding, dilation, ... |
| 94 | + conv_w = self._get_tensor_constant_from_node(graph_module, conv_weight_node) |
| 95 | + conv_b = self._get_tensor_constant_from_node(graph_module, conv_bias_node) |
| 96 | + |
| 97 | + # batch norm args: input, weight, bias, running mean, training, running var, momentum, eps |
| 98 | + bn_w = self._get_tensor_constant_from_node(graph_module, bn_node.args[1]) |
| 99 | + bn_b = self._get_tensor_constant_from_node(graph_module, bn_node.args[2]) |
| 100 | + bn_rm = self._get_tensor_constant_from_node(graph_module, bn_node.args[3]) |
| 101 | + bn_rv = self._get_tensor_constant_from_node(graph_module, bn_node.args[4]) |
| 102 | + bn_eps = bn_node.args[7] |
| 103 | + |
| 104 | + if any( |
| 105 | + t is None for t in (conv_w, bn_rm, bn_rv) |
| 106 | + ): # The other inputs can be None. |
| 107 | + continue # The data is not static. Leave this BatchNorm as is (probably a rare case). |
| 108 | + fused_weight, fused_bias = fuse_conv_bn_weights( |
| 109 | + conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b |
| 110 | + ) |
| 111 | + |
| 112 | + # Update the weight and bias for Conv. |
| 113 | + conv_args = list(conv_node.args) |
| 114 | + if len(conv_args) == 2: |
| 115 | + # Fill in the default bias argument. |
| 116 | + conv_args.append(None) |
| 117 | + |
| 118 | + weight_attr_name = conv_weight_node.target |
| 119 | + _assign_attr( |
| 120 | + fused_weight, graph_module, weight_attr_name, _AttrKind.PARAMETER |
| 121 | + ) |
| 122 | + |
| 123 | + if conv_bias_node is not None: |
| 124 | + bias_attr_name = conv_bias_node.target |
| 125 | + _assign_attr( |
| 126 | + fused_bias, graph_module, str(bias_attr_name), _AttrKind.PARAMETER |
| 127 | + ) |
| 128 | + else: |
| 129 | + # The Conv doesn't have a bias. Create a new one. |
| 130 | + bias_attr_name = weight_attr_name + "_bias" |
| 131 | + _assign_attr( |
| 132 | + fused_bias, graph_module, bias_attr_name, _AttrKind.PARAMETER |
| 133 | + ) |
| 134 | + with graph_module.graph.inserting_before(conv_node): |
| 135 | + get_bias_node = graph_module.graph.get_attr(bias_attr_name) |
| 136 | + |
| 137 | + conv_args[2] = get_bias_node |
| 138 | + |
| 139 | + conv_node.args = tuple(conv_args) |
| 140 | + |
| 141 | + # Replace the uses of the BatchNorm with the Conv. |
| 142 | + bn_node.replace_all_uses_with(conv_node) |
| 143 | + |
| 144 | + made_changes = True |
| 145 | + |
| 146 | + return PassResult(graph_module, made_changes) |
0 commit comments