Skip to content

Commit d6fe983

Browse files
chunyuan-wpytorchmergebot
authored andcommitted
[inductor] add conv_transpose2d unary fusion for cpu in inference mode (#90265)
An FX transformation is added to fuse ConvTranspose2d with eltwise OPs in torchinductor for CPU in inference mode, following the implementation in #87063. The fusion OP is implemented in #90264 and will be treated as an extern kernel call in torchinductor. The fusion of ConvTranspose2d with the below OPs is supported: - relu - sigmoid - tanh - hardswish - leaky_relu - hardtanh - gelu Pull Request resolved: #90265 Approved by: https://github.com/jgong5, https://github.com/jansel
1 parent 85698d0 commit d6fe983

File tree

4 files changed

+215
-2
lines changed

4 files changed

+215
-2
lines changed

test/inductor/test_torchinductor.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1856,6 +1856,53 @@ def forward(self, x, y):
18561856
with torch.no_grad():
18571857
self.common(mod, (v, other), atol=2e-3, rtol=0.016)
18581858

1859+
@unittest.skipIf(HAS_CUDA, "only support cpu conv_transpose2d unary test")
1860+
def test_conv_transpose2d_unary(self):
1861+
test_memory_format = [torch.contiguous_format, torch.channels_last]
1862+
options = itertools.product(
1863+
unary_list,
1864+
[True, False],
1865+
[1, 3],
1866+
[1, 2],
1867+
[1, 4],
1868+
[0, 1],
1869+
test_memory_format,
1870+
)
1871+
1872+
for (
1873+
unary_fn,
1874+
bias,
1875+
kernel_size,
1876+
dilation,
1877+
groups,
1878+
padding,
1879+
memory_format,
1880+
) in options:
1881+
oC = 32 * groups
1882+
iC = 3 * groups
1883+
x_shape = (1, iC, 28, 28)
1884+
mod = torch.nn.Sequential(
1885+
torch.nn.ConvTranspose2d(
1886+
iC,
1887+
oC,
1888+
kernel_size=kernel_size,
1889+
padding=padding,
1890+
dilation=dilation,
1891+
groups=groups,
1892+
bias=bias,
1893+
),
1894+
unary_fn,
1895+
).eval()
1896+
1897+
v = torch.randn(x_shape, dtype=torch.float32).to(
1898+
memory_format=memory_format
1899+
)
1900+
with torch.no_grad():
1901+
self.common(
1902+
mod,
1903+
(v,),
1904+
)
1905+
18591906
def test_gather1(self):
18601907
def fn(a, b):
18611908
return (

torch/_inductor/ir.py

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3417,6 +3417,8 @@ def _prepare_convolution_fusion_create(
34173417
stride_: List[int],
34183418
dilation_: List[int],
34193419
groups: int,
3420+
transposed: bool = False,
3421+
output_padding_: List[int] = None,
34203422
):
34213423
"""
34223424
This function is a helper function to prepare inputs, layout and constant args
@@ -3429,6 +3431,8 @@ def _prepare_convolution_fusion_create(
34293431
padding = tuple(padding_)
34303432
dilation = tuple(dilation_)
34313433
assert isinstance(groups, int)
3434+
output_padding = tuple(output_padding_) if output_padding_ else (0, 0)
3435+
34323436
with V.graph.fake_mode:
34333437
x_fake = ir_node_to_tensor(x, guard_shape=True)
34343438
weight_fake = ir_node_to_tensor(weight, guard_shape=True)
@@ -3442,8 +3446,8 @@ def _prepare_convolution_fusion_create(
34423446
stride,
34433447
padding,
34443448
dilation,
3445-
False,
3446-
[0, 0],
3449+
transposed,
3450+
output_padding,
34473451
groups,
34483452
)
34493453
output_size = output.size()
@@ -3462,6 +3466,8 @@ def _prepare_convolution_fusion_create(
34623466
output_stride,
34633467
)
34643468
constant_args = [padding, stride, dilation, groups]
3469+
if transposed:
3470+
constant_args.insert(1, output_padding)
34653471

34663472
if bias is not None:
34673473
inputs.append(bias)
@@ -3796,6 +3802,62 @@ def apply_constraint(self):
37963802
pass
37973803

37983804

3805+
class ConvolutionTransposeUnary(ExternKernelAlloc):
3806+
kernel = "torch.ops.mkldnn._convolution_transpose_pointwise"
3807+
3808+
def __init__(
3809+
self,
3810+
layout,
3811+
inputs,
3812+
constant_args=(),
3813+
kernel="torch.ops.mkldnn._convolution_transpose_pointwise",
3814+
):
3815+
super().__init__(layout, inputs, constant_args)
3816+
self.kernel = kernel
3817+
3818+
def codegen(self, wrapper):
3819+
wrapper.writeline(
3820+
f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})"
3821+
)
3822+
3823+
@classmethod
3824+
def create(
3825+
cls,
3826+
x: "TensorBox",
3827+
weight: "TensorBox",
3828+
bias: "TensorBox",
3829+
padding_: List[int],
3830+
output_padding_: List[int],
3831+
stride_: List[int],
3832+
dilation_: List[int],
3833+
groups_: int,
3834+
attr,
3835+
scalars,
3836+
algorithm,
3837+
):
3838+
kernel = "torch.ops.mkldnn._convolution_transpose_pointwise"
3839+
transposed = True
3840+
(inputs, constant_args, kernel_layout, _,) = _prepare_convolution_fusion_create(
3841+
cls,
3842+
x,
3843+
weight,
3844+
bias,
3845+
padding_,
3846+
stride_,
3847+
dilation_,
3848+
groups_,
3849+
transposed,
3850+
output_padding_,
3851+
)
3852+
constant_args = constant_args + [attr, scalars, algorithm]
3853+
return ConvolutionTransposeUnary(
3854+
layout=kernel_layout,
3855+
inputs=inputs,
3856+
constant_args=constant_args,
3857+
kernel=kernel,
3858+
)
3859+
3860+
37993861
@dataclasses.dataclass
38003862
class MutableBox(IRNode):
38013863
"""

torch/_inductor/lowering.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -962,6 +962,36 @@ def linear_unary(
962962
def linear_binary(x: TensorBox, y: TensorBox, w: TensorBox, b: TensorBox, attr):
963963
return TensorBox.create(ir.LinearBinary.create(x, y, w, b, attr))
964964

965+
@register_lowering(torch.ops.mkldnn._convolution_transpose_pointwise)
966+
def convolution_transpose_unary(
967+
x: TensorBox,
968+
weight: TensorBox,
969+
bias: TensorBox,
970+
padding,
971+
output_padding,
972+
stride,
973+
dilation,
974+
groups,
975+
attr,
976+
scalars,
977+
algorithm,
978+
):
979+
return TensorBox.create(
980+
ir.ConvolutionTransposeUnary.create(
981+
x,
982+
weight,
983+
bias,
984+
padding,
985+
output_padding,
986+
stride,
987+
dilation,
988+
groups,
989+
attr,
990+
scalars,
991+
algorithm,
992+
)
993+
)
994+
965995
if torch._C.has_mkl:
966996

967997
@register_lowering(torch.ops.mkl._mkl_linear)

torch/_inductor/overrides.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,69 @@ def forward(self, input, other):
406406
return y
407407

408408

409+
class ConvTransposeUnary2d(nn.ConvTranspose2d):
410+
def __init__(
411+
self,
412+
conv_transpose: nn.Module,
413+
unary: nn.Module,
414+
):
415+
super(ConvTransposeUnary2d, self).__init__(
416+
conv_transpose.in_channels,
417+
conv_transpose.out_channels,
418+
conv_transpose.kernel_size,
419+
conv_transpose.stride,
420+
conv_transpose.padding,
421+
conv_transpose.output_padding,
422+
conv_transpose.groups,
423+
conv_transpose.bias is not None,
424+
conv_transpose.dilation,
425+
conv_transpose.padding_mode,
426+
conv_transpose.weight.device,
427+
conv_transpose.weight.dtype,
428+
)
429+
self._update_module_params(conv_transpose, unary)
430+
431+
def _update_module_params(self, conv_transpose, unary):
432+
self.__dict__ = copy.deepcopy(conv_transpose.__dict__)
433+
self.attr, self.scalars, self.algorithm = unary_modules_map[unary.__class__](
434+
unary
435+
)
436+
437+
def _conv_transpose_forward(self, input, weight, bias):
438+
if self.padding_mode != "zeros":
439+
return torch.ops.mkldnn._convolution_transpose_pointwise(
440+
F.pad(
441+
input, self._reversed_padding_repeated_twice, mode=self.padding_mode
442+
),
443+
weight,
444+
bias,
445+
_pair(0),
446+
self.output_padding,
447+
self.stride,
448+
self.dilation,
449+
self.groups,
450+
self.attr,
451+
self.scalars,
452+
self.algorithm,
453+
)
454+
return torch.ops.mkldnn._convolution_transpose_pointwise(
455+
input,
456+
weight,
457+
bias,
458+
self.padding,
459+
self.output_padding,
460+
self.stride,
461+
self.dilation,
462+
self.groups,
463+
self.attr,
464+
self.scalars,
465+
self.algorithm,
466+
)
467+
468+
def forward(self, input):
469+
return self._conv_transpose_forward(input, self.weight, self.bias)
470+
471+
409472
def packed_conv_eval(conv: nn.Module, input_size: list):
410473
assert not (conv.training), "Fusion only for eval!"
411474
return ConvUnary2d(
@@ -481,6 +544,16 @@ def fused_linear_binary_eval(linear: nn.Module, attr: str, input_size: list):
481544
return linear_binary
482545

483546

547+
def fused_conv_transpose_unary_eval(
548+
conv_transpose: nn.Module, unary: nn.Module, input_size: list
549+
):
550+
assert not (conv_transpose.training), "Fusion only for eval!"
551+
return ConvTransposeUnary2d(
552+
conv_transpose,
553+
unary,
554+
)
555+
556+
484557
def check_node_kind(current_node, modules, node_kind):
485558
if not isinstance(current_node, torch.fx.Node):
486559
return False
@@ -1262,6 +1335,7 @@ def rand_like(x, **kwargs):
12621335
nn.Linear: fused_linear_unary_eval,
12631336
ConvBinary2d: fused_conv_binary_unary_eval,
12641337
ConvBinaryInplace2d: fused_conv_binary_unary_eval,
1338+
nn.ConvTranspose2d: fused_conv_transpose_unary_eval,
12651339
}
12661340

12671341

0 commit comments

Comments
 (0)