Skip to content

Arm backend: Update fuse_batchnorm_pass to create new placeholders #8411

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Mar 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/_passes/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ python_library(
"//executorch/backends/transforms:replace_scalar_with_tensor",
"//executorch/backends/xnnpack/_passes:xnnpack_passes",
"//executorch/exir:lib",
"//executorch/backends/transforms:utils",
],
)
2 changes: 1 addition & 1 deletion backends/arm/_passes/arm_pass_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down
130 changes: 82 additions & 48 deletions backends/arm/_passes/fuse_batchnorm2d_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,15 @@
# pyre-unsafe

import torch
from executorch.backends.transforms.utils import (
create_constant_placeholder,
delete_constant_placeholder,
)
from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from torch._export.utils import get_buffer, get_param
from torch.export.graph_signature import InputKind
from torch.fx import Node
from torch.nn.utils.fusion import fuse_conv_bn_weights

Expand All @@ -23,7 +28,7 @@ def __init__(self, exported_program: ExportedProgram):
self.exported_program = exported_program
super().__init__()

def is_fuseable_conv_bn(self, node: Node):
def is_fuseable_conv_bn(self, node: Node) -> bool:
"""Returns True if node is a batchnorm that can be fused into
a parent convolution."""
if node.op != "call_function":
Expand All @@ -44,15 +49,19 @@ def is_fuseable_conv_bn(self, node: Node):
# Since we change the output of the conv, fuse only if it has single user.
if len(conv.users) > 1:
return False
# For similar reasons, only fuse if conv parameters have single user.
if len(conv.all_input_nodes[1].users) > 1:
return False
if len(conv.all_input_nodes) > 2 and len(conv.all_input_nodes[2].users) > 1:
return False
return True

def get_bias_name(self, conv_weight_node: Node, conv_bias_node: Node) -> str:
if conv_bias_node:
return conv_bias_node.name + "_fused_bn"
elif "weight" in conv_weight_node.name:
return conv_weight_node.name.replace("weight", "bias") + "_fused_bn"
else:
return conv_weight_node.name + "_bias_fused_bn"

def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901
modified = False
constant_placeholders_to_delete = set()
for node in graph_module.graph.nodes:
if not self.is_fuseable_conv_bn(node):
continue
Expand All @@ -64,68 +73,93 @@ def get_param_or_none(arg) -> torch.nn.Parameter | None:
)

# Get weight, bias, mean, var and epsilon from the batchnorm
bn = node
conv, bn_weight_node, bn_bias_node, bn_mean_node, bn_var_node = bn.args[0:5]
bn_weight = get_param_or_none(bn_weight_node)
bn_bias = get_param_or_none(bn_bias_node)

running_mean = get_buffer(self.exported_program, bn_mean_node)
running_var = get_buffer(self.exported_program, bn_var_node)
if running_mean is None or running_var is None:
bn_node = node
conv, bn_weight_node, bn_bias_node, bn_mean_node, bn_var_node = (
bn_node.args[0:5]
)
bn_weight_tensor = get_param_or_none(bn_weight_node)
bn_bias_tensor = get_param_or_none(bn_bias_node)
bn_mean_tensor = get_buffer(self.exported_program, bn_mean_node)
bn_var_tensor = get_buffer(self.exported_program, bn_var_node)
if bn_mean_tensor is None or bn_var_tensor is None:
raise ValueError(
"Parameters running_mean and running_var of batchnorm can't be None."
)
epsilon = bn.args[-1]
epsilon = bn_node.args[-1]

# Get weight and bias from conv
conv_weight_node, conv_bias_node = conv.args[1:3]
conv_weight = get_param(self.exported_program, conv_weight_node)
conv_bias = get_param_or_none(conv_bias_node)
if conv_weight is None:
conv_weight_tensor = get_param(self.exported_program, conv_weight_node)
conv_bias_tensor = get_param_or_none(conv_bias_node)
if conv_weight_tensor is None:
raise ValueError("Parameter weight of convolution can't be None.")

# Compute conv parameters folded with batchnorm
fused_conv_weight, fused_conv_bias = fuse_conv_bn_weights(
conv_weight,
conv_bias,
running_mean,
running_var,
conv_weight_tensor,
conv_bias_tensor,
bn_mean_tensor,
bn_var_tensor,
epsilon,
bn_weight,
bn_bias,
bn_weight_tensor,
bn_bias_tensor,
)

# Set the conv parameters to fused value
def try_set_param(
param_node: Node | None, param_value: torch.nn.Parameter
) -> bool:
"""set_param but check if param_node is None first. Return True if param was set successfully, otherwise False."""
if param_node is not None:
param_name = (
self.exported_program.graph_signature.inputs_to_parameters[
param_node.name
]
# Create fused weights and bias to conv and replace conv args
with graph_module.graph.inserting_before(conv_weight_node):
fused_conv_weight_node = create_constant_placeholder(
exp_program=self.exported_program,
graph=graph_module.graph,
kind=InputKind.PARAMETER,
name=conv_weight_node.name + "_fused_bn",
data=fused_conv_weight,
)

if fused_conv_bias is not None:
fused_conv_bias_node = create_constant_placeholder(
exp_program=self.exported_program,
graph=graph_module.graph,
kind=InputKind.PARAMETER,
name=self.get_bias_name(conv_weight_node, conv_bias_node),
data=fused_conv_bias,
)
self.exported_program.state_dict[param_name] = param_value
return True
return False
else:
fused_conv_bias_node = None

conv.args = (
conv.args[0],
fused_conv_weight_node,
fused_conv_bias_node,
*conv.args[3:],
)

try_set_param(conv_weight_node, fused_conv_weight)
if not try_set_param(conv_bias_node, fused_conv_bias) and try_set_param(
bn_bias_node, fused_conv_bias
):
# pyre-ignore[60]
# Conv didn't have bias but batchnorm did, steal bias from batchnorm.
conv_args = (*conv.args[0:2], bn_bias_node, *conv.args[3:])
conv.args = conv_args

# Erasing nodes is handled by dead-code elimination.
for user in bn.users:
# Erasing batch-norm nodes is handled by dead-code elimination. After that we may remove their constant placeholder inputs
for user in bn_node.users:
user.replace_all_uses_with(conv)

constant_placeholders_to_delete.update(
[
bn_weight_node,
bn_bias_node,
bn_mean_node,
bn_var_node,
conv_weight_node,
conv_bias_node,
]
)
modified = True

if modified:
graph_module.graph.eliminate_dead_code()
for constant_placeholder in constant_placeholders_to_delete:
if (constant_placeholder is not None) and (
len(constant_placeholder.users) == 0
):
delete_constant_placeholder(
self.exported_program, constant_placeholder
)

graph_module.recompile()
graph_module = super().call(graph_module).graph_module

return PassResult(graph_module=graph_module, modified=modified)
8 changes: 4 additions & 4 deletions backends/arm/test/passes/test_fuse_batchnorm_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,13 @@ def forward(self, x):
return x


class MergeNoBN(torch.nn.Module):
class MergeMultipleUsersBN(torch.nn.Module):
ops_before_pass = {
"executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 2,
"executorch_exir_dialects_edge__ops_aten_convolution_default": 3,
}
ops_after_pass = {
"executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 2,
"executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 1,
"executorch_exir_dialects_edge__ops_aten_convolution_default": 3,
}

Expand Down Expand Up @@ -122,7 +122,7 @@ def forward(self, x):
z = self.conv2d2(x)
a = self.batch_norm2d(
y
) # Can't be fused since paramters of conv2d2 have multiple users.
) # Can be fused despite paramters of conv2d2 having multiple users.

return z, a

Expand All @@ -131,7 +131,7 @@ def forward(self, x):
"merge_one_of_two_bn_affine": MergeOneOfTwoBN(True),
"merge_one_of_two_bn": MergeOneOfTwoBN(False),
"merge_two_of_two_bn_affine": MergeTwosOfTwoBN(True),
"merge_no_bn_affine": MergeNoBN(True),
"merge_multiple_users_bn_affine": MergeMultipleUsersBN(True),
}


Expand Down
3 changes: 3 additions & 0 deletions backends/transforms/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ def define_common_targets():
runtime.python_library(
name = "utils",
srcs = ["utils.py"],
visibility = [
"//executorch/backends/...",
],
deps = [
"//caffe2:torch",
"//executorch/exir:lib",
Expand Down
123 changes: 123 additions & 0 deletions backends/transforms/test/test_create_delete_constant_placeholder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.backends.transforms.utils import (
create_constant_placeholder,
delete_constant_placeholder,
)
from executorch.exir import to_edge
from executorch.exir.dialects._ops import ops as exir_ops
from torch.export import export
from torch.export.graph_signature import InputKind


class EmptyNetwork(torch.nn.Module):

def forward(self, x: torch.Tensor) -> torch.Tensor:
return x

test_data: torch.Tensor = (torch.zeros(1),)


def _test_create_delete(kind: InputKind, persistent_buffer: bool = None):
"""
Tests the utility functions create_constant_placeholder and delete_constant_placeholder
"""

# Toy network with two nodes, input and output
# The result should be 0 = 0
module = EmptyNetwork()
exported_program = export(module, args=module.test_data)
exported_program = to_edge(exported_program).exported_program()
graph = exported_program.graph_module.graph
assert len(graph.nodes) == 2
assert exported_program.module()(torch.zeros(1)) == 0
assert len(exported_program.graph_signature.input_specs) == 1
assert len(exported_program.state_dict) == 0
assert len(exported_program.constants) == 0

const_name = "test_node"

# Create one const node with value 1 and add it to the input
input_node = list(graph.nodes)[0]
with graph.inserting_before(input_node):
const_node = create_constant_placeholder(
exp_program=exported_program,
graph=graph,
kind=kind,
name=const_name,
data=torch.ones(1),
persistent_buffer=persistent_buffer,
)
assert "val" in const_node.meta

with graph.inserting_after(input_node):
add_node = graph.create_node(
"call_function",
exir_ops.edge.aten.add.Tensor,
args=(input_node, const_node),
kwargs={},
)

output_node = list(graph.nodes)[-1]
output_node.replace_input_with(input_node, add_node)

# We should now have four nodes: test_node, input, add, output
# The result should be 0 + 1 = 1
assert exported_program.module()(torch.zeros(1)) == 1
assert len(graph.nodes) == 4

if kind == InputKind.PARAMETER:
assert const_name in exported_program.graph_signature.inputs_to_parameters
assert const_name in exported_program.state_dict
assert len(exported_program.constants) == 0
elif kind == InputKind.BUFFER and persistent_buffer:
assert const_name in exported_program.graph_signature.inputs_to_buffers
assert const_name in exported_program.state_dict
assert len(exported_program.constants) == 0
elif kind == InputKind.BUFFER and not persistent_buffer:
assert const_name in exported_program.graph_signature.inputs_to_buffers
assert len(exported_program.state_dict) == 0
assert const_name in exported_program.constants
elif kind == InputKind.CONSTANT_TENSOR:
assert (
const_name
in exported_program.graph_signature.inputs_to_lifted_tensor_constants
)
assert len(exported_program.state_dict) == 0
assert const_name in exported_program.constants
else:
raise RuntimeError("Wrong input kind")

# Replacing the add op and using eliminate_dead_code() deletes the add op but not the input op
output_node.replace_input_with(add_node, input_node)
graph.eliminate_dead_code()
assert len(graph.nodes) == 3

# Delete the input op manually
# The result should again be 0 = 0
delete_constant_placeholder(exported_program, const_node)
assert exported_program.module()(torch.zeros(1)) == 0
assert len(graph.nodes) == 2
assert len(exported_program.graph_signature.input_specs) == 1
assert len(exported_program.state_dict) == 0
assert len(exported_program.constants) == 0


def test_create_delete_parameter():
_test_create_delete(InputKind.PARAMETER)


def test_create_delete_persistent_buffer():
_test_create_delete(InputKind.BUFFER, True)


def test_create_delete_non_persistent_buffer():
_test_create_delete(InputKind.BUFFER, False)


def test_create_delete_constant_tensor():
_test_create_delete(InputKind.CONSTANT_TENSOR)
Loading
Loading