Skip to content

Commit b320626

Browse files
XiaobingSuperpytorchmergebot
authored andcommitted
TorchDynamo: enable convolution and batchnorm folding for inference path (#87435)
Pull Request resolved: #87435 Approved by: https://github.com/jgong5, https://github.com/jansel
1 parent fbd08fb commit b320626

File tree

2 files changed

+90
-0
lines changed

2 files changed

+90
-0
lines changed

test/inductor/test_torchinductor.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1334,6 +1334,62 @@ def fn(a, b):
13341334
check_lowp=False,
13351335
)
13361336

1337+
# For gpu path, there has a accurcy issue,
1338+
@unittest.skipIf(HAS_CUDA, "only support cpu conv bn test")
1339+
def test_conv_bn_fuse(self):
1340+
input_shapes = {1: (112,), 2: (112, 112), 3: (55, 55, 55)}
1341+
conv_modules = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
1342+
bn_modules = {
1343+
1: torch.nn.BatchNorm1d,
1344+
2: torch.nn.BatchNorm2d,
1345+
3: torch.nn.BatchNorm3d,
1346+
}
1347+
options = itertools.product(
1348+
[1, 2, 3],
1349+
[True, False],
1350+
[1, 3],
1351+
[1, 2],
1352+
[1, 4],
1353+
)
1354+
1355+
for (
1356+
dim,
1357+
bias,
1358+
kernel_size,
1359+
dilation,
1360+
groups,
1361+
) in options:
1362+
oC = 32 * groups
1363+
iC = 3 * groups
1364+
x_shape = (1, iC) + input_shapes[dim]
1365+
mod = torch.nn.Sequential(
1366+
conv_modules[dim](
1367+
iC,
1368+
oC,
1369+
kernel_size=kernel_size,
1370+
dilation=dilation,
1371+
groups=groups,
1372+
bias=bias,
1373+
),
1374+
bn_modules[dim](oC),
1375+
).eval()
1376+
test_memory_format = [torch.contiguous_format]
1377+
# TODO: GPU path doesn't support channels_last now.
1378+
if not HAS_CUDA and dim > 1:
1379+
channels_last = (
1380+
torch.channels_last if dim == 2 else torch.channels_last_3d
1381+
)
1382+
test_memory_format.append(channels_last)
1383+
for memory_format in test_memory_format:
1384+
v = torch.randn(x_shape, dtype=torch.float32).to(
1385+
memory_format=memory_format
1386+
)
1387+
with torch.no_grad():
1388+
self.common(
1389+
mod,
1390+
(v,),
1391+
)
1392+
13371393
# For gpu path, there has a accurcy issue,
13381394
# see https://github.com/pytorch/pytorch/issues/87745.
13391395
@unittest.skipIf(HAS_CUDA, "only support cpu conv2d unary test")

torch/_inductor/overrides.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from torch.fx.passes.shape_prop import ShapeProp
1717
from torch.nn import functional as F
1818
from torch.nn.modules.utils import _pair
19+
from torch.nn.utils.fusion import fuse_conv_bn_eval
1920
from torch.overrides import TorchFunctionMode
2021

2122
log = logging.getLogger(__name__)
@@ -310,6 +311,7 @@ def fuse_fx(gm: torch.fx.GraphModule, example_inputs):
310311
)
311312
if not is_cpu:
312313
return gm
314+
gm = fuse_conv_bn(gm)
313315
# For binary fusion, we need to check inputs info to make sure
314316
# the binary inputs have same tensor info(device, dtype, and layout).
315317
ShapeProp(gm).propagate(*example_inputs)
@@ -319,6 +321,38 @@ def fuse_fx(gm: torch.fx.GraphModule, example_inputs):
319321
return gm
320322

321323

324+
def fuse_conv_bn(gm: torch.fx.GraphModule, inplace=False):
325+
"""
326+
Fuses Convolution/BN layers for inference purposes.
327+
"""
328+
patterns = [
329+
(torch.nn.Conv1d, torch.nn.BatchNorm1d),
330+
(torch.nn.Conv2d, torch.nn.BatchNorm2d),
331+
(torch.nn.Conv3d, torch.nn.BatchNorm3d),
332+
]
333+
modules = dict(gm.named_modules())
334+
335+
for pattern in patterns:
336+
for node in gm.graph.nodes:
337+
if matches_module_pattern(pattern, node, modules):
338+
if len(node.args[0].users) > 1: # Output of conv is used by other nodes
339+
continue
340+
conv = modules[node.args[0].target]
341+
bn = modules[node.target]
342+
eval_mode = all(not n.training for n in [conv, bn])
343+
if not eval_mode:
344+
continue
345+
if not bn.track_running_stats:
346+
continue
347+
fused_conv = fuse_conv_bn_eval(conv, bn)
348+
replace_node_module(node.args[0], modules, fused_conv)
349+
node.replace_all_uses_with(node.args[0])
350+
gm.graph.erase_node(node)
351+
gm.graph.lint()
352+
gm.recompile()
353+
return gm
354+
355+
322356
def fuse_unary(gm: torch.fx.GraphModule):
323357
modules = dict(gm.named_modules())
324358

0 commit comments

Comments
 (0)