Skip to content

TorchDynamo: enable convolution and batchnorm folding for inference path #87435

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1334,6 +1334,62 @@ def fn(a, b):
check_lowp=False,
)

# For gpu path, there has a accurcy issue,
@unittest.skipIf(HAS_CUDA, "only support cpu conv bn test")
def test_conv_bn_fuse(self):
input_shapes = {1: (112,), 2: (112, 112), 3: (55, 55, 55)}
conv_modules = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
bn_modules = {
1: torch.nn.BatchNorm1d,
2: torch.nn.BatchNorm2d,
3: torch.nn.BatchNorm3d,
}
options = itertools.product(
[1, 2, 3],
[True, False],
[1, 3],
[1, 2],
[1, 4],
)

for (
dim,
bias,
kernel_size,
dilation,
groups,
) in options:
oC = 32 * groups
iC = 3 * groups
x_shape = (1, iC) + input_shapes[dim]
mod = torch.nn.Sequential(
conv_modules[dim](
iC,
oC,
kernel_size=kernel_size,
dilation=dilation,
groups=groups,
bias=bias,
),
bn_modules[dim](oC),
).eval()
test_memory_format = [torch.contiguous_format]
# TODO: GPU path doesn't support channels_last now.
if not HAS_CUDA and dim > 1:
channels_last = (
torch.channels_last if dim == 2 else torch.channels_last_3d
)
test_memory_format.append(channels_last)
for memory_format in test_memory_format:
v = torch.randn(x_shape, dtype=torch.float32).to(
memory_format=memory_format
)
with torch.no_grad():
self.common(
mod,
(v,),
)

# For gpu path, there has a accurcy issue,
# see https://github.com/pytorch/pytorch/issues/87745.
@unittest.skipIf(HAS_CUDA, "only support cpu conv2d unary test")
Expand Down
34 changes: 34 additions & 0 deletions torch/_inductor/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torch.fx.passes.shape_prop import ShapeProp
from torch.nn import functional as F
from torch.nn.modules.utils import _pair
from torch.nn.utils.fusion import fuse_conv_bn_eval
from torch.overrides import TorchFunctionMode

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -310,6 +311,7 @@ def fuse_fx(gm: torch.fx.GraphModule, example_inputs):
)
if not is_cpu:
return gm
gm = fuse_conv_bn(gm)
# For binary fusion, we need to check inputs info to make sure
# the binary inputs have same tensor info(device, dtype, and layout).
ShapeProp(gm).propagate(*example_inputs)
Expand All @@ -319,6 +321,38 @@ def fuse_fx(gm: torch.fx.GraphModule, example_inputs):
return gm


def fuse_conv_bn(gm: torch.fx.GraphModule, inplace=False):
"""
Fuses Convolution/BN layers for inference purposes.
"""
patterns = [
(torch.nn.Conv1d, torch.nn.BatchNorm1d),
(torch.nn.Conv2d, torch.nn.BatchNorm2d),
(torch.nn.Conv3d, torch.nn.BatchNorm3d),
]
modules = dict(gm.named_modules())

for pattern in patterns:
for node in gm.graph.nodes:
if matches_module_pattern(pattern, node, modules):
if len(node.args[0].users) > 1: # Output of conv is used by other nodes
continue
conv = modules[node.args[0].target]
bn = modules[node.target]
eval_mode = all(not n.training for n in [conv, bn])
if not eval_mode:
continue
if not bn.track_running_stats:
continue
fused_conv = fuse_conv_bn_eval(conv, bn)
replace_node_module(node.args[0], modules, fused_conv)
node.replace_all_uses_with(node.args[0])
gm.graph.erase_node(node)
gm.graph.lint()
gm.recompile()
return gm


def fuse_unary(gm: torch.fx.GraphModule):
modules = dict(gm.named_modules())

Expand Down