Skip to content

Commit

Permalink
TorchDynamo: Add convolution unary fusion for cpu in inference mode (p…
Browse files Browse the repository at this point in the history
  • Loading branch information
XiaobingSuper authored and kulinseth committed Dec 9, 2022
1 parent 1ff7e31 commit d3c3665
Show file tree
Hide file tree
Showing 5 changed files with 389 additions and 0 deletions.
60 changes: 60 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import dataclasses
import functools
import importlib
import itertools
import os
import random
import sys
Expand Down Expand Up @@ -1292,6 +1293,65 @@ def fn(a, b):
check_lowp=False,
)

# For gpu path, there has a accurcy issue,
# see https://github.com/pytorch/pytorch/issues/87745.
@unittest.skipIf(HAS_CUDA, "only support cpu conv2d unary test")
def test_conv2d_unary(self):
def _unary_list():
unary_list = [
torch.nn.ReLU(),
torch.nn.Sigmoid(),
torch.nn.Tanh(),
torch.nn.Hardswish(),
torch.nn.LeakyReLU(0.1, inplace=False),
torch.nn.Hardtanh(min_val=-0.5, max_val=4, inplace=False),
torch.nn.GELU(approximate="none"),
torch.nn.GELU(approximate="tanh"),
]
return unary_list

test_memory_format = [torch.contiguous_format, torch.channels_last]
options = itertools.product(
_unary_list(),
[True, False],
[1, 3],
[1, 2],
[1, 4],
test_memory_format,
)

for (
unary_fn,
bias,
kernel_size,
dilation,
groups,
memory_format,
) in options:
oC = 32 * groups
iC = 3 * groups
x_shape = (1, iC, 112, 112)
mod = torch.nn.Sequential(
torch.nn.Conv2d(
iC,
oC,
kernel_size=kernel_size,
dilation=dilation,
groups=groups,
bias=bias,
),
unary_fn,
).eval()

# TODO: add bf16 test for cpu path?
v = torch.randn(x_shape, dtype=torch.float32).to(
memory_format=memory_format
)
self.common(
mod,
(v,),
)

def test_gather1(self):
def fn(a, b):
return (
Expand Down
1 change: 1 addition & 0 deletions torch/_inductor/compile_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ def compile_fx(model_: torch.fx.GraphModule, example_inputs_: List[torch.Tensor]
with overrides.patch_functions():
model_ = normalize_ir(model_, example_inputs_)
model_ = overrides.replace_fx(model_)
model_ = overrides.fuse_fx(model_, example_inputs_)
num_example_inputs = len(example_inputs_)
cudagraphs = BoxedBool(config.triton.cudagraphs and not config.dynamic_shapes)

Expand Down
146 changes: 146 additions & 0 deletions torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -3295,6 +3295,152 @@ def get_template_tiling(self):
)


def _prepare_convolution_fusion_create(
cls,
x: "TensorBox",
weight: "TensorBox",
bias: "TensorBox",
padding_: List[int],
stride_: List[int],
dilation_: List[int],
groups: int,
):
"""
This function is a helper function to prepare inputs, layout and constant args
for convolution post-op fusion's create function, including deciding the output
layout (channels first or channels last), realizing inputs and make them etc. The
function only supports the CPU device since conv post-op fusion kernel is only
supported on CPU right now.
"""

x = cls.require_stride1(cls.realize_input(x))
weight = cls.require_stride1(cls.realize_input(weight))
assert x.get_device().type == "cpu" and weight.get_device().type == "cpu"
inputs = [x, weight]
stride = tuple(stride_)
padding = tuple(padding_)
dilation = tuple(dilation_)
assert isinstance(groups, int)

weight_shape = [
sympy.Integer(V.graph.sizevars.guard_static_shape(s)) for s in weight.get_size()
]

out_channels, in_channels1, *kernel_size = weight_shape
in_channels1 = in_channels1 * groups
assert len(x.get_size()) == 2 + len(kernel_size)
batch, in_channels2, *input_size = x.get_size()
output_size = [batch]
V.graph.sizevars.guard_equals(in_channels1, in_channels2)

output_size.append(out_channels)
assert (
len(stride)
== len(padding)
== len(dilation)
== len(kernel_size)
== len(input_size)
)
for i in range(len(stride)):
output_size.append(
IndexingDiv(
input_size[i]
+ 2 * padding[i]
- dilation[i] * (kernel_size[i] - 1)
- 1
+ stride[i],
stride[i],
)
)
output_size[-1] = sympy.Integer(
V.graph.sizevars.guard_static_shape(output_size[-1])
)

output_layout_str = "torch.contiguous_format"
# If x or weight have one channels_last(2d or 3d) format, it will call channels_last path,
# which align with aten.convolutuion path(cpu only support 2d case now).
# TODO: after cpu 3d convolution support channels_last path, the size check can be removed.
if len(x.get_size()) == 4 and (
x.get_layout().is_channels_last_stride_ordered()
or weight.get_layout().is_channels_last_stride_ordered()
):
output_layout_str = "torch.channels_last"

if output_layout_str == "torch.channels_last":
stride_order = [0] + list(reversed(range(1, len(kernel_size) + 1)))
if len(stride_order) < len(output_size):
# add batch dim if it exists
stride_order = [len(stride_order)] + stride_order
else:
stride_order = list(reversed(range(len(output_size))))

kernel_layout = FlexibleLayout(
device=inputs[0].get_device(),
dtype=inputs[0].get_dtype(),
size=output_size,
stride_order=stride_order,
)
constant_args = [padding, stride, dilation, groups]

if bias is not None:
inputs.append(bias)
else:
constant_args.insert(0, bias)
return inputs, constant_args, kernel_layout


class ConvolutionUnary(ExternKernelAlloc):
kernel = "torch.ops.mkldnn._convolution_pointwise"

def __init__(
self,
layout,
inputs,
constant_args=(),
kernel="torch.ops.mkldnn._convolution_pointwise",
):
super().__init__(layout, inputs, constant_args)
self.kernel = kernel

def codegen(self, wrapper):
wrapper.writeline(
f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})"
)

@classmethod
def create(
cls,
x: "TensorBox",
weight: "TensorBox",
bias: "TensorBox",
padding_: List[int],
stride_: List[int],
dilation_: List[int],
groups: int,
attr,
scalars,
algorithm,
):
kernel = "torch.ops.mkldnn._convolution_pointwise"
(inputs, constant_args, kernel_layout,) = _prepare_convolution_fusion_create(
cls, x, weight, bias, padding_, stride_, dilation_, groups
)
constant_args = constant_args + [attr, scalars, algorithm]
return ConvolutionUnary(
layout=kernel_layout,
inputs=inputs,
constant_args=constant_args,
kernel=kernel,
)

def apply_constraint(self):
x = self.inputs[0]
# FixedLayout of input
x = self.require_stride_order(x, self.layout.preferred_stride_order)
self.inputs[0] = x
self.freeze_layout_with_stride_order(self.layout.preferred_stride_order)


@dataclasses.dataclass
class MutableBox(IRNode):
"""
Expand Down
38 changes: 38 additions & 0 deletions torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,6 +886,44 @@ def bmm(a: TensorBox, b: TensorBox):
return TensorBox.create(ir.BatchMatrixMultiply.create(a, b))


def register_onednn_fusion_ops():
if torch._C.has_mkldnn:

@register_lowering(torch.ops.mkldnn._convolution_pointwise)
def convolution_unary(
x: TensorBox,
weight: TensorBox,
bias: TensorBox,
padding,
stride,
dilation,
groups,
attr,
scalars,
algorithm,
):
return TensorBox.create(
ir.ConvolutionUnary.create(
x,
weight,
bias,
padding,
stride,
dilation,
groups,
attr,
scalars,
algorithm,
)
)

else:
pass


register_onednn_fusion_ops()


def fallback_handler(kernel):
fallbacks.add(kernel)

Expand Down
Loading

0 comments on commit d3c3665

Please sign in to comment.