Skip to content

Commit 769f3a8

Browse files
[MLIR][TORCH] Add E2E support for max_pool2d_with_indices op
This commit adds lowering of `max_pool2d_with_indices` op. Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
1 parent d3c0837 commit 769f3a8

File tree

16 files changed

+836
-294
lines changed

16 files changed

+836
-294
lines changed

e2e_testing/torchscript/xfail_sets.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
"MobilenetV3Module_basic",
2121
"ConvolutionModule3D_basic",
2222
"ConvolutionModule1D_basic",
23+
"MaxPool2dWith3dInputModule_basic",
24+
"MaxPool2dWithIndicesWith3dInputModule_basic",
2325
}
2426
REFBACKEND_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS
2527

include/torch-mlir/Conversion/Utils/Utils.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ void checkDimEqualHelper(OpBuilder &b, Location loc, Value lhsDim,
3838
Value createInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
3939
Type elemTy, Value initElem);
4040

41+
Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
42+
Type elemTy);
43+
4144
Value castIntToIndex(OpBuilder &b, Location loc, Value v);
4245

4346
Value castIndexToInt(OpBuilder &b, Location loc, Value idx);
@@ -51,9 +54,6 @@ SmallVector<Value> getTensorSizes(OpBuilder &b, Location loc, Value tensor);
5154

5255
Value getTensorSize(OpBuilder &b, Location loc, Value tensor);
5356

54-
Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
55-
Type elemTy);
56-
5757
// Creates a constant of type `elemType` with value `val`.
5858
Value getConstant(OpBuilder &b, Location loc, int64_t val, Type elemType);
5959

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2938,6 +2938,35 @@ def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
29382938
}];
29392939
}
29402940

2941+
def Torch_AtenMaxPool2dWithIndicesOp : Torch_Op<"aten.max_pool2d_with_indices", [
2942+
AllowsTypeRefinement,
2943+
HasValueSemantics,
2944+
ReadOnly
2945+
]> {
2946+
let summary = "Generated op for `aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)`";
2947+
let arguments = (ins
2948+
AnyTorchTensorType:$self,
2949+
ListOfTorchIntType:$kernel_size,
2950+
ListOfTorchIntType:$stride,
2951+
ListOfTorchIntType:$padding,
2952+
ListOfTorchIntType:$dilation,
2953+
Torch_BoolType:$ceil_mode
2954+
);
2955+
let results = (outs
2956+
AnyTorchTensorType:$result0,
2957+
AnyTorchTensorType:$result1
2958+
);
2959+
let hasCustomAssemblyFormat = 1;
2960+
let extraClassDefinition = [{
2961+
ParseResult AtenMaxPool2dWithIndicesOp::parse(OpAsmParser &parser, OperationState &result) {
2962+
return parseDefaultTorchOp(parser, result, 6, 2);
2963+
}
2964+
void AtenMaxPool2dWithIndicesOp::print(OpAsmPrinter &printer) {
2965+
printDefaultTorchOp(printer, *this, 6, 2);
2966+
}
2967+
}];
2968+
}
2969+
29412970
def Torch_AtenMaxPool2dWithIndicesBackwardOp : Torch_Op<"aten.max_pool2d_with_indices_backward", [
29422971
AllowsTypeRefinement,
29432972
HasValueSemantics,

lib/Conversion/TorchToLinalg/Linear.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -513,8 +513,8 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
513513
SmallVector<int64_t, 4> paddingIncludingNC = {0, 0};
514514
paddingIncludingNC.insert(paddingIncludingNC.end(), paddingInts.begin(),
515515
paddingInts.end());
516-
Value paddedInput = torch_to_linalg::getPaddedTensor(op, rewriter, input,
517-
paddingIncludingNC);
516+
Value paddedInput = torch_to_linalg::getZeroPaddedTensor(
517+
op, rewriter, input, paddingIncludingNC);
518518

519519
SmallVector<Value> paddingIntValues =
520520
getAsConstantIntValues(rewriter, loc, paddingInts);

0 commit comments

Comments
 (0)