Skip to content

Commit

Permalink
Take input striding for conv fusion op based on eager output (pytorch…
Browse files Browse the repository at this point in the history
…#88864)

As pytorch#88706, we also change the input stride check using eager output.

Pull Request resolved: pytorch#88864
Approved by: https://github.com/jgong5, https://github.com/jansel
  • Loading branch information
XiaobingSuper authored and pytorchmergebot committed Nov 15, 2022
1 parent 0544a32 commit 7a37bbe
Showing 1 changed file with 30 additions and 65 deletions.
95 changes: 30 additions & 65 deletions torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -3346,59 +3346,47 @@ def _prepare_convolution_fusion_create(
function only supports the CPU device since conv post-op fusion kernel is only
supported on CPU right now.
"""

x = cls.require_stride1(cls.realize_input(x))
weight = cls.require_stride1(cls.realize_input(weight))
assert x.get_device().type == "cpu" and weight.get_device().type == "cpu"
inputs = [x, weight]
stride = tuple(stride_)
padding = tuple(padding_)
dilation = tuple(dilation_)
assert isinstance(groups, int)
with FakeTensorMode():
output, *_ = cls.process_kernel(
torch.ops.aten.convolution,
x,
weight,
bias,
with torch._subclasses.FakeTensorMode():
x_fake = ir_node_to_tensor(x, guard_shape=True)
weight_fake = ir_node_to_tensor(weight, guard_shape=True)
bias_fake = (
ir_node_to_tensor(bias, guard_shape=True) if bias is not None else bias
)
output = torch.ops.aten.convolution(
x_fake,
weight_fake,
bias_fake,
stride,
padding,
dilation,
False,
[0, 0],
groups,
)
req_stride_order = get_stride_order(output.stride())

output_size = output.shape
weight_shape = [
sympy.Integer(V.graph.sizevars.guard_static_shape(s)) for s in weight.get_size()
]
_, _, *kernel_size = weight_shape
output_layout_str = (
"torch.contiguous_format" if output.is_contiguous() else "torch.channels_last"
)

if output_layout_str == "torch.channels_last":
stride_order = [0] + list(reversed(range(1, len(kernel_size) + 1)))
if len(stride_order) < len(output_size):
# add batch dim if it exists
stride_order = [len(stride_order)] + stride_order
else:
stride_order = list(reversed(range(len(output_size))))
x = cls.require_stride_order(x, req_stride_order)
weight = cls.require_stride1(cls.realize_input(weight))
assert x.get_device().type == "cpu" and weight.get_device().type == "cpu"
inputs = [x, weight]

kernel_layout = FlexibleLayout(
device=inputs[0].get_device(),
dtype=inputs[0].get_dtype(),
size=output_size,
stride_order=stride_order,
kernel_layout = FixedLayout(
x.get_device(),
x.get_dtype(),
output.size(),
output.stride(),
)
constant_args = [padding, stride, dilation, groups]

if bias is not None:
inputs.append(bias)
else:
constant_args.insert(0, bias)
return inputs, constant_args, kernel_layout
return inputs, constant_args, kernel_layout, req_stride_order


class ConvolutionUnary(ExternKernelAlloc):
Expand Down Expand Up @@ -3436,7 +3424,7 @@ def create(
algorithm,
):
kernel = "torch.ops.mkldnn._convolution_pointwise"
(inputs, constant_args, kernel_layout,) = _prepare_convolution_fusion_create(
(inputs, constant_args, kernel_layout, _) = _prepare_convolution_fusion_create(
cls, x, weight, bias, padding_, stride_, dilation_, groups
)
constant_args = constant_args + [attr, scalars, algorithm]
Expand All @@ -3447,13 +3435,6 @@ def create(
kernel=kernel,
)

def apply_constraint(self):
x = self.inputs[0]
# FixedLayout of input
x = self.require_stride_order(x, self.layout.preferred_stride_order)
self.inputs[0] = x
self.freeze_layout_with_stride_order(self.layout.preferred_stride_order)


class ConvolutionBinary(ExternKernelAlloc):
kernel = "torch.ops.mkldnn._convolution_pointwise.binary"
Expand Down Expand Up @@ -3493,10 +3474,15 @@ def create(
unary_algorithm: Optional[str],
):
kernel = "torch.ops.mkldnn._convolution_pointwise.binary"
(inputs, constant_args, kernel_layout,) = _prepare_convolution_fusion_create(
(
inputs,
constant_args,
kernel_layout,
req_stride_order,
) = _prepare_convolution_fusion_create(
cls, x, weight, bias, padding_, stride_, dilation_, groups
)
other = cls.require_stride1(cls.realize_input(other))
other = cls.require_stride_order(other, req_stride_order)
inputs.insert(1, other)
constant_args = constant_args + [
binary_attr,
Expand All @@ -3512,32 +3498,19 @@ def create(
kernel=kernel,
)

def apply_constraint(self):
x = self.inputs[0]
# FixedLayout of input
x = self.require_stride_order(x, self.layout.preferred_stride_order)
self.inputs[0] = x
other = self.inputs[1]
# FixedLayout of other
other = self.require_stride_order(other, self.layout.preferred_stride_order)
self.inputs[1] = other
self.freeze_layout_with_stride_order(self.layout.preferred_stride_order)


class ConvolutionBinaryInplace(ExternKernelAlloc):
kernel = "torch.ops.mkldnn._convolution_pointwise_.binary"

def __init__(
self,
kernel_layout,
inputs_layout,
inputs,
constant_args=(),
kernel="torch.ops.mkldnn._convolution_pointwise_.binary",
):
super().__init__(kernel_layout, inputs, constant_args)
self.kernel = kernel
self.inputs_layout = inputs_layout

def codegen(self, wrapper):
wrapper.writeline(
Expand Down Expand Up @@ -3566,7 +3539,7 @@ def create(
unary_algorithm: Optional[str],
):
kernel = "torch.ops.mkldnn._convolution_pointwise_.binary"
(inputs, constant_args, inputs_layout,) = _prepare_convolution_fusion_create(
(inputs, constant_args, _, _) = _prepare_convolution_fusion_create(
cls, x, weight, bias, padding_, stride_, dilation_, groups
)
other = cls.realize_input(other)
Expand All @@ -3581,19 +3554,11 @@ def create(
]
return ConvolutionBinaryInplace(
kernel_layout=MutationLayout(inputs[1]),
inputs_layout=inputs_layout,
inputs=inputs,
constant_args=constant_args,
kernel=kernel,
)

def apply_constraint(self):
x = self.inputs[0]
# FixedLayout of input
x = self.require_stride_order(x, self.inputs_layout.preferred_stride_order)
self.inputs[0] = x
self.freeze_layout_with_stride_order(self.inputs_layout.preferred_stride_order)


class LinearUnary(ExternKernelAlloc):
kernel = "torch.ops.mkldnn._linear_pointwise"
Expand Down

0 comments on commit 7a37bbe

Please sign in to comment.