Skip to content

[XNNPACK] Add support for Linear fused BatchNorm #11805

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions backends/xnnpack/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
from executorch.backends.xnnpack._passes.fuse_batch_norm_with_conv import (
FuseBatchNormWithConvPass,
)
from executorch.backends.xnnpack._passes.fuse_batch_norm_with_linear import (
FuseBatchNormWithLinearPass,
)
from executorch.backends.xnnpack._passes.prelu_reshape_pass import PReLUReshapePass
from executorch.backends.xnnpack._passes.tag_implicit_q_dq_pass import (
TagImplicitQDqPass,
Expand Down Expand Up @@ -64,6 +67,7 @@ def __init__(
ConvertToSDPAPass,
ConstPropPass,
FuseBatchNormWithConvPass,
FuseBatchNormWithLinearPass,
FuseActivationPass,
DecomposeConcatenate,
RemoveGetItemPass,
Expand Down
187 changes: 187 additions & 0 deletions backends/xnnpack/_passes/fuse_batch_norm_with_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import operator

import torch
from executorch.backends.transforms.utils import (
create_constant_placeholder,
delete_constant_placeholder,
)

from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass

from executorch.backends.xnnpack.utils.utils import (
get_param_tensor,
get_tensor_name,
is_param_node,
)
from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import PassResult
from torch.export.graph_signature import InputKind

from torch.nn.utils.fusion import fuse_linear_bn_weights


class FuseBatchNormWithLinearPass(XNNPACKPass):
def call(self, graph_module: torch.fx.GraphModule):
graph = graph_module.graph
constant_placeholders_to_delete = set()
for linear in graph.nodes:
# We want to discover a chain of linear -> batch_norm.
# Only proceed if the current node is a linear node, and has a single
# user/successor.
if (
linear.target != exir_ops.edge.aten.linear.default
or len(linear.users) != 1
):
continue

# Single user of the linear op must be batch_norm. If not, bail.
bn = list(linear.users.keys())[0]
if (
bn.target != exir_ops.edge.aten.native_batch_norm.default
and bn.target
!= exir_ops.edge.aten._native_batch_norm_legit_no_training.default
):
continue

if not self.can_fuse(linear, bn, self.exported_program):
continue

# Get the parameters from linear op
assert len(linear.args) == 3

linear_weight = get_param_tensor(self.exported_program, linear.args[1])
linear_weight_name = get_tensor_name(self.exported_program, linear.args[1])
assert linear_weight is not None

linear_bias = get_param_tensor(self.exported_program, linear.args[2])
linear_bias_name = get_tensor_name(self.exported_program, linear.args[2])

# Get the parameters from the batchnorm op
assert (
bn.target == exir_ops.edge.aten.native_batch_norm.default
and len(bn.args) == 8
) or (
bn.target
== exir_ops.edge.aten._native_batch_norm_legit_no_training.default
and len(bn.args) == 7
)
bn_weight = get_param_tensor(self.exported_program, bn.args[1])
bn_bias = get_param_tensor(self.exported_program, bn.args[2])

running_mean = get_param_tensor(self.exported_program, bn.args[3])
assert running_mean is not None

running_var = get_param_tensor(self.exported_program, bn.args[4])
assert running_var is not None

# args[7] for native_batch_norm, but args[6] for
# _native_batch_norm_legit_no_training (which doesn't have training
# as an arg)
eps = bn.args[-1]

fused_weight, fused_bias = fuse_linear_bn_weights(
linear_weight,
linear_bias,
running_mean,
running_var,
eps,
bn_weight,
bn_bias,
)
fused_weight_name = (linear_weight_name + "_fused_bn").replace(".", "_")
if linear_bias_name == "":
fused_bias_name = (linear_weight_name + "_bias_fused_bn").replace(
".", "_"
)
else:
fused_bias_name = (linear_bias_name + "_fused_bn").replace(".", "_")

# Modify the graph by updating the weight and bias of the linear op
# with the fused weight and bias params, and replacing all the users
# of getitem(batchnorm) with the linear op.

with graph.inserting_before(linear.args[1]):
fused_linear_weight_node = create_constant_placeholder(
exp_program=self.exported_program,
graph=graph_module.graph,
kind=InputKind.PARAMETER,
name=fused_weight_name,
data=fused_weight,
)
if fused_bias is not None:
fused_linear_bias_node = create_constant_placeholder(
exp_program=self.exported_program,
graph=graph_module.graph,
kind=InputKind.PARAMETER,
name=fused_bias_name,
data=fused_bias,
)
else:
fused_linear_bias_node = None

linear.args = (
linear.args[0],
fused_linear_weight_node,
fused_linear_bias_node,
)

# Remove any use of batchnorm from the graph
for user in bn.users.copy():
assert user.target == operator.getitem
user.replace_all_uses_with(linear)
graph.erase_node(user)

graph.erase_node(bn)
constant_placeholders_to_delete.update(linear.args[1:3] + bn.args[1:5])

if len(constant_placeholders_to_delete) > 0:
graph_module.graph.eliminate_dead_code()
for node in constant_placeholders_to_delete:
if (node is not None) and (len(node.users) == 0):
delete_constant_placeholder(self.exported_program, node)

graph_module.recompile()
# To Regenerate metadata and shape information, retrace module
graph_module = super().call(graph_module).graph_module

return PassResult(graph_module, True)

@staticmethod
def can_fuse(
linear: torch.fx.Node,
bn: torch.fx.Node,
program: ExportedProgram,
) -> bool:
"""
Determine whether a batch norm node can be fused with a preceding linear node.
"""

# All the users of the batchnorm node must be getitem ops. batchnorm
# returns a 3-element tuple. Each user must only access the first
# element of the tuple.
if [
(user.target == operator.getitem and user.args[1] == 0) for user in bn.users
].count(False):
return False

linear_weights = linear.args[1]
bn_weights = bn.args[1]

# Check that the weights for linear and batchnorm are both params
if not isinstance(linear_weights, torch.fx.Node) or not isinstance(
bn_weights, torch.fx.Node
):
return False

if [
is_param_node(program, node) for node in {linear_weights, bn_weights}
].count(False):
return False
return True
19 changes: 12 additions & 7 deletions backends/xnnpack/partition/config/node_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from executorch.backends.xnnpack._passes.fuse_batch_norm_with_conv import (
FuseBatchNormWithConvPass,
)
from executorch.backends.xnnpack._passes.fuse_batch_norm_with_linear import (
FuseBatchNormWithLinearPass,
)
from executorch.backends.xnnpack.partition.config.xnnpack_config import (
ConfigPrecisionType,
XNNPartitionerConfig,
Expand All @@ -35,20 +38,22 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
return False

bn = node
conv = node.all_input_nodes[0]
input_node = node.all_input_nodes[0]

if conv.op != "call_function":
if input_node.op != "call_function":
return False

conv_name = format_target_name(conv.target.__name__) # pyre-ignore
input_name = format_target_name(input_node.target.__name__) # pyre-ignore

if conv_name not in ["convolution.default"]:
why(node, f"Invalid conv target {conv_name}")
if input_name not in ["convolution.default", "linear.default"]:
why(node, f"Invalid input target {input_name.split('.')[0]}")
return False

can_fuse = FuseBatchNormWithConvPass.can_fuse(conv, bn, ep)
can_fuse = FuseBatchNormWithConvPass.can_fuse(
input_node, bn, ep
) or FuseBatchNormWithLinearPass.can_fuse(input_node, bn, ep)
if not can_fuse:
why(node, "BatchNorm cannot be fused with Convolution")
why(node, f"BatchNorm cannot be fused with {input_name.split('.')[0]}")
return False

return True
Expand Down
65 changes: 59 additions & 6 deletions backends/xnnpack/test/passes/test_batch_norm_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,15 @@
from executorch.backends.xnnpack._passes.fuse_batch_norm_with_conv import (
FuseBatchNormWithConvPass,
)
from executorch.backends.xnnpack._passes.fuse_batch_norm_with_linear import (
FuseBatchNormWithLinearPass,
)
from executorch.backends.xnnpack.test.tester import RunPasses, Tester


class TestBatchNormFusion(unittest.TestCase):
PassStage = RunPasses([FuseBatchNormWithConvPass])
ConvPassStage = RunPasses([FuseBatchNormWithConvPass])
LinearPassStage = RunPasses([FuseBatchNormWithLinearPass])
bn_name = "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default"

def setUp(self):
Expand All @@ -42,7 +46,22 @@ def forward(self, x):
y = y + y
return self.bn(y)

def test_fp32_batch_norm_fusion(self):
class ModelLinearBN(torch.nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
op = torch.nn.Linear
self.linear = op(in_features, out_features)
self.bn = torch.nn.BatchNorm1d(out_features)
self.forward(torch.randn(2, 2) * 2 + 2) # update the BN stats

def forward(self, x):
y = self.linear(x)
y = self.bn(y)
y = self.linear(y)
y = y + y
return self.bn(y)

def test_fp32_conv_batch_norm_fusion(self):
for transpose in [False, True]:
(
Tester(
Expand All @@ -51,12 +70,12 @@ def test_fp32_batch_norm_fusion(self):
)
.export()
.to_edge()
.run_passes(self.PassStage)
.run_passes(self.ConvPassStage)
.check_count({self.bn_name: 1})
.run_method_and_compare_outputs()
)

def test_q8_batch_norm_fusion(self):
def test_q8_conv_batch_norm_fusion(self):
for transpose in [False, True]:
(
Tester(
Expand All @@ -66,12 +85,12 @@ def test_q8_batch_norm_fusion(self):
.quantize()
.export()
.to_edge()
.run_passes(self.PassStage)
.run_passes(self.ConvPassStage)
.check_count({self.bn_name: 1})
.run_method_and_compare_outputs()
)

def test_fp32_batch_norm_no_fusion_doesnt_partition(self):
def test_fp32_conv_batch_norm_no_fusion_doesnt_partition(self):
"""
We do not currently support standalone batch norms (i.e. batch norms that are
not fused with a conv). This is planned, but until implemented, this test ensures
Expand All @@ -94,3 +113,37 @@ def forward(self, x):
.partition()
.check_count({self.bn_name: 1})
)

def test_fp32_linear_batch_norm_fusion(self):
(
Tester(
self.ModelLinearBN(2, 2).eval(),
(torch.randn(2, 2),),
)
.export()
.to_edge_transform_and_lower()
.check_count({self.bn_name: 1})
.run_method_and_compare_outputs()
)

def test_fp32_linear_batch_norm_no_fusion_doesnt_partition(self):
"""
We do not currently support standalone batch norms (i.e. batch norms that are
not fused with a linear). This is planned, but until implemented, this test ensures
that we do not partition the standalone batch norm and then fail to lower.
"""

class BN(torch.nn.Module):
def __init__(self):
super().__init__()
self.bn = torch.nn.BatchNorm1d(2)

def forward(self, x):
return self.bn(x)

(
Tester(BN(), (torch.randn(2, 2),))
.export()
.to_edge_transform_and_lower()
.check_count({self.bn_name: 1})
)
Loading