diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 8a2e26ee9b94c..fdc10c9ca16a7 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -3346,21 +3346,20 @@ def _prepare_convolution_fusion_create( function only supports the CPU device since conv post-op fusion kernel is only supported on CPU right now. """ - - x = cls.require_stride1(cls.realize_input(x)) - weight = cls.require_stride1(cls.realize_input(weight)) - assert x.get_device().type == "cpu" and weight.get_device().type == "cpu" - inputs = [x, weight] stride = tuple(stride_) padding = tuple(padding_) dilation = tuple(dilation_) assert isinstance(groups, int) - with FakeTensorMode(): - output, *_ = cls.process_kernel( - torch.ops.aten.convolution, - x, - weight, - bias, + with torch._subclasses.FakeTensorMode(): + x_fake = ir_node_to_tensor(x, guard_shape=True) + weight_fake = ir_node_to_tensor(weight, guard_shape=True) + bias_fake = ( + ir_node_to_tensor(bias, guard_shape=True) if bias is not None else bias + ) + output = torch.ops.aten.convolution( + x_fake, + weight_fake, + bias_fake, stride, padding, dilation, @@ -3368,29 +3367,18 @@ def _prepare_convolution_fusion_create( [0, 0], groups, ) + req_stride_order = get_stride_order(output.stride()) - output_size = output.shape - weight_shape = [ - sympy.Integer(V.graph.sizevars.guard_static_shape(s)) for s in weight.get_size() - ] - _, _, *kernel_size = weight_shape - output_layout_str = ( - "torch.contiguous_format" if output.is_contiguous() else "torch.channels_last" - ) - - if output_layout_str == "torch.channels_last": - stride_order = [0] + list(reversed(range(1, len(kernel_size) + 1))) - if len(stride_order) < len(output_size): - # add batch dim if it exists - stride_order = [len(stride_order)] + stride_order - else: - stride_order = list(reversed(range(len(output_size)))) + x = cls.require_stride_order(x, req_stride_order) + weight = cls.require_stride1(cls.realize_input(weight)) + assert x.get_device().type == "cpu" and weight.get_device().type == "cpu" + inputs = [x, weight] - kernel_layout = FlexibleLayout( - device=inputs[0].get_device(), - dtype=inputs[0].get_dtype(), - size=output_size, - stride_order=stride_order, + kernel_layout = FixedLayout( + x.get_device(), + x.get_dtype(), + output.size(), + output.stride(), ) constant_args = [padding, stride, dilation, groups] @@ -3398,7 +3386,7 @@ def _prepare_convolution_fusion_create( inputs.append(bias) else: constant_args.insert(0, bias) - return inputs, constant_args, kernel_layout + return inputs, constant_args, kernel_layout, req_stride_order class ConvolutionUnary(ExternKernelAlloc): @@ -3436,7 +3424,7 @@ def create( algorithm, ): kernel = "torch.ops.mkldnn._convolution_pointwise" - (inputs, constant_args, kernel_layout,) = _prepare_convolution_fusion_create( + (inputs, constant_args, kernel_layout, _) = _prepare_convolution_fusion_create( cls, x, weight, bias, padding_, stride_, dilation_, groups ) constant_args = constant_args + [attr, scalars, algorithm] @@ -3447,13 +3435,6 @@ def create( kernel=kernel, ) - def apply_constraint(self): - x = self.inputs[0] - # FixedLayout of input - x = self.require_stride_order(x, self.layout.preferred_stride_order) - self.inputs[0] = x - self.freeze_layout_with_stride_order(self.layout.preferred_stride_order) - class ConvolutionBinary(ExternKernelAlloc): kernel = "torch.ops.mkldnn._convolution_pointwise.binary" @@ -3493,10 +3474,15 @@ def create( unary_algorithm: Optional[str], ): kernel = "torch.ops.mkldnn._convolution_pointwise.binary" - (inputs, constant_args, kernel_layout,) = _prepare_convolution_fusion_create( + ( + inputs, + constant_args, + kernel_layout, + req_stride_order, + ) = _prepare_convolution_fusion_create( cls, x, weight, bias, padding_, stride_, dilation_, groups ) - other = cls.require_stride1(cls.realize_input(other)) + other = cls.require_stride_order(other, req_stride_order) inputs.insert(1, other) constant_args = constant_args + [ binary_attr, @@ -3512,17 +3498,6 @@ def create( kernel=kernel, ) - def apply_constraint(self): - x = self.inputs[0] - # FixedLayout of input - x = self.require_stride_order(x, self.layout.preferred_stride_order) - self.inputs[0] = x - other = self.inputs[1] - # FixedLayout of other - other = self.require_stride_order(other, self.layout.preferred_stride_order) - self.inputs[1] = other - self.freeze_layout_with_stride_order(self.layout.preferred_stride_order) - class ConvolutionBinaryInplace(ExternKernelAlloc): kernel = "torch.ops.mkldnn._convolution_pointwise_.binary" @@ -3530,14 +3505,12 @@ class ConvolutionBinaryInplace(ExternKernelAlloc): def __init__( self, kernel_layout, - inputs_layout, inputs, constant_args=(), kernel="torch.ops.mkldnn._convolution_pointwise_.binary", ): super().__init__(kernel_layout, inputs, constant_args) self.kernel = kernel - self.inputs_layout = inputs_layout def codegen(self, wrapper): wrapper.writeline( @@ -3566,7 +3539,7 @@ def create( unary_algorithm: Optional[str], ): kernel = "torch.ops.mkldnn._convolution_pointwise_.binary" - (inputs, constant_args, inputs_layout,) = _prepare_convolution_fusion_create( + (inputs, constant_args, _, _) = _prepare_convolution_fusion_create( cls, x, weight, bias, padding_, stride_, dilation_, groups ) other = cls.realize_input(other) @@ -3581,19 +3554,11 @@ def create( ] return ConvolutionBinaryInplace( kernel_layout=MutationLayout(inputs[1]), - inputs_layout=inputs_layout, inputs=inputs, constant_args=constant_args, kernel=kernel, ) - def apply_constraint(self): - x = self.inputs[0] - # FixedLayout of input - x = self.require_stride_order(x, self.inputs_layout.preferred_stride_order) - self.inputs[0] = x - self.freeze_layout_with_stride_order(self.inputs_layout.preferred_stride_order) - class LinearUnary(ExternKernelAlloc): kernel = "torch.ops.mkldnn._linear_pointwise"