Skip to content

Commit 62fe70a

Browse files
[MLIR][TORCH] Add E2E support for max_pool2d_with_indices op
This commit adds lowering of `max_pool2d_with_indices` op. Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
1 parent d46f169 commit 62fe70a

File tree

13 files changed

+590
-149
lines changed

13 files changed

+590
-149
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2938,6 +2938,35 @@ def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
29382938
}];
29392939
}
29402940

2941+
def Torch_AtenMaxPool2dWithIndicesOp : Torch_Op<"aten.max_pool2d_with_indices", [
2942+
AllowsTypeRefinement,
2943+
HasValueSemantics,
2944+
ReadOnly
2945+
]> {
2946+
let summary = "Generated op for `aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)`";
2947+
let arguments = (ins
2948+
AnyTorchTensorType:$self,
2949+
ListOfTorchIntType:$kernel_size,
2950+
ListOfTorchIntType:$stride,
2951+
ListOfTorchIntType:$padding,
2952+
ListOfTorchIntType:$dilation,
2953+
Torch_BoolType:$ceil_mode
2954+
);
2955+
let results = (outs
2956+
AnyTorchTensorType:$result0,
2957+
AnyTorchTensorType:$result1
2958+
);
2959+
let hasCustomAssemblyFormat = 1;
2960+
let extraClassDefinition = [{
2961+
ParseResult AtenMaxPool2dWithIndicesOp::parse(OpAsmParser &parser, OperationState &result) {
2962+
return parseDefaultTorchOp(parser, result, 6, 2);
2963+
}
2964+
void AtenMaxPool2dWithIndicesOp::print(OpAsmPrinter &printer) {
2965+
printDefaultTorchOp(printer, *this, 6, 2);
2966+
}
2967+
}];
2968+
}
2969+
29412970
def Torch_AtenSoftmaxIntOp : Torch_Op<"aten.softmax.int", [
29422971
AllowsTypeRefinement,
29432972
HasValueSemantics,

lib/Conversion/TorchToLinalg/Linear.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -513,8 +513,8 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
513513
SmallVector<int64_t, 4> paddingIncludingNC = {0, 0};
514514
paddingIncludingNC.insert(paddingIncludingNC.end(), paddingInts.begin(),
515515
paddingInts.end());
516-
Value paddedInput = torch_to_linalg::getPaddedTensor(op, rewriter, input,
517-
paddingIncludingNC);
516+
Value paddedInput = torch_to_linalg::getZeroPaddedTensor(
517+
op, rewriter, input, paddingIncludingNC);
518518

519519
SmallVector<Value> paddingIntValues =
520520
getAsConstantIntValues(rewriter, loc, paddingInts);

lib/Conversion/TorchToLinalg/Pooling.cpp

Lines changed: 281 additions & 76 deletions
Large diffs are not rendered by default.

lib/Conversion/TorchToLinalg/Utils.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@ Value torch_to_linalg::getPaddedTensor(
5555
// Helper function to get the padding tensor given the padding int values.
5656
// It's assumed that the padding on the low end and high end are the same,
5757
// and that zero padding is required.
58-
Value torch_to_linalg::getPaddedTensor(Operation *op, OpBuilder &b,
59-
Value &input,
60-
SmallVectorImpl<int64_t> &paddingInts) {
58+
Value torch_to_linalg::getZeroPaddedTensor(
59+
Operation *op, OpBuilder &b, Value &input,
60+
SmallVectorImpl<int64_t> &paddingInts) {
6161
assert(input.getType().isa<RankedTensorType>() &&
6262
"input must be RankedTensorType");
6363
Location loc = op->getLoc();

lib/Conversion/TorchToLinalg/Utils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ Value getPaddedTensor(Operation *op, OpBuilder &b, Value &input,
2121
// Helper function to get the padding tensor given the padding int values.
2222
// It's assumed that the padding on the low end and high end are the same,
2323
// and that zero padding is required.
24-
Value getPaddedTensor(Operation *op, OpBuilder &b, Value &input,
25-
SmallVectorImpl<int64_t> &paddingInts);
24+
Value getZeroPaddedTensor(Operation *op, OpBuilder &b, Value &input,
25+
SmallVectorImpl<int64_t> &paddingInts);
2626

2727
// Helper function to caculate the output tensor dims for convolution-like ops.
2828
// Along each dim:

lib/Dialect/Torch/Transforms/RefineTypes.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -669,6 +669,21 @@ ChangeResult TypeAnalyzer::visitOperation(
669669
return changed;
670670
}
671671

672+
if (isa<AtenMaxPool2dWithIndicesOp>(op)) {
673+
auto self = operands[0]->getValue();
674+
auto result0Knowledge =
675+
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
676+
result0Knowledge.dtype = self.dtype;
677+
auto result1Knowledge =
678+
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
679+
result1Knowledge.dtype =
680+
IntegerType::get(op->getContext(), 64, IntegerType::Signed);
681+
;
682+
auto changed = incorporateKnowledge(op->getResult(0), result0Knowledge);
683+
changed |= incorporateKnowledge(op->getResult(1), result1Knowledge);
684+
return changed;
685+
}
686+
672687
if (auto arange = dyn_cast<AtenArangeOp>(op)) {
673688
return visitAtenArangeOp(arange);
674689
}

lib/Dialect/Torch/Transforms/ShapeLibrary.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1572,6 +1572,11 @@ module {
15721572
}
15731573
return %none : !torch.none
15741574
}
1575+
func @"__torch_mlir_shape_fn.aten.max_pool2d_with_indices"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.bool) -> !torch.tuple<list<int>, list<int>> {
1576+
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.max_pool2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool) -> !torch.list<int>
1577+
%1 = torch.prim.TupleConstruct %0, %0 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>
1578+
return %1 : !torch.tuple<list<int>, list<int>>
1579+
}
15751580
func @"__torch_mlir_shape_fn.aten.adaptive_avg_pool2d"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
15761581
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.adaptive_avg_pool2d(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
15771582
return %0 : !torch.list<int>

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,10 @@ def aten〇resize_(self: List[int], size: List[int], memory_format: Optional[int
563563
def aten〇max_pool2d(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), ceil_mode: bool = False) -> List[int]:
564564
return upstream_shape_helpers.max_pool2d(self, kernel_size, stride, padding, dilation, ceil_mode)
565565

566+
def aten〇max_pool2d_with_indices(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), ceil_mode: bool = False) -> Tuple[List[int], List[int]]:
567+
maxpool2d = indices = upstream_shape_helpers.max_pool2d(self, kernel_size, stride, padding, dilation, ceil_mode)
568+
return maxpool2d, indices
569+
566570
def aten〇adaptive_avg_pool2d(self: List[int], output_size: List[int]) -> List[int]:
567571
return upstream_shape_helpers.adaptive_avg_pool2d(self, output_size)
568572

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,9 @@ def emit_with_mutating_variants(key, **kwargs):
325325
emit(
326326
"aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)"
327327
)
328+
emit(
329+
"aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)"
330+
)
328331
emit(
329332
"aten::softmax.int : (Tensor, int, int?) -> (Tensor)"
330333
)

python/torch_mlir_e2e_test/test_suite/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,4 @@ def register_all_tests():
3535
from . import rng
3636
from . import cast
3737
from . import index_put
38+
from . import pooling

0 commit comments

Comments
 (0)