Skip to content

Commit 2051a15

Browse files
Arm backend: Update fuse_batchnorm_pass to create new placeholders (#8411)
* [ARM backend] Update fuse_batchnorm_pass to create new placeholders - This allows to fuse bn+convs with multiple users of the same weights - Adds new util functions create/delete_const_placeholders to take care of updating the GraphSignature and state_dict/constants dict when handling constant placholders. - Adds and updates related tests Change-Id: I8e550614d9741de840786d9dca9f30af9eb95a64 * Move create/delete_constant_node utils to shared folder Change-Id: I3a82f58f9796e421bd205f030f7d79d72a2f7ed9 * Add buck dependency * Fix bazel build --------- Co-authored-by: Digant Desai <digantdesai@meta.com>
1 parent efd1a06 commit 2051a15

File tree

7 files changed

+349
-54
lines changed

7 files changed

+349
-54
lines changed

backends/arm/_passes/TARGETS

+1
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,6 @@ python_library(
99
"//executorch/backends/transforms:replace_scalar_with_tensor",
1010
"//executorch/backends/xnnpack/_passes:xnnpack_passes",
1111
"//executorch/exir:lib",
12+
"//executorch/backends/transforms:utils",
1213
],
1314
)

backends/arm/_passes/arm_pass_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
2-
# Copyright 2024-2025 Arm Limited and/or its affiliates.
32
# All rights reserved.
3+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
44
#
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.

backends/arm/_passes/fuse_batchnorm2d_pass.py

+82-48
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,15 @@
66
# pyre-unsafe
77

88
import torch
9+
from executorch.backends.transforms.utils import (
10+
create_constant_placeholder,
11+
delete_constant_placeholder,
12+
)
913
from executorch.exir import ExportedProgram
1014
from executorch.exir.dialects._ops import ops as exir_ops
1115
from executorch.exir.pass_base import ExportPass, PassResult
1216
from torch._export.utils import get_buffer, get_param
17+
from torch.export.graph_signature import InputKind
1318
from torch.fx import Node
1419
from torch.nn.utils.fusion import fuse_conv_bn_weights
1520

@@ -23,7 +28,7 @@ def __init__(self, exported_program: ExportedProgram):
2328
self.exported_program = exported_program
2429
super().__init__()
2530

26-
def is_fuseable_conv_bn(self, node: Node):
31+
def is_fuseable_conv_bn(self, node: Node) -> bool:
2732
"""Returns True if node is a batchnorm that can be fused into
2833
a parent convolution."""
2934
if node.op != "call_function":
@@ -44,15 +49,19 @@ def is_fuseable_conv_bn(self, node: Node):
4449
# Since we change the output of the conv, fuse only if it has single user.
4550
if len(conv.users) > 1:
4651
return False
47-
# For similar reasons, only fuse if conv parameters have single user.
48-
if len(conv.all_input_nodes[1].users) > 1:
49-
return False
50-
if len(conv.all_input_nodes) > 2 and len(conv.all_input_nodes[2].users) > 1:
51-
return False
5252
return True
5353

54+
def get_bias_name(self, conv_weight_node: Node, conv_bias_node: Node) -> str:
55+
if conv_bias_node:
56+
return conv_bias_node.name + "_fused_bn"
57+
elif "weight" in conv_weight_node.name:
58+
return conv_weight_node.name.replace("weight", "bias") + "_fused_bn"
59+
else:
60+
return conv_weight_node.name + "_bias_fused_bn"
61+
5462
def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901
5563
modified = False
64+
constant_placeholders_to_delete = set()
5665
for node in graph_module.graph.nodes:
5766
if not self.is_fuseable_conv_bn(node):
5867
continue
@@ -64,68 +73,93 @@ def get_param_or_none(arg) -> torch.nn.Parameter | None:
6473
)
6574

6675
# Get weight, bias, mean, var and epsilon from the batchnorm
67-
bn = node
68-
conv, bn_weight_node, bn_bias_node, bn_mean_node, bn_var_node = bn.args[0:5]
69-
bn_weight = get_param_or_none(bn_weight_node)
70-
bn_bias = get_param_or_none(bn_bias_node)
71-
72-
running_mean = get_buffer(self.exported_program, bn_mean_node)
73-
running_var = get_buffer(self.exported_program, bn_var_node)
74-
if running_mean is None or running_var is None:
76+
bn_node = node
77+
conv, bn_weight_node, bn_bias_node, bn_mean_node, bn_var_node = (
78+
bn_node.args[0:5]
79+
)
80+
bn_weight_tensor = get_param_or_none(bn_weight_node)
81+
bn_bias_tensor = get_param_or_none(bn_bias_node)
82+
bn_mean_tensor = get_buffer(self.exported_program, bn_mean_node)
83+
bn_var_tensor = get_buffer(self.exported_program, bn_var_node)
84+
if bn_mean_tensor is None or bn_var_tensor is None:
7585
raise ValueError(
7686
"Parameters running_mean and running_var of batchnorm can't be None."
7787
)
78-
epsilon = bn.args[-1]
88+
epsilon = bn_node.args[-1]
7989

8090
# Get weight and bias from conv
8191
conv_weight_node, conv_bias_node = conv.args[1:3]
82-
conv_weight = get_param(self.exported_program, conv_weight_node)
83-
conv_bias = get_param_or_none(conv_bias_node)
84-
if conv_weight is None:
92+
conv_weight_tensor = get_param(self.exported_program, conv_weight_node)
93+
conv_bias_tensor = get_param_or_none(conv_bias_node)
94+
if conv_weight_tensor is None:
8595
raise ValueError("Parameter weight of convolution can't be None.")
8696

8797
# Compute conv parameters folded with batchnorm
8898
fused_conv_weight, fused_conv_bias = fuse_conv_bn_weights(
89-
conv_weight,
90-
conv_bias,
91-
running_mean,
92-
running_var,
99+
conv_weight_tensor,
100+
conv_bias_tensor,
101+
bn_mean_tensor,
102+
bn_var_tensor,
93103
epsilon,
94-
bn_weight,
95-
bn_bias,
104+
bn_weight_tensor,
105+
bn_bias_tensor,
96106
)
97107

98-
# Set the conv parameters to fused value
99-
def try_set_param(
100-
param_node: Node | None, param_value: torch.nn.Parameter
101-
) -> bool:
102-
"""set_param but check if param_node is None first. Return True if param was set successfully, otherwise False."""
103-
if param_node is not None:
104-
param_name = (
105-
self.exported_program.graph_signature.inputs_to_parameters[
106-
param_node.name
107-
]
108+
# Create fused weights and bias to conv and replace conv args
109+
with graph_module.graph.inserting_before(conv_weight_node):
110+
fused_conv_weight_node = create_constant_placeholder(
111+
exp_program=self.exported_program,
112+
graph=graph_module.graph,
113+
kind=InputKind.PARAMETER,
114+
name=conv_weight_node.name + "_fused_bn",
115+
data=fused_conv_weight,
116+
)
117+
118+
if fused_conv_bias is not None:
119+
fused_conv_bias_node = create_constant_placeholder(
120+
exp_program=self.exported_program,
121+
graph=graph_module.graph,
122+
kind=InputKind.PARAMETER,
123+
name=self.get_bias_name(conv_weight_node, conv_bias_node),
124+
data=fused_conv_bias,
108125
)
109-
self.exported_program.state_dict[param_name] = param_value
110-
return True
111-
return False
126+
else:
127+
fused_conv_bias_node = None
128+
129+
conv.args = (
130+
conv.args[0],
131+
fused_conv_weight_node,
132+
fused_conv_bias_node,
133+
*conv.args[3:],
134+
)
112135

113-
try_set_param(conv_weight_node, fused_conv_weight)
114-
if not try_set_param(conv_bias_node, fused_conv_bias) and try_set_param(
115-
bn_bias_node, fused_conv_bias
116-
):
117-
# pyre-ignore[60]
118-
# Conv didn't have bias but batchnorm did, steal bias from batchnorm.
119-
conv_args = (*conv.args[0:2], bn_bias_node, *conv.args[3:])
120-
conv.args = conv_args
121-
122-
# Erasing nodes is handled by dead-code elimination.
123-
for user in bn.users:
136+
# Erasing batch-norm nodes is handled by dead-code elimination. After that we may remove their constant placeholder inputs
137+
for user in bn_node.users:
124138
user.replace_all_uses_with(conv)
139+
140+
constant_placeholders_to_delete.update(
141+
[
142+
bn_weight_node,
143+
bn_bias_node,
144+
bn_mean_node,
145+
bn_var_node,
146+
conv_weight_node,
147+
conv_bias_node,
148+
]
149+
)
125150
modified = True
126151

127152
if modified:
128153
graph_module.graph.eliminate_dead_code()
154+
for constant_placeholder in constant_placeholders_to_delete:
155+
if (constant_placeholder is not None) and (
156+
len(constant_placeholder.users) == 0
157+
):
158+
delete_constant_placeholder(
159+
self.exported_program, constant_placeholder
160+
)
161+
129162
graph_module.recompile()
130163
graph_module = super().call(graph_module).graph_module
164+
131165
return PassResult(graph_module=graph_module, modified=modified)

backends/arm/test/passes/test_fuse_batchnorm_pass.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -85,13 +85,13 @@ def forward(self, x):
8585
return x
8686

8787

88-
class MergeNoBN(torch.nn.Module):
88+
class MergeMultipleUsersBN(torch.nn.Module):
8989
ops_before_pass = {
9090
"executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 2,
9191
"executorch_exir_dialects_edge__ops_aten_convolution_default": 3,
9292
}
9393
ops_after_pass = {
94-
"executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 2,
94+
"executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 1,
9595
"executorch_exir_dialects_edge__ops_aten_convolution_default": 3,
9696
}
9797

@@ -122,7 +122,7 @@ def forward(self, x):
122122
z = self.conv2d2(x)
123123
a = self.batch_norm2d(
124124
y
125-
) # Can't be fused since paramters of conv2d2 have multiple users.
125+
) # Can be fused despite paramters of conv2d2 having multiple users.
126126

127127
return z, a
128128

@@ -131,7 +131,7 @@ def forward(self, x):
131131
"merge_one_of_two_bn_affine": MergeOneOfTwoBN(True),
132132
"merge_one_of_two_bn": MergeOneOfTwoBN(False),
133133
"merge_two_of_two_bn_affine": MergeTwosOfTwoBN(True),
134-
"merge_no_bn_affine": MergeNoBN(True),
134+
"merge_multiple_users_bn_affine": MergeMultipleUsersBN(True),
135135
}
136136

137137

backends/transforms/targets.bzl

+3
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,9 @@ def define_common_targets():
149149
runtime.python_library(
150150
name = "utils",
151151
srcs = ["utils.py"],
152+
visibility = [
153+
"//executorch/backends/...",
154+
],
152155
deps = [
153156
"//caffe2:torch",
154157
"//executorch/exir:lib",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.backends.transforms.utils import (
8+
create_constant_placeholder,
9+
delete_constant_placeholder,
10+
)
11+
from executorch.exir import to_edge
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
from torch.export import export
14+
from torch.export.graph_signature import InputKind
15+
16+
17+
class EmptyNetwork(torch.nn.Module):
18+
19+
def forward(self, x: torch.Tensor) -> torch.Tensor:
20+
return x
21+
22+
test_data: torch.Tensor = (torch.zeros(1),)
23+
24+
25+
def _test_create_delete(kind: InputKind, persistent_buffer: bool = None):
26+
"""
27+
Tests the utility functions create_constant_placeholder and delete_constant_placeholder
28+
"""
29+
30+
# Toy network with two nodes, input and output
31+
# The result should be 0 = 0
32+
module = EmptyNetwork()
33+
exported_program = export(module, args=module.test_data)
34+
exported_program = to_edge(exported_program).exported_program()
35+
graph = exported_program.graph_module.graph
36+
assert len(graph.nodes) == 2
37+
assert exported_program.module()(torch.zeros(1)) == 0
38+
assert len(exported_program.graph_signature.input_specs) == 1
39+
assert len(exported_program.state_dict) == 0
40+
assert len(exported_program.constants) == 0
41+
42+
const_name = "test_node"
43+
44+
# Create one const node with value 1 and add it to the input
45+
input_node = list(graph.nodes)[0]
46+
with graph.inserting_before(input_node):
47+
const_node = create_constant_placeholder(
48+
exp_program=exported_program,
49+
graph=graph,
50+
kind=kind,
51+
name=const_name,
52+
data=torch.ones(1),
53+
persistent_buffer=persistent_buffer,
54+
)
55+
assert "val" in const_node.meta
56+
57+
with graph.inserting_after(input_node):
58+
add_node = graph.create_node(
59+
"call_function",
60+
exir_ops.edge.aten.add.Tensor,
61+
args=(input_node, const_node),
62+
kwargs={},
63+
)
64+
65+
output_node = list(graph.nodes)[-1]
66+
output_node.replace_input_with(input_node, add_node)
67+
68+
# We should now have four nodes: test_node, input, add, output
69+
# The result should be 0 + 1 = 1
70+
assert exported_program.module()(torch.zeros(1)) == 1
71+
assert len(graph.nodes) == 4
72+
73+
if kind == InputKind.PARAMETER:
74+
assert const_name in exported_program.graph_signature.inputs_to_parameters
75+
assert const_name in exported_program.state_dict
76+
assert len(exported_program.constants) == 0
77+
elif kind == InputKind.BUFFER and persistent_buffer:
78+
assert const_name in exported_program.graph_signature.inputs_to_buffers
79+
assert const_name in exported_program.state_dict
80+
assert len(exported_program.constants) == 0
81+
elif kind == InputKind.BUFFER and not persistent_buffer:
82+
assert const_name in exported_program.graph_signature.inputs_to_buffers
83+
assert len(exported_program.state_dict) == 0
84+
assert const_name in exported_program.constants
85+
elif kind == InputKind.CONSTANT_TENSOR:
86+
assert (
87+
const_name
88+
in exported_program.graph_signature.inputs_to_lifted_tensor_constants
89+
)
90+
assert len(exported_program.state_dict) == 0
91+
assert const_name in exported_program.constants
92+
else:
93+
raise RuntimeError("Wrong input kind")
94+
95+
# Replacing the add op and using eliminate_dead_code() deletes the add op but not the input op
96+
output_node.replace_input_with(add_node, input_node)
97+
graph.eliminate_dead_code()
98+
assert len(graph.nodes) == 3
99+
100+
# Delete the input op manually
101+
# The result should again be 0 = 0
102+
delete_constant_placeholder(exported_program, const_node)
103+
assert exported_program.module()(torch.zeros(1)) == 0
104+
assert len(graph.nodes) == 2
105+
assert len(exported_program.graph_signature.input_specs) == 1
106+
assert len(exported_program.state_dict) == 0
107+
assert len(exported_program.constants) == 0
108+
109+
110+
def test_create_delete_parameter():
111+
_test_create_delete(InputKind.PARAMETER)
112+
113+
114+
def test_create_delete_persistent_buffer():
115+
_test_create_delete(InputKind.BUFFER, True)
116+
117+
118+
def test_create_delete_non_persistent_buffer():
119+
_test_create_delete(InputKind.BUFFER, False)
120+
121+
122+
def test_create_delete_constant_tensor():
123+
_test_create_delete(InputKind.CONSTANT_TENSOR)

0 commit comments

Comments
 (0)