Skip to content

Commit 51d6140

Browse files
[MLIR][TORCH] Add E2E support for max_pool2d_with_indices op
This commit adds lowering of `max_pool2d_with_indices` op. Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
1 parent 51d4d55 commit 51d6140

File tree

7 files changed

+314
-84
lines changed

7 files changed

+314
-84
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2876,6 +2876,35 @@ def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
28762876
}];
28772877
}
28782878

2879+
def Torch_AtenMaxPool2dWithIndicesOp : Torch_Op<"aten.max_pool2d_with_indices", [
2880+
AllowsTypeRefinement,
2881+
HasValueSemantics,
2882+
ReadOnly
2883+
]> {
2884+
let summary = "Generated op for `aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)`";
2885+
let arguments = (ins
2886+
AnyTorchTensorType:$self,
2887+
TorchIntListType:$kernel_size,
2888+
TorchIntListType:$stride,
2889+
TorchIntListType:$padding,
2890+
TorchIntListType:$dilation,
2891+
Torch_BoolType:$ceil_mode
2892+
);
2893+
let results = (outs
2894+
AnyTorchTensorType:$result0,
2895+
AnyTorchTensorType:$result1
2896+
);
2897+
let hasCustomAssemblyFormat = 1;
2898+
let extraClassDefinition = [{
2899+
ParseResult AtenMaxPool2dWithIndicesOp::parse(OpAsmParser &parser, OperationState &result) {
2900+
return parseDefaultTorchOp(parser, result, 6, 2);
2901+
}
2902+
void AtenMaxPool2dWithIndicesOp::print(OpAsmPrinter &printer) {
2903+
printDefaultTorchOp(printer, *this, 6, 2);
2904+
}
2905+
}];
2906+
}
2907+
28792908
def Torch_AtenSoftmaxIntOp : Torch_Op<"aten.softmax.int", [
28802909
AllowsTypeRefinement,
28812910
HasValueSemantics,

lib/Conversion/TorchToLinalg/Pooling.cpp

Lines changed: 194 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,91 @@ using namespace mlir;
2727
using namespace mlir::torch;
2828
using namespace mlir::torch::Torch;
2929

30+
// Computes maxpool2d for AtenMaxPool2dOp and AtenMaxPool2dWithIndicesOp.
31+
template <typename OpTy>
32+
static LogicalResult computeMaxPool2d(OpTy op,
33+
ConversionPatternRewriter &rewriter,
34+
Value self, Value &result) {
35+
Location loc = op.getLoc();
36+
Value ceilMode = op.ceil_mode();
37+
Type elementType = self.getType().cast<RankedTensorType>().getElementType();
38+
if (!elementType.isa<mlir::FloatType>())
39+
return op.emitError("unimplemented: non-floating point type");
40+
41+
// Pattern match against the op's original operands, because otherwise we
42+
// will get the lowered version of the operands which is harder to pattern
43+
// match.
44+
SmallVector<int64_t, 2> strideInts;
45+
if (!matchPattern(op.stride(), m_TorchConstantIntList(strideInts)))
46+
return rewriter.notifyMatchFailure(op, "only support constant int strides");
47+
SmallVector<int64_t, 2> dilationInts;
48+
if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilationInts)))
49+
return rewriter.notifyMatchFailure(op,
50+
"only support constant int dilations");
51+
SmallVector<int64_t, 2> paddingInts;
52+
if (!matchPattern(op.padding(), m_TorchConstantIntList(paddingInts)))
53+
return rewriter.notifyMatchFailure(op,
54+
"only support constant int paddings");
55+
SmallVector<int64_t, 2> kernelSizeInts;
56+
if (!matchPattern(op.kernel_size(), m_TorchConstantIntList(kernelSizeInts)))
57+
return rewriter.notifyMatchFailure(op, "only support kernel size ints");
58+
bool ceilModeFalse = false;
59+
if (!matchPattern(op.ceil_mode(), m_TorchConstantBool(&ceilModeFalse)))
60+
return rewriter.notifyMatchFailure(op, "only ceil_mode false is supported");
61+
62+
SmallVector<int64_t, 4> paddingIncludingNC = {0, 0};
63+
paddingIncludingNC.insert(paddingIncludingNC.end(), paddingInts.begin(),
64+
paddingInts.end());
65+
Value paddedInput =
66+
torch_to_linalg::getPaddedTensor(op, rewriter, self, paddingIncludingNC);
67+
68+
Value N = getDimOp(rewriter, loc, self, 0);
69+
Value C = getDimOp(rewriter, loc, self, 1);
70+
Value H = getDimOp(rewriter, loc, self, 2);
71+
Value W = getDimOp(rewriter, loc, self, 3);
72+
73+
SmallVector<Value> paddingIntValues =
74+
getAsConstantIntValues(rewriter, loc, paddingInts);
75+
SmallVector<Value> dilationIntValues =
76+
getAsConstantIntValues(rewriter, loc, dilationInts);
77+
SmallVector<Value> kernelSizeIntValues =
78+
getAsConstantIntValues(rewriter, loc, kernelSizeInts);
79+
SmallVector<Value> strideIntValues =
80+
getAsConstantIntValues(rewriter, loc, strideInts);
81+
82+
Value Hout = torch_to_linalg::getOutputDimForConvOps(
83+
rewriter, loc, H, paddingIntValues[0], dilationIntValues[0],
84+
kernelSizeIntValues[0], strideIntValues[0]);
85+
Value Wout = torch_to_linalg::getOutputDimForConvOps(
86+
rewriter, loc, W, paddingIntValues[1], dilationIntValues[1],
87+
kernelSizeIntValues[1], strideIntValues[1]);
88+
89+
// Initialize output tensor with smallest floating point value
90+
Value outTensor = rewriter.create<linalg::InitTensorOp>(
91+
loc, ValueRange{N, C, Hout, Wout}, elementType);
92+
auto initialAttr = rewriter.getFloatAttr(
93+
elementType, APFloat::getSmallest(
94+
elementType.cast<mlir::FloatType>().getFloatSemantics(),
95+
/*Negative*/ true));
96+
Value initValue = rewriter.create<arith::ConstantOp>(loc, initialAttr);
97+
Value outTensorInitialized =
98+
rewriter.create<linalg::FillOp>(loc, initValue, outTensor).getResult(0);
99+
100+
auto stridesAttr = rewriter.getI64VectorAttr(strideInts);
101+
auto dilationAttr = rewriter.getI64VectorAttr(dilationInts);
102+
Value windowTensor = rewriter.create<linalg::InitTensorOp>(
103+
loc, getAsConstantIndexValues(rewriter, loc, kernelSizeInts),
104+
elementType);
105+
106+
result = rewriter
107+
.create<linalg::PoolingNchwMaxOp>(
108+
loc, outTensorInitialized.getType(),
109+
ValueRange{paddedInput, windowTensor}, outTensorInitialized,
110+
stridesAttr, dilationAttr)
111+
.getResult(0);
112+
return success();
113+
}
114+
30115
namespace {
31116
class ConvertAtenMaxPool2dOp : public OpConversionPattern<AtenMaxPool2dOp> {
32117
public:
@@ -36,94 +121,117 @@ class ConvertAtenMaxPool2dOp : public OpConversionPattern<AtenMaxPool2dOp> {
36121
ConversionPatternRewriter &rewriter) const override {
37122
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
38123
return failure();
124+
125+
Value self = adaptor.self();
126+
Value maxPool2d;
127+
if (failed(
128+
computeMaxPool2d<AtenMaxPool2dOp>(op, rewriter, self, maxPool2d)))
129+
return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d");
130+
Type newResultType = getTypeConverter()->convertType(op.getType());
131+
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, maxPool2d);
132+
return success();
133+
}
134+
};
135+
} // namespace
136+
137+
namespace {
138+
// Returns the result of maxpool2d over the input tensor. And the corresponding
139+
// indices of the input tensor for the values of the result tensor.
140+
//
141+
// The result of the maxpool2d operation is calculated using the helper function
142+
// written above. For finding the indices, we follow the below method:
143+
//
144+
// Let's say the input tensor is a 4-d tensor. The maxpool2d and indices will
145+
// also be a 4-d tensor. Then:
146+
// for i in input.size[0]:
147+
// for j in input.size[1]:
148+
// for k in input.size[2]:
149+
// for l in input.size[3]:
150+
// for p in maxpool2d.size[1]:
151+
// for q in maxpool2d.size[2]:
152+
// for r in maxpool2d.size[3]:
153+
// if input[i, j, k, l] == maxpool2d[i, p, q, r]:
154+
// indices[i, p, q, r] = (k * input.size[3] + l)
155+
//
156+
class ConvertAtenMaxPool2dWithIndicesOp
157+
: public OpConversionPattern<AtenMaxPool2dWithIndicesOp> {
158+
public:
159+
using OpConversionPattern::OpConversionPattern;
160+
LogicalResult
161+
matchAndRewrite(AtenMaxPool2dWithIndicesOp op, OpAdaptor adaptor,
162+
ConversionPatternRewriter &rewriter) const override {
163+
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
164+
return failure();
39165
Location loc = op->getLoc();
40166
Value self = adaptor.self();
41-
Value ceilMode = adaptor.ceil_mode();
42167

43-
Type elementType = self.getType().cast<RankedTensorType>().getElementType();
44-
if (!elementType.isa<mlir::FloatType>())
45-
return op.emitError("unimplemented: non-floating point type");
168+
// Contains the result of maxpool2d operation over the input.
169+
Value maxPool2d;
170+
if (failed(computeMaxPool2d<AtenMaxPool2dWithIndicesOp>(op, rewriter, self,
171+
maxPool2d)))
172+
return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d");
46173

47-
// Pattern match against the op's original operands, because otherwise we
48-
// will get the lowered version of the operands which is harder to pattern
49-
// match.
50-
SmallVector<int64_t, 2> strideInts;
51-
if (!matchPattern(op.stride(), m_TorchConstantIntList(strideInts)))
52-
return rewriter.notifyMatchFailure(op,
53-
"only support constant int strides");
54-
SmallVector<int64_t, 2> dilationInts;
55-
if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilationInts)))
56-
return rewriter.notifyMatchFailure(op,
57-
"only support constant int dilations");
58-
SmallVector<int64_t, 2> paddingInts;
59-
if (!matchPattern(op.padding(), m_TorchConstantIntList(paddingInts)))
60-
return rewriter.notifyMatchFailure(op,
61-
"only support constant int paddings");
62-
SmallVector<int64_t, 2> kernelSizeInts;
63-
if (!matchPattern(op.kernel_size(), m_TorchConstantIntList(kernelSizeInts)))
64-
return rewriter.notifyMatchFailure(op, "only support kernel size ints");
65-
66-
Value falseValue = rewriter.create<arith::ConstantOp>(
67-
loc, IntegerAttr::get(rewriter.getIntegerType(1), 0));
68-
Value ceilModeFalse = rewriter.create<arith::CmpIOp>(
69-
loc, arith::CmpIPredicate::eq, ceilMode, falseValue);
70-
rewriter.create<cf::AssertOp>(
71-
loc, ceilModeFalse,
72-
rewriter.getStringAttr("only ceil_mode false is supported"));
73-
74-
SmallVector<int64_t, 4> paddingIncludingNC = {0, 0};
75-
paddingIncludingNC.insert(paddingIncludingNC.end(), paddingInts.begin(),
76-
paddingInts.end());
77-
Value paddedInput = torch_to_linalg::getPaddedTensor(op, rewriter, self,
78-
paddingIncludingNC);
79-
80-
Value N = getDimOp(rewriter, loc, self, 0);
81-
Value C = getDimOp(rewriter, loc, self, 1);
82-
Value H = getDimOp(rewriter, loc, self, 2);
83-
Value W = getDimOp(rewriter, loc, self, 3);
84-
85-
SmallVector<Value> paddingIntValues =
86-
getAsConstantIntValues(rewriter, loc, paddingInts);
87-
SmallVector<Value> dilationIntValues =
88-
getAsConstantIntValues(rewriter, loc, dilationInts);
89-
SmallVector<Value> kernelSizeIntValues =
90-
getAsConstantIntValues(rewriter, loc, kernelSizeInts);
91-
SmallVector<Value> strideIntValues =
92-
getAsConstantIntValues(rewriter, loc, strideInts);
93-
94-
Value Hout = torch_to_linalg::getOutputDimForConvOps(
95-
rewriter, loc, H, paddingIntValues[0], dilationIntValues[0],
96-
kernelSizeIntValues[0], strideIntValues[0]);
97-
Value Wout = torch_to_linalg::getOutputDimForConvOps(
98-
rewriter, loc, W, paddingIntValues[1], dilationIntValues[1],
99-
kernelSizeIntValues[1], strideIntValues[1]);
100-
101-
// Initialize output tensor with smallest floating point value
102-
Value outTensor = rewriter.create<linalg::InitTensorOp>(
103-
loc, ValueRange{N, C, Hout, Wout}, elementType);
104-
auto initialAttr = rewriter.getFloatAttr(
105-
elementType,
106-
APFloat::getSmallest(
107-
elementType.cast<mlir::FloatType>().getFloatSemantics(),
108-
/*Negative*/ true));
109-
Value initValue = rewriter.create<arith::ConstantOp>(loc, initialAttr);
110-
Value outTensorInitialized =
111-
rewriter.create<linalg::FillOp>(loc, initValue, outTensor).getResult(0);
112-
113-
auto stridesAttr = rewriter.getI64VectorAttr(strideInts);
114-
auto dilationAttr = rewriter.getI64VectorAttr(dilationInts);
115-
Value windowTensor = rewriter.create<linalg::InitTensorOp>(
116-
loc, getAsConstantIndexValues(rewriter, loc, kernelSizeInts),
117-
elementType);
118-
119-
Value maxPool2d = rewriter
120-
.create<linalg::PoolingNchwMaxOp>(
121-
loc, outTensorInitialized.getType(),
122-
ValueRange{paddedInput, windowTensor},
123-
outTensorInitialized, stridesAttr, dilationAttr)
124-
.getResult(0);
125-
Type newResultType = getTypeConverter()->convertType(op.getType());
126-
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, maxPool2d);
174+
RankedTensorType resultType = getTypeConverter()
175+
->convertType(op->getResult(0).getType())
176+
.cast<RankedTensorType>();
177+
RankedTensorType indicesType = getTypeConverter()
178+
->convertType(op->getResult(1).getType())
179+
.cast<RankedTensorType>();
180+
unsigned resultRank = resultType.getRank();
181+
SmallVector<Value> inputShape(getTensorSizes(rewriter, loc, self));
182+
SmallVector<Value> resultShape(getTensorSizes(rewriter, loc, maxPool2d));
183+
184+
Value indicesTensor = createZeroInitTensor(rewriter, loc, resultShape,
185+
indicesType.getElementType());
186+
187+
SmallVector<AffineExpr> inputExprs, maxPoolExprs, indicesExprs;
188+
SmallVector<StringRef> iteratorTypes(2 * resultRank - 1,
189+
getParallelIteratorTypeName());
190+
AffineExpr zeroDimExpr = rewriter.getAffineDimExpr(0);
191+
inputExprs.push_back(zeroDimExpr);
192+
maxPoolExprs.push_back(zeroDimExpr);
193+
indicesExprs.push_back(zeroDimExpr);
194+
195+
for (unsigned i = 1; i < resultRank; i++) {
196+
inputExprs.push_back(rewriter.getAffineDimExpr(i));
197+
maxPoolExprs.push_back(rewriter.getAffineDimExpr(i + resultRank - 1));
198+
indicesExprs.push_back(rewriter.getAffineDimExpr(i + resultRank - 1));
199+
}
200+
201+
auto indexingMaps =
202+
AffineMap::inferFromExprList({inputExprs, maxPoolExprs, indicesExprs});
203+
204+
auto indicesResult =
205+
rewriter
206+
.create<linalg::GenericOp>(
207+
loc, /*resultTensorTypes=*/indicesTensor.getType(),
208+
/*inputs=*/ValueRange({self, maxPool2d}),
209+
/*outputs=*/indicesTensor,
210+
/*indexingMaps=*/indexingMaps,
211+
/*iteratorTypes=*/iteratorTypes,
212+
[&](OpBuilder &b, Location loc, ValueRange args) {
213+
Value out = args[2];
214+
Value index = b.create<linalg::IndexOp>(loc, resultRank - 2);
215+
index = b.create<arith::MulIOp>(
216+
loc, index, inputShape[resultRank - 2 + 1]);
217+
index = b.create<arith::AddIOp>(
218+
loc, index,
219+
b.create<linalg::IndexOp>(loc, resultRank - 1));
220+
index = castIndexToInt(b, loc, index);
221+
Value predicate;
222+
if (resultType.getElementType().isa<mlir::FloatType>())
223+
predicate = b.create<arith::CmpFOp>(
224+
loc, arith::CmpFPredicate::OEQ, args[0], args[1]);
225+
else
226+
predicate = b.create<arith::CmpIOp>(
227+
loc, arith::CmpIPredicate::eq, args[0], args[1]);
228+
229+
Value result =
230+
b.create<arith::SelectOp>(loc, predicate, index, out);
231+
b.create<linalg::YieldOp>(loc, result);
232+
})
233+
.getResult(0);
234+
rewriter.replaceOp(op, {maxPool2d, indicesResult});
127235
return success();
128236
}
129237
};
@@ -234,6 +342,8 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality(
234342
MLIRContext *context = patterns.getContext();
235343
target.addIllegalOp<AtenMaxPool2dOp>();
236344
patterns.add<ConvertAtenMaxPool2dOp>(typeConverter, context);
345+
target.addIllegalOp<AtenMaxPool2dWithIndicesOp>();
346+
patterns.add<ConvertAtenMaxPool2dWithIndicesOp>(typeConverter, context);
237347
target.addIllegalOp<AtenAdaptiveAvgPool2dOp>();
238348
patterns.add<ConvertAtenAdaptiveAvgPool2dOp>(typeConverter, context);
239349
}

lib/Dialect/Torch/Transforms/RefineTypes.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -668,6 +668,21 @@ ChangeResult TypeAnalyzer::visitOperation(
668668
return changed;
669669
}
670670

671+
if (isa<AtenMaxPool2dWithIndicesOp>(op)) {
672+
auto self = operands[0]->getValue();
673+
auto result0Knowledge =
674+
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
675+
result0Knowledge.dtype = self.dtype;
676+
auto result1Knowledge =
677+
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
678+
result1Knowledge.dtype =
679+
IntegerType::get(op->getContext(), 64, IntegerType::Signed);
680+
;
681+
auto changed = incorporateKnowledge(op->getResult(0), result0Knowledge);
682+
changed |= incorporateKnowledge(op->getResult(1), result1Knowledge);
683+
return changed;
684+
}
685+
671686
if (auto arange = dyn_cast<AtenArangeOp>(op)) {
672687
return visitAtenArangeOp(arange);
673688
}

lib/Dialect/Torch/Transforms/ShapeLibrary.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1572,6 +1572,11 @@ module {
15721572
}
15731573
return %none : !torch.none
15741574
}
1575+
func @"__torch_mlir_shape_fn.aten.max_pool2d_with_indices"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.bool) -> !torch.tuple<list<int>, list<int>> {
1576+
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.max_pool2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool) -> !torch.list<int>
1577+
%1 = torch.prim.TupleConstruct %0, %0 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>
1578+
return %1 : !torch.tuple<list<int>, list<int>>
1579+
}
15751580
func @"__torch_mlir_shape_fn.aten.adaptive_avg_pool2d"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
15761581
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.adaptive_avg_pool2d(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
15771582
return %0 : !torch.list<int>

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,10 @@ def aten〇resize_(self: List[int], size: List[int], memory_format: Optional[int
563563
def aten〇max_pool2d(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), ceil_mode: bool = False) -> List[int]:
564564
return upstream_shape_helpers.max_pool2d(self, kernel_size, stride, padding, dilation, ceil_mode)
565565

566+
def aten〇max_pool2d_with_indices(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), ceil_mode: bool = False) -> Tuple[List[int], List[int]]:
567+
maxpool2d = indices = upstream_shape_helpers.max_pool2d(self, kernel_size, stride, padding, dilation, ceil_mode)
568+
return maxpool2d, indices
569+
566570
def aten〇adaptive_avg_pool2d(self: List[int], output_size: List[int]) -> List[int]:
567571
return upstream_shape_helpers.adaptive_avg_pool2d(self, output_size)
568572

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,9 @@ def emit_with_mutating_variants(key, **kwargs):
322322
emit(
323323
"aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)"
324324
)
325+
emit(
326+
"aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)"
327+
)
325328
emit(
326329
"aten::softmax.int : (Tensor, int, int?) -> (Tensor)"
327330
)

0 commit comments

Comments
 (0)